Home || Architecture || Video Search || Visual Search || Scripts || Applications || Important Messages || OGL || Src

mainTrainModel.cpp

Go to the documentation of this file.
00001 #include "Basis/CmdOptions.h"
00002 #include "Util/Mpi/SelectWorkLoadInPlace.h"
00003 #include "Core/ApplicationFactory.h"
00004 
00005 // since we are not using libraries:
00006 #include "Link/Svm/LinkSvm.cpp"
00007 
00008 
00009 namespace Impala
00010 {
00011 namespace Application
00012 {
00013 
00014 
00015 // truncating dataset / annotation
00016 // code should go to dataFactory?
00017 /*
00018 Table::QuidTable* learnQuids = learnFeatures->GetQuidTable();
00019 Table::AnnotationTable* annotationtemp = new Table::AnnotationTable;
00020 Read(annotationtemp, filename, db);
00021 Table::CriterionElement1InSet<Table::AnnotationTable> c(learnQuids);
00022 Table::AnnotationTable* annotation = Select(annotationtemp, c);
00023 delete annotationtemp;
00024 
00025 int maxV = options.GetInt("maxVideoId");
00026 annotation->SelectQuidObjectMaxId(maxV);
00027 int maxP = options.GetInt("maxPosPerVideo");
00028 annotation->SelectQuidObjectMaxPositive(maxP);
00029 int maxN = options.GetInt("maxNegPerVideo");
00030 annotation->SelectQuidObjectMaxNegative(maxN);
00031 ILOG_INFO(annotation->GetNrPositive() << " positive and " <<
00032           annotation->GetNrNegative() << " negative annotations");
00033 
00034 Table::QuidTable* positive = annotation->GetPositive();
00035 Table::QuidTable* negative = annotation->GetNegative();
00036 delete learnQuids;
00037 learnQuids = positive;
00038 learnQuids->Append(negative);
00039 delete negative;
00040 */
00041 
00042 
00043 using namespace Core;
00044 using namespace Core::Training;
00045 
00046 int TrainModel(Training::Factory* trainFactory,
00047                DataFactory* dataFactory, bool distKernel)
00048 {
00049     ILOG_VAR(main);
00050     std::vector<String> conceptList = dataFactory->MakeConceptList();
00051     if(!distKernel)
00052         Util::Mpi::SelectWorkLoadInPlace(&conceptList);
00053     for (int i=0 ; i<conceptList.size(); ++i)
00054     {
00055         String concept = conceptList[i];
00056 
00057         if(!dataFactory->CanMakeConceptModel(concept))
00058         {
00059             ILOG_WARNING_NODE("cannot create model for " << concept <<
00060                          "; skipping concept");
00061             continue;
00062         }
00063         ILOG_INFO_NODE("concept " << concept);
00064 
00065         // read params
00066         Util::PropertySet params = 
00067             trainFactory->GetBestFileParams(dataFactory, concept);
00068         ILOG_INFO_NODE(params);
00069         if(params.Size() == 0)
00070         {
00071             ILOG_WARNING_NODE("no best file found for " << concept <<
00072                          "; skipping concept");
00073             continue;
00074         }
00075 
00076         //load annotation
00077         Table::AnnotationTable* annotation = dataFactory->MakeAnnotation(concept);
00078         if(annotation == 0)
00079         {
00080             ILOG_WARNING_NODE("no annotation found for " << concept << 
00081                          "; skipping concept");
00082             continue;
00083         }
00084         annotation->Sort();
00085 
00086         //compute model
00087         TrainDataSrc* src = trainFactory->MakeDataSrc(annotation, dataFactory);
00088         Core::Training::Classifier* svm = trainFactory->MakeClassifier("svm");
00089         svm->Train(&params, src);
00090         dataFactory->WriteConceptModel(concept, svm);
00091 
00092         // score on self
00093         Table::ScoreTable* ranking = svm->Predict(src);
00094         Training::AveragePrecision ap(annotation);
00095         double score = ap.Compute(ranking);
00096         if(score < ap.ComputeReversed(ranking))
00097         {
00098             ILOG_ERROR_NODE("reversed ranking > normal ranking => logic error in svm");
00099         }
00100         delete ranking;
00101         dataFactory->WriteScoreOnSelf(concept, score);
00102         ILOG_INFO_NODE("score on self = " << score);
00103         ILOG_DEBUG("deleting svm");
00104         delete svm;
00105         ILOG_DEBUG_NODE("deleting src");
00106         delete src;
00107         ILOG_DEBUG_NODE("deleting annotation");
00108         delete annotation;
00109         ILOG_DEBUG_NODE("done iteration");
00110     }
00111     return 0;
00112 }
00113 
00114 int
00115 mainTrainModel(int argc, char** argv)
00116 {
00117     ILOG_VAR(main);
00118     Link::Mpi::Init(&argc, &argv);
00119     CmdOptions& options = CmdOptions::GetInstance();
00120     options.Initialise(false, false, true);
00121     options.AddOption(0, "assume-shotid", "bool", "0");
00122     options.AddOption('m', "cache", "megabytes", "500");
00123     options.AddOption(0, "start", "index of concept to start with", "0");
00124     options.AddOption(0, "number", "number of concepts", "-1");
00125     options.AddOption(0, "concept", "name", "");
00126     options.AddOption(0, "kernel", "string: [linear,poly,rbf,sigmoid,precomputed,hist,dist-precomputed]", "rbf");
00127     options.AddOption(0, "maxVideoId", "index", "-1");
00128     options.AddOption(0, "maxPosPerVideo", "number", "-1");
00129     options.AddOption(0, "maxNegPerVideo", "number", "-1");
00130     // prevent dataserver from keeping all ImageArchives open
00131     options.AddOption(0, "imCacheSize", "size", "1");
00132 
00133     if (! options.ParseArgs(argc, argv, "dataSet concepts model featureDef", 4))
00134     {
00135         Link::Mpi::Finalize();
00136         return 1;
00137     }
00138 
00139     ApplicationFactory factory(&options);
00140     Util::PropertySet* properties = factory.MakeClassifierProperties();
00141     Training::Factory* trainFactory = factory.MakeTrainFactory(properties);
00142     DataFactory* dataFactory = factory.MakeDataFactory();
00143     /* when we use a distributed kernel matix we assume that node 0 does
00144        computation while other nodes load the distributed kernel matrix */
00145 
00146     int retval = 0;
00147     bool dist = (options.GetString("kernel") == "dist-precomputed");
00148     if(dist && Link::Mpi::MyId() != 0)
00149     {
00150         dataFactory->ServeDistributedAccess();
00151     }
00152     else
00153     {
00154         // make sure we subscribe, even if we are skipping all concepts;
00155         // otherwise the program will hang
00156         if(options.GetString("kernel") == "dist-precomputed")
00157             dataFactory->GetDistributedAccess();
00158         retval = TrainModel(trainFactory, dataFactory, dist);
00159     }
00160 
00161     ILOG_DEBUG_ONCE("deleting factories and props");
00162     delete dataFactory;
00163     delete trainFactory;
00164     delete properties;
00165     Link::Mpi::Finalize();
00166     return retval;
00167 }
00168 
00169 } // namespace Application
00170 } // namespace Impala
00171 
00172 int
00173 main(int argc, char* argv[])
00174 {
00175     return Impala::Application::mainTrainModel(argc, argv);
00176 }

Generated on Fri Mar 19 09:30:30 2010 for ImpalaSrc by  doxygen 1.5.1