00001 #include "Basis/CmdOptions.h"
00002 #include "Util/Mpi/SelectWorkLoadInPlace.h"
00003 #include "Core/ApplicationFactory.h"
00004
00005
00006 #include "Link/Svm/LinkSvm.cpp"
00007
00008
00009 namespace Impala
00010 {
00011 namespace Application
00012 {
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043 using namespace Core;
00044 using namespace Core::Training;
00045
00046 int TrainModel(Training::Factory* trainFactory,
00047 DataFactory* dataFactory, bool distKernel)
00048 {
00049 ILOG_VAR(main);
00050 std::vector<String> conceptList = dataFactory->MakeConceptList();
00051 if(!distKernel)
00052 Util::Mpi::SelectWorkLoadInPlace(&conceptList);
00053 for (int i=0 ; i<conceptList.size(); ++i)
00054 {
00055 String concept = conceptList[i];
00056
00057 if(!dataFactory->CanMakeConceptModel(concept))
00058 {
00059 ILOG_WARNING_NODE("cannot create model for " << concept <<
00060 "; skipping concept");
00061 continue;
00062 }
00063 ILOG_INFO_NODE("concept " << concept);
00064
00065
00066 Util::PropertySet params =
00067 trainFactory->GetBestFileParams(dataFactory, concept);
00068 ILOG_INFO_NODE(params);
00069 if(params.Size() == 0)
00070 {
00071 ILOG_WARNING_NODE("no best file found for " << concept <<
00072 "; skipping concept");
00073 continue;
00074 }
00075
00076
00077 Table::AnnotationTable* annotation = dataFactory->MakeAnnotation(concept);
00078 if(annotation == 0)
00079 {
00080 ILOG_WARNING_NODE("no annotation found for " << concept <<
00081 "; skipping concept");
00082 continue;
00083 }
00084 annotation->Sort();
00085
00086
00087 TrainDataSrc* src = trainFactory->MakeDataSrc(annotation, dataFactory);
00088 Core::Training::Classifier* svm = trainFactory->MakeClassifier("svm");
00089 svm->Train(¶ms, src);
00090 dataFactory->WriteConceptModel(concept, svm);
00091
00092
00093 Table::ScoreTable* ranking = svm->Predict(src);
00094 Training::AveragePrecision ap(annotation);
00095 double score = ap.Compute(ranking);
00096 if(score < ap.ComputeReversed(ranking))
00097 {
00098 ILOG_ERROR_NODE("reversed ranking > normal ranking => logic error in svm");
00099 }
00100 delete ranking;
00101 dataFactory->WriteScoreOnSelf(concept, score);
00102 ILOG_INFO_NODE("score on self = " << score);
00103 ILOG_DEBUG("deleting svm");
00104 delete svm;
00105 ILOG_DEBUG_NODE("deleting src");
00106 delete src;
00107 ILOG_DEBUG_NODE("deleting annotation");
00108 delete annotation;
00109 ILOG_DEBUG_NODE("done iteration");
00110 }
00111 return 0;
00112 }
00113
00114 int
00115 mainTrainModel(int argc, char** argv)
00116 {
00117 ILOG_VAR(main);
00118 Link::Mpi::Init(&argc, &argv);
00119 CmdOptions& options = CmdOptions::GetInstance();
00120 options.Initialise(false, false, true);
00121 options.AddOption(0, "assume-shotid", "bool", "0");
00122 options.AddOption('m', "cache", "megabytes", "500");
00123 options.AddOption(0, "start", "index of concept to start with", "0");
00124 options.AddOption(0, "number", "number of concepts", "-1");
00125 options.AddOption(0, "concept", "name", "");
00126 options.AddOption(0, "kernel", "string: [linear,poly,rbf,sigmoid,precomputed,hist,dist-precomputed]", "rbf");
00127 options.AddOption(0, "maxVideoId", "index", "-1");
00128 options.AddOption(0, "maxPosPerVideo", "number", "-1");
00129 options.AddOption(0, "maxNegPerVideo", "number", "-1");
00130
00131 options.AddOption(0, "imCacheSize", "size", "1");
00132
00133 if (! options.ParseArgs(argc, argv, "dataSet concepts model featureDef", 4))
00134 {
00135 Link::Mpi::Finalize();
00136 return 1;
00137 }
00138
00139 ApplicationFactory factory(&options);
00140 Util::PropertySet* properties = factory.MakeClassifierProperties();
00141 Training::Factory* trainFactory = factory.MakeTrainFactory(properties);
00142 DataFactory* dataFactory = factory.MakeDataFactory();
00143
00144
00145
00146 int retval = 0;
00147 bool dist = (options.GetString("kernel") == "dist-precomputed");
00148 if(dist && Link::Mpi::MyId() != 0)
00149 {
00150 dataFactory->ServeDistributedAccess();
00151 }
00152 else
00153 {
00154
00155
00156 if(options.GetString("kernel") == "dist-precomputed")
00157 dataFactory->GetDistributedAccess();
00158 retval = TrainModel(trainFactory, dataFactory, dist);
00159 }
00160
00161 ILOG_DEBUG_ONCE("deleting factories and props");
00162 delete dataFactory;
00163 delete trainFactory;
00164 delete properties;
00165 Link::Mpi::Finalize();
00166 return retval;
00167 }
00168
00169 }
00170 }
00171
00172 int
00173 main(int argc, char* argv[])
00174 {
00175 return Impala::Application::mainTrainModel(argc, argv);
00176 }