Home || Visual Search || Applications || Architecture || Important Messages || OGL || Src

mainActiveLearner.cpp

Go to the documentation of this file.
00001 #include "Basis/CmdOptions.h"
00002 #include "Core/Table/Select.h"
00003 #include "Core/Training/TrainDataSrcKernelDistributed.h"
00004 #include "Core/Training/LoadDistributedAccess.h"
00005 #include "Core/Database/MakeRawDataSet.h"
00006 #include "Util/DatabaseReadString.h"
00007 #include "Persistency/SvmRepository.h"
00008 #include "Persistency/AnnotationTableRepository.h"
00009 #include "Persistency/SimilarityTableSetRepository.h"
00010 
00011 // since we are not using libraries:
00012 #include "Link/Svm/LinkSvm.cpp"
00013 
00014 
00015 namespace Impala
00016 {
00017 namespace Application
00018 {
00019 
00020 
00021 using namespace Core;
00022 using namespace Core::Training;
00023 using namespace Persistency;
00024 
00025 Table::SimilarityTableSet* 
00026 LearnConceptFromAnnotations(CmdOptions& options, Matrix::DistributedAccess* da,
00027                             String concept,
00028                             const ModelLocator& modelLoc,
00029                             Table::AnnotationTable *annotations)
00030 {
00031     ILOG_VAR(Application.LearnConceptFromAnnotations);
00032 
00033 
00034     ILOG_INFO("concept " << concept);
00035 
00036     //load annotation
00037     Table::QuidTable* allQuids = da->GetColumnQuids();
00038     Table::CriterionElement1InSet<Table::AnnotationTable> c(allQuids);
00039     Table::AnnotationTable* annotation = Select(annotations, c);
00040 
00041     Table::QuidTable* positive = annotation->GetPositive();
00042     Table::QuidTable* negative = annotation->GetNegative();
00043     Table::QuidTable* learnQuids = positive;
00044     learnQuids->Append(negative);
00045     delete negative;
00046 
00047     if ((positive->Size() == 0) || (negative->Size() == 0))
00048     {
00049         delete learnQuids;
00050         delete annotation;
00051         return 0;
00052     }
00053 
00054     Util::PropertySet params;
00055     params.Add("probability", 1);
00056     params.Add("cache", options.GetInt("cache"));
00057 
00058     params.Add("gamma", -1);
00059     params.Add("autoweight", 1);
00060     params.Add("kernel", "dist-precomputed");
00061     params.Add("precompute-kernel", "chi2");
00062     params.Add("evaluator", "AP");
00063     params.Add("C", 1.0);
00064     params.Add("w1", 30.0); // estimate these
00065     params.Add("w2", 1.2); // estimate these
00066     for(int i = 0; i < params.Size(); i++)
00067         ILOG_INFO(params.GetName(i) + " " + params.GetValue(i));
00068 
00069     //compute model:
00070     Training::Svm svm;
00071     Training::TrainDataSrcKernelDistributed* dataSrc =
00072         new TrainDataSrcKernelDistributed(da, annotation);
00073     svm.Train(&params, dataSrc);
00074     ILOG_INFO("saving model");
00075     SvmRepository().Add(modelLoc, &svm);
00076 
00077     // compute resulting ranking:
00078     std::vector<String> names;
00079     names.push_back(concept);
00080     Table::QuidTable* rowQuids = da->GetRowQuids();
00081     Table::SimilarityTableSet *simSet =
00082         new Table::SimilarityTableSet(names, rowQuids->Size());
00083     Table::Copy(simSet->GetQuidTable(), rowQuids);
00084 
00085     ILOG_INFO("applying model");
00086     svm.PredictForActiveLearn(da, learnQuids, simSet->GetSimTable(0));
00087     
00088     ILOG_INFO("ranking");
00089     simSet->ComputeRank(0, true);
00090     
00091     delete annotation;
00092     delete learnQuids;
00093     delete dataSrc;
00094     return simSet;
00095 }
00096 
00097 int
00098 RunDistributedLearningEngine(CmdOptions& options)
00099 {
00100     ILOG_VAR(Application.RunDistributedLearningEngine);
00101     String setName = options.GetArg(0);
00102     String conceptsName = options.GetArg(1); //"conceptsActiveLearn.txt";
00103     String model = options.GetArg(2);
00104     String kernel = options.GetArg(3);
00105 
00106     Database::RawDataSet* dataSet = Database::MakeRawDataSet(setName);
00107     int quidClass = dataSet->GetQuidClass();
00108 
00109     int startnode = 1;
00110     int nodes = Link::Mpi::NrProcs() - startnode;
00111     Matrix::DistributedAccess* da =
00112         LoadDistributedAccess(model, kernel, dataSet, 0, startnode, nodes);
00113     if ((Link::Mpi::MyId() <= nodes + startnode) && (Link::Mpi::MyId() != 0))
00114     {
00115         da->StartEventLoop();
00116     }
00117     else
00118     {
00119         da->Subscribe();
00120         ILOG_INFO_NODE("engine loaded, waiting for annotations.");
00121 
00122         while (true)
00123         {
00124             // announce waiting for table
00125             // receive table
00126             String identifier;
00127             while (true)
00128             {
00129                 String fName = "Annotations/" + QuidClassToString(quidClass)
00130                     + "/" + conceptsName + "/startlearner.txt";
00131                 FileLocator fLoc(dataSet->GetLocator(), fName);
00132                 Persistency::File file = RepositoryGetFile(fLoc, false, true);
00133                 if (file.Valid())
00134                 {
00135                     std::vector<String> ids;
00136                     file.ReadStrings(ids);
00137                     if (ids.size() > 0)
00138                     {
00139                         identifier = ids[0];
00140                         ILOG_INFO("Starting model train cycle for "<<identifier);
00141                         ids.clear();
00142                         file = RepositoryGetFile(fLoc, true, false);
00143                         file.WriteStrings(ids.begin(), ids.end());
00144                         break;
00145                     }
00146                 }
00147             }
00148 
00149             AnnotationTableLocator annoLoc(dataSet->GetLocator(), quidClass,
00150                                            conceptsName, identifier);
00151             if (!AnnotationTableRepository().Exists(annoLoc))
00152                 continue;
00153             Table::AnnotationTable* annotationtable =
00154                 AnnotationTableRepository().Get(annoLoc);
00155 
00156             ModelLocator modelLoc(dataSet->GetLocator(), conceptsName,
00157                                   "activelearn", "activelearn", identifier);
00158             Timer timer;
00159             Table::SimilarityTableSet* simSet = 
00160                 LearnConceptFromAnnotations(options, da, identifier, modelLoc,
00161                                             annotationtable);
00162             ILOG_INFO("Learned in " << timer.SplitTimeStr());
00163 
00164             if (!simSet)
00165             {
00166                 ILOG_ERROR("Learning aborted, no positives or negatives?");
00167                 continue;
00168             }
00169 
00170             // save concept
00171             SimilarityTableSetLocator simLoc(dataSet->GetLocator(), true, "",
00172                                              conceptsName, "activelearn",
00173                                              "activelearn", "");
00174             SimilarityTableSetRepository().Add(simLoc, simSet);
00175 
00176             delete annotationtable;
00177             delete simSet;
00178 
00179             ILOG_INFO("Loop done: " << timer.SplitTimeStr());
00180         }
00181         da->Unsubscribe();
00182     }
00183     delete da;
00184     Link::Mpi::Finalize();
00185     return 0;
00186 }
00187 
00188 
00189 int
00190 mainActiveLearner(int argc, char** argv)
00191 {
00192     Link::Mpi::Init(&argc, &argv);
00193     CmdOptions& options = CmdOptions::GetInstance();
00194     options.Initialise(false, false, true);
00195     options.AddOption(0, "assume-shotid", "bool", "0");
00196     options.AddOption('m', "cache", "megabytes", "10");
00197     //options.AddOption(0, "kernel", "string: [linear,poly,rbf,sigmoid,precomputed,hist,dist-precomputed]", "rbf");
00198     options.AddOption(0, "maxVideoId", "index", "-1");
00199     options.AddOption(0, "maxPosPerVideo", "number", "-1");
00200     options.AddOption(0, "maxNegPerVideo", "number", "-1");
00201 
00202     if (! options.ParseArgs(argc, argv, "<video set> <concept definitions> <model type> <features>", 4))
00203     {
00204         Link::Mpi::Finalize();
00205         return 1;
00206     }
00207 
00208     return RunDistributedLearningEngine(options);
00209 }
00210 
00211 } // namespace Application
00212 } // namespace Impala
00213 
00214 int
00215 main(int argc, char* argv[])
00216 {
00217     return Impala::Application::mainActiveLearner(argc, argv);
00218 }

Generated on Thu Jan 13 09:03:41 2011 for ImpalaSrc by  doxygen 1.5.1