Home || Architecture || Video Search || Visual Search || Scripts || Applications || Important Messages || OGL || Src

mainActiveLearner.cpp

Go to the documentation of this file.
00001 #include "Basis/CmdOptions.h"
00002 #include "Core/Table/Select.h"
00003 #include "Core/Training/SvmFile.h"
00004 #include "Core/Training/Svm.h"
00005 #include "Core/Training/Factory.h"
00006 #include "Core/Training/AveragePrecision.h"
00007 #include "Core/VideoSet/MakeVideoSet.h"
00008 #include "Core/ImageSet/MakeImageSet.h"
00009 #include "Util/DatabaseReadString.h"
00010 
00011 // since we are not using libraries:
00012 #include "Link/Svm/LinkSvm.cpp"
00013 
00014 
00015 namespace Impala
00016 {
00017 namespace Application
00018 {
00019 
00020 
00021 using namespace Core;
00022 using namespace Core::Training;
00023 
00024 Table::SimilarityTableSet* 
00025 LearnConceptFromAnnotations(CmdOptions& options, Matrix::DistributedAccess &da,
00026                             String concept,
00027                             String modelname, 
00028                             Table::AnnotationTable *annotations,
00029                             Util::Database* db)
00030 {
00031     ILOG_VAR(Application.LearnConceptFromAnnotations);
00032 
00033 
00034     ILOG_INFO("concept " << concept);
00035 
00036     //load annotation
00037     Table::QuidTable* allQuids = da.GetColumnQuids();
00038     Table::CriterionElement1InSet<Table::AnnotationTable> c(allQuids);
00039     Table::AnnotationTable* annotation = Select(annotations, c);
00040 
00041     Table::QuidTable* positive = annotation->GetPositive();
00042     Table::QuidTable* negative = annotation->GetNegative();
00043     Table::QuidTable* learnQuids = positive;
00044     learnQuids->Append(negative);
00045     delete negative;
00046 
00047     if((positive->Size() == 0) || (negative->Size() == 0))
00048     {
00049         delete learnQuids;
00050         delete annotation;
00051         return 0;
00052     }
00053 
00054     // read params from best file: the seccond line in the file
00055 /*    filename = dataSet->GetFilePathConceptModel
00056         (conceptSet, modelType, featureDef, concept+".best", false, false);
00057     if (filename.empty())
00058     {
00059         ILOG_INFO("Could not open " << concept << ".best, skipping");
00060         continue;
00061     }*/
00062 
00063 /*    // TODO: estimate .best file
00064 
00065     Util::IOBuffer* ioBuf = db->GetIOBuffer(filename, true, false, "");
00066     String buffer = ioBuf->ReadLine();
00067     ILOG_INFO(buffer);
00068     buffer = ioBuf->ReadLine();
00069     Util::PropertySet params(buffer);
00070 */
00071     Util::PropertySet params;
00072     params.Add("probability", 1);
00073     params.Add("cache", options.GetInt("cache"));
00074 
00075     params.Add("gamma", -1);
00076     params.Add("autoweight", 1);
00077     params.Add("kernel", "dist-precomputed");
00078     params.Add("precompute-kernel", "chi2");
00079     params.Add("evaluator", "AP");
00080     params.Add("C", 1.0);
00081     params.Add("w1", 30.0); // estimate these
00082     params.Add("w2", 1.2); // estimate these
00083     for(int i = 0; i < params.Size(); i++)
00084         ILOG_INFO(params.GetName(i) + " " + params.GetValue(i));
00085 
00086     //compute model:
00087     Core::Training::Svm svm;
00088     Core::Training::TrainDataSrcKernelDistributed* dataSrc = new TrainDataSrcKernelDistributed(&da, annotation);
00089     svm.Train(&params, dataSrc);
00090     if(!modelname.empty())    
00091         svm.SaveModel(modelname, db);
00092 
00093     // compute resulting ranking:
00094     std::vector<std::string> names;
00095     names.push_back(concept);
00096     Table::QuidTable* rowQuids = da.GetRowQuids();
00097     Table::SimilarityTableSet *simSet = new Table::SimilarityTableSet(names, rowQuids->Size());
00098     Core::Table::Copy(simSet->GetQuidTable(), rowQuids);
00099 
00100     ILOG_INFO("applying model");
00101     svm.PredictForActiveLearn(da, learnQuids, simSet->GetSimTable(0));
00102     
00103     ILOG_INFO("ranking");
00104     simSet->ComputeRank(0, true);
00105     
00106     delete annotation;
00107     delete learnQuids;
00108     delete dataSrc;
00109 
00110     return simSet;
00111 
00112 }
00113 
00114 int
00115 RunDistributedLearningEngine(CmdOptions& options)
00116 {
00117     ILOG_VAR(Application.RunDistributedLearningEngine);
00118     int quidClass = 0;
00119     Core::Database::RawDataSet* dataSet = 0;
00120     String setName = options.GetArg(0);
00121     Util::Database* db = new Util::Database(setName);
00122     String path = db->GetFilePath("ImageData", setName, false, true);
00123     if (! path.empty())
00124     {
00125         quidClass = QUID_CLASS_IMAGE;
00126         dataSet = ImageSet::MakeImageSet(setName);
00127     }
00128     else
00129     {
00130         quidClass = QUID_CLASS_FRAME;
00131         dataSet = VideoSet::MakeVideoSet(setName);
00132     }
00133     db = dataSet->GetDatabase();
00134 
00135     String modelType = options.GetArg(2);
00136     Feature::FeatureDefinition featureDef(options.GetArg(3));
00137 
00138     int startnode = 1;
00139     int nodes = Link::Mpi::NrProcs() - startnode;
00140     Matrix::DistributedAccess da(options.GetArg(3), dataSet, 0, startnode, nodes);
00141     if(Link::Mpi::MyId() <= nodes + startnode && Link::Mpi::MyId() != 0)
00142         da.StartEventLoop();
00143     else
00144     {
00145         da.Subscribe();
00146         ILOG_INFO_NODE("engine loaded, waiting for annotations.");
00147         String conceptsName = options.GetArg(1); //"conceptsActiveLearn.txt";
00148 
00149         while(true)
00150         {
00151             // announce waiting for table
00152             // receive table
00153 
00154             String identifier;
00155 
00156             while (true)
00157             {
00158                String sfile = dataSet->GetFilePathAnnotation(quidClass, conceptsName, "startlearner.txt", false, true);
00159                 if (!sfile.empty())
00160                 {
00161                     std::vector<String> blaat;
00162                     Util::DatabaseReadStrings(blaat, sfile, db);
00163                     if (blaat.size() > 0)
00164                     {
00165                         identifier = blaat[0];
00166                         ILOG_INFO("Starting model train cycle for " << identifier);
00167                         blaat.clear();
00168 
00169                         Util::DatabaseWriteString(sfile, db, blaat.begin(), blaat.end());
00170                         break;
00171                     }
00172                 }
00173             }
00174 
00175             Table::AnnotationTable* annotationtable = new Table::AnnotationTable;
00176 
00177             String filename = dataSet->GetFilePathAnnotation(quidClass, conceptsName, identifier + ".tab", false, true);
00178             if(filename.empty())
00179             {
00180                 //ILOG_WARN("Annotation file not found for " << identifier);
00181                 continue;
00182             }
00183             Read(annotationtable, filename, db);
00184 
00185             String modelname = "";
00186             if(true)  // save the moel?
00187             {
00188                 modelname = dataSet->GetFilePathConceptModel
00189                     (conceptsName, "activelearn", 
00190                      Feature::FeatureDefinition("activelearn"), 
00191                      identifier+".model", 
00192                      true, false);
00193             }
00194             Timer timer;
00195             Table::SimilarityTableSet* simSet = 
00196                 LearnConceptFromAnnotations(options, da, identifier, modelname, annotationtable, db);
00197             ILOG_INFO("Learned in " << timer.SplitTimeStr());
00198 
00199             if(!simSet)
00200             {
00201                 ILOG_ERROR("Learning aborted, no positives or negatives?");
00202                 continue;
00203             }
00204 
00205             // save concept
00206             simSet->Save(dataSet, conceptsName, "activelearn", "activelearn",
00207                          true);
00208 
00209             delete annotationtable;
00210             delete simSet;
00211 
00212             ILOG_INFO("Loop done: " << timer.SplitTimeStr());
00213         }
00214         da.Unsubscribe();
00215     }
00216     Link::Mpi::Finalize();
00217     return 0;
00218 }
00219 
00220 
00221 int
00222 mainActiveLearner(int argc, char** argv)
00223 {
00224     Link::Mpi::Init(&argc, &argv);
00225     CmdOptions& options = CmdOptions::GetInstance();
00226     options.Initialise(false, false, true);
00227     options.AddOption(0, "assume-shotid", "bool", "0");
00228     options.AddOption('m', "cache", "megabytes", "10");
00229     //options.AddOption(0, "kernel", "string: [linear,poly,rbf,sigmoid,precomputed,hist,dist-precomputed]", "rbf");
00230     options.AddOption(0, "maxVideoId", "index", "-1");
00231     options.AddOption(0, "maxPosPerVideo", "number", "-1");
00232     options.AddOption(0, "maxNegPerVideo", "number", "-1");
00233 
00234     if (! options.ParseArgs(argc, argv, "<video set> <concept definitions> <model type> <features>", 4))
00235     {
00236         Link::Mpi::Finalize();
00237         return 1;
00238     }
00239 
00240     return RunDistributedLearningEngine(options);
00241 }
00242 
00243 } // namespace Application
00244 } // namespace Impala
00245 
00246 int
00247 main(int argc, char* argv[])
00248 {
00249     return Impala::Application::mainActiveLearner(argc, argv);
00250 }

Generated on Fri Mar 19 09:30:27 2010 for ImpalaSrc by  doxygen 1.5.1