Home || Visual Search || Applications || Architecture || Important Messages || OGL || Src

mainTrainModel.cpp

Go to the documentation of this file.
00001 #include "Persistency/BestFileRepository.h"
00002 #include "Persistency/ScoreFileRepository.h"
00003 #include "Persistency/SvmRepository.h"
00004 #include "Basis/CmdOptions.h"
00005 #include "Util/Mpi/SelectWorkLoadInPlace.h"
00006 #include "Core/Training/Factory.h"
00007 
00008 // since we are not using libraries:
00009 #include "Link/Svm/LinkSvm.cpp"
00010 
00011 
00012 namespace Impala
00013 {
00014 namespace Application
00015 {
00016 
00017 
00018 using namespace Core;
00019 using namespace Core::Training;
00020 using namespace Persistency;
00021 
00022 int
00023 TrainModel(Training::Factory* factory, bool distKernel)
00024 {
00025     ILOG_VAR(Impala.Application.TrainModel);
00026     std::vector<String> conceptList = factory->MakeConceptList();
00027     if (!distKernel)
00028         Util::Mpi::SelectWorkLoadInPlace(&conceptList);
00029     for (int i=0 ; i<conceptList.size(); i++)
00030     {
00031         String concept = conceptList[i];
00032         ModelLocator modelLoc = factory->GetModelLocator();
00033 
00034         if (SvmRepository().Exists(modelLoc))
00035         {
00036             ILOG_WARNING_NODE("cannot create model for " << concept <<
00037                               "; skipping concept");
00038             continue;
00039         }
00040         ILOG_INFO_NODE("concept " << concept);
00041 
00042         // read params
00043         modelLoc.SetConcept(concept);
00044         Util::PropertySet* params = BestFileRepository().Get(modelLoc);
00045         params->Add("probability", 1);
00046         int cache = factory->GetProperties()->GetInt("cache");
00047         params->Add("cache", cache);
00048         ILOG_INFO_NODE(*params);
00049         if (params->Size() == 0)
00050         {
00051             ILOG_WARNING_NODE("no best file found for " << concept <<
00052                               "; skipping concept");
00053             continue;
00054         }
00055 
00056         // load annotation
00057         Table::AnnotationTable* annotation = factory->MakeAnnotation(concept);
00058         if (annotation == 0)
00059         {
00060             ILOG_WARNING_NODE("no annotation found for " << concept << 
00061                               "; skipping concept");
00062             continue;
00063         }
00064         annotation->Sort();
00065 
00066         // compute model
00067         TrainDataSrc* src = factory->MakeDataSrc(annotation);
00068         Classifier* classifier = factory->MakeClassifier("svm");
00069         classifier->Train(params, src);
00070         Svm* svm = static_cast<Svm*>(classifier);
00071         if (svm)
00072             SvmRepository().Add(modelLoc, svm);
00073         else
00074             ILOG_ERROR("Rogue classifier");
00075 
00076         // score on self
00077         Table::ScoreTable* ranking = classifier->Predict(src);
00078         Training::AveragePrecision ap(annotation);
00079         double score = ap.Compute(ranking);
00080         double scoreRev = ap.ComputeReversed(ranking);
00081         if (score < scoreRev)
00082         {
00083             ILOG_ERROR_NODE("reversed ranking score (" << scoreRev <<
00084                             ") > normal ranking score (" << score <<
00085                             ") for " << concept << " => logic error in svm" <<
00086                             " (try setting C=1 in .best file)");
00087         }
00088         delete ranking;
00089         ILOG_INFO_NODE("score on self = " << score);
00090         Util::PropertySet scoreProp;
00091         scoreProp.Add("scoreOnSelf", score);
00092         ScoreFileRepository().Add(modelLoc, &scoreProp);
00093 
00094         delete classifier;
00095         delete src;
00096         delete annotation;
00097         delete params;
00098     }
00099     return 0;
00100 }
00101 
00102 int
00103 mainTrainModel(int argc, char** argv)
00104 {
00105     ILOG_VAR(Impala.Application.mainTrainModel);
00106     Link::Mpi::Init(&argc, &argv);
00107     CmdOptions& options = CmdOptions::GetInstance();
00108     options.Initialise(false, false, true);
00109     options.AddOption(0, "assume-shotid", "bool", "0");
00110     options.AddOption('m', "cache", "megabytes", "500");
00111     options.AddOption(0, "start", "index of concept to start with", "0");
00112     options.AddOption(0, "number", "number of concepts", "-1");
00113     options.AddOption(0, "concept", "name", "");
00114     options.AddOption(0, "kernel", "string: [linear,poly,rbf,sigmoid,precomputed,hist,dist-precomputed]", "rbf");
00115     options.AddOption(0, "featureIndexCat", "name", "");
00116     options.AddOption(0, "maxVideoId", "index", "-1");
00117     options.AddOption(0, "maxPosPerVideo", "number", "-1");
00118     options.AddOption(0, "maxNegPerVideo", "number", "-1");
00119     // prevent dataserver from keeping all ImageArchives open
00120     options.AddOption(0, "imCacheSize", "size", "1");
00121 
00122     if (options.ParseArgs(argc, argv, "dataSet concepts model featureDef", 4))
00123     {
00124         Training::Factory factory(&options, true);
00125         /* when we use a distributed kernel matix we assume that node 0 does
00126            computation while other nodes load the distributed kernel matrix */
00127         bool dist = (options.GetString("kernel") == "dist-precomputed");
00128         if (dist && Link::Mpi::MyId() != 0)
00129         {
00130             factory.ServeDistributedAccess();
00131         }
00132         else
00133         {
00134             // make sure we subscribe, even if we are skipping all concepts;
00135             // otherwise the program will hang
00136             if (options.GetString("kernel") == "dist-precomputed")
00137                 factory.GetDistributedAccess();
00138             TrainModel(&factory, dist);
00139         }
00140     }
00141 
00142     int nrOfErrors = ILOG_ERROR_COUNT;
00143     nrOfErrors = Link::Mpi::ReduceSum(nrOfErrors);
00144     ILOG_INFO_HEADNODE("Root: total nr error = " << nrOfErrors);
00145     Link::Mpi::Finalize();
00146     return nrOfErrors;
00147 }
00148 
00149 } // namespace Application
00150 } // namespace Impala
00151 
00152 int
00153 main(int argc, char* argv[])
00154 {
00155     return Impala::Application::mainTrainModel(argc, argv);
00156 }

Generated on Thu Jan 13 09:03:44 2011 for ImpalaSrc by  doxygen 1.5.1