00001 #include "Persistency/BestFileRepository.h"
00002 #include "Persistency/ScoreFileRepository.h"
00003 #include "Persistency/SvmRepository.h"
00004 #include "Basis/CmdOptions.h"
00005 #include "Util/Mpi/SelectWorkLoadInPlace.h"
00006 #include "Core/Training/Factory.h"
00007
00008
00009 #include "Link/Svm/LinkSvm.cpp"
00010
00011
00012 namespace Impala
00013 {
00014 namespace Application
00015 {
00016
00017
00018 using namespace Core;
00019 using namespace Core::Training;
00020 using namespace Persistency;
00021
00022 int
00023 TrainModel(Training::Factory* factory, bool distKernel)
00024 {
00025 ILOG_VAR(Impala.Application.TrainModel);
00026 std::vector<String> conceptList = factory->MakeConceptList();
00027 if (!distKernel)
00028 Util::Mpi::SelectWorkLoadInPlace(&conceptList);
00029 for (int i=0 ; i<conceptList.size(); i++)
00030 {
00031 String concept = conceptList[i];
00032 ModelLocator modelLoc = factory->GetModelLocator();
00033
00034 if (SvmRepository().Exists(modelLoc))
00035 {
00036 ILOG_WARNING_NODE("cannot create model for " << concept <<
00037 "; skipping concept");
00038 continue;
00039 }
00040 ILOG_INFO_NODE("concept " << concept);
00041
00042
00043 modelLoc.SetConcept(concept);
00044 Util::PropertySet* params = BestFileRepository().Get(modelLoc);
00045 params->Add("probability", 1);
00046 int cache = factory->GetProperties()->GetInt("cache");
00047 params->Add("cache", cache);
00048 ILOG_INFO_NODE(*params);
00049 if (params->Size() == 0)
00050 {
00051 ILOG_WARNING_NODE("no best file found for " << concept <<
00052 "; skipping concept");
00053 continue;
00054 }
00055
00056
00057 Table::AnnotationTable* annotation = factory->MakeAnnotation(concept);
00058 if (annotation == 0)
00059 {
00060 ILOG_WARNING_NODE("no annotation found for " << concept <<
00061 "; skipping concept");
00062 continue;
00063 }
00064 annotation->Sort();
00065
00066
00067 TrainDataSrc* src = factory->MakeDataSrc(annotation);
00068 Classifier* classifier = factory->MakeClassifier("svm");
00069 classifier->Train(params, src);
00070 Svm* svm = static_cast<Svm*>(classifier);
00071 if (svm)
00072 SvmRepository().Add(modelLoc, svm);
00073 else
00074 ILOG_ERROR("Rogue classifier");
00075
00076
00077 Table::ScoreTable* ranking = classifier->Predict(src);
00078 Training::AveragePrecision ap(annotation);
00079 double score = ap.Compute(ranking);
00080 double scoreRev = ap.ComputeReversed(ranking);
00081 if (score < scoreRev)
00082 {
00083 ILOG_ERROR_NODE("reversed ranking score (" << scoreRev <<
00084 ") > normal ranking score (" << score <<
00085 ") for " << concept << " => logic error in svm" <<
00086 " (try setting C=1 in .best file)");
00087 }
00088 delete ranking;
00089 ILOG_INFO_NODE("score on self = " << score);
00090 Util::PropertySet scoreProp;
00091 scoreProp.Add("scoreOnSelf", score);
00092 ScoreFileRepository().Add(modelLoc, &scoreProp);
00093
00094 delete classifier;
00095 delete src;
00096 delete annotation;
00097 delete params;
00098 }
00099 return 0;
00100 }
00101
00102 int
00103 mainTrainModel(int argc, char** argv)
00104 {
00105 ILOG_VAR(Impala.Application.mainTrainModel);
00106 Link::Mpi::Init(&argc, &argv);
00107 CmdOptions& options = CmdOptions::GetInstance();
00108 options.Initialise(false, false, true);
00109 options.AddOption(0, "assume-shotid", "bool", "0");
00110 options.AddOption('m', "cache", "megabytes", "500");
00111 options.AddOption(0, "start", "index of concept to start with", "0");
00112 options.AddOption(0, "number", "number of concepts", "-1");
00113 options.AddOption(0, "concept", "name", "");
00114 options.AddOption(0, "kernel", "string: [linear,poly,rbf,sigmoid,precomputed,hist,dist-precomputed]", "rbf");
00115 options.AddOption(0, "featureIndexCat", "name", "");
00116 options.AddOption(0, "maxVideoId", "index", "-1");
00117 options.AddOption(0, "maxPosPerVideo", "number", "-1");
00118 options.AddOption(0, "maxNegPerVideo", "number", "-1");
00119
00120 options.AddOption(0, "imCacheSize", "size", "1");
00121
00122 if (options.ParseArgs(argc, argv, "dataSet concepts model featureDef", 4))
00123 {
00124 Training::Factory factory(&options, true);
00125
00126
00127 bool dist = (options.GetString("kernel") == "dist-precomputed");
00128 if (dist && Link::Mpi::MyId() != 0)
00129 {
00130 factory.ServeDistributedAccess();
00131 }
00132 else
00133 {
00134
00135
00136 if (options.GetString("kernel") == "dist-precomputed")
00137 factory.GetDistributedAccess();
00138 TrainModel(&factory, dist);
00139 }
00140 }
00141
00142 int nrOfErrors = ILOG_ERROR_COUNT;
00143 nrOfErrors = Link::Mpi::ReduceSum(nrOfErrors);
00144 ILOG_INFO_HEADNODE("Root: total nr error = " << nrOfErrors);
00145 Link::Mpi::Finalize();
00146 return nrOfErrors;
00147 }
00148
00149 }
00150 }
00151
00152 int
00153 main(int argc, char* argv[])
00154 {
00155 return Impala::Application::mainTrainModel(argc, argv);
00156 }