00001 #include "Basis/CmdOptions.h"
00002 #include "Core/Table/Select.h"
00003 #include "Core/Training/TrainDataSrcKernelDistributed.h"
00004 #include "Core/Training/LoadDistributedAccess.h"
00005 #include "Core/Database/MakeRawDataSet.h"
00006 #include "Util/DatabaseReadString.h"
00007 #include "Persistency/SvmRepository.h"
00008 #include "Persistency/AnnotationTableRepository.h"
00009 #include "Persistency/SimilarityTableSetRepository.h"
00010
00011
00012 #include "Link/Svm/LinkSvm.cpp"
00013
00014
00015 namespace Impala
00016 {
00017 namespace Application
00018 {
00019
00020
00021 using namespace Core;
00022 using namespace Core::Training;
00023 using namespace Persistency;
00024
00025 Table::SimilarityTableSet*
00026 LearnConceptFromAnnotations(CmdOptions& options, Matrix::DistributedAccess* da,
00027 String concept,
00028 const ModelLocator& modelLoc,
00029 Table::AnnotationTable *annotations)
00030 {
00031 ILOG_VAR(Application.LearnConceptFromAnnotations);
00032
00033
00034 ILOG_INFO("concept " << concept);
00035
00036
00037 Table::QuidTable* allQuids = da->GetColumnQuids();
00038 Table::CriterionElement1InSet<Table::AnnotationTable> c(allQuids);
00039 Table::AnnotationTable* annotation = Select(annotations, c);
00040
00041 Table::QuidTable* positive = annotation->GetPositive();
00042 Table::QuidTable* negative = annotation->GetNegative();
00043 Table::QuidTable* learnQuids = positive;
00044 learnQuids->Append(negative);
00045 delete negative;
00046
00047 if ((positive->Size() == 0) || (negative->Size() == 0))
00048 {
00049 delete learnQuids;
00050 delete annotation;
00051 return 0;
00052 }
00053
00054 Util::PropertySet params;
00055 params.Add("probability", 1);
00056 params.Add("cache", options.GetInt("cache"));
00057
00058 params.Add("gamma", -1);
00059 params.Add("autoweight", 1);
00060 params.Add("kernel", "dist-precomputed");
00061 params.Add("precompute-kernel", "chi2");
00062 params.Add("evaluator", "AP");
00063 params.Add("C", 1.0);
00064 params.Add("w1", 30.0);
00065 params.Add("w2", 1.2);
00066 for(int i = 0; i < params.Size(); i++)
00067 ILOG_INFO(params.GetName(i) + " " + params.GetValue(i));
00068
00069
00070 Training::Svm svm;
00071 Training::TrainDataSrcKernelDistributed* dataSrc =
00072 new TrainDataSrcKernelDistributed(da, annotation);
00073 svm.Train(¶ms, dataSrc);
00074 ILOG_INFO("saving model");
00075 SvmRepository().Add(modelLoc, &svm);
00076
00077
00078 std::vector<String> names;
00079 names.push_back(concept);
00080 Table::QuidTable* rowQuids = da->GetRowQuids();
00081 Table::SimilarityTableSet *simSet =
00082 new Table::SimilarityTableSet(names, rowQuids->Size());
00083 Table::Copy(simSet->GetQuidTable(), rowQuids);
00084
00085 ILOG_INFO("applying model");
00086 svm.PredictForActiveLearn(da, learnQuids, simSet->GetSimTable(0));
00087
00088 ILOG_INFO("ranking");
00089 simSet->ComputeRank(0, true);
00090
00091 delete annotation;
00092 delete learnQuids;
00093 delete dataSrc;
00094 return simSet;
00095 }
00096
00097 int
00098 RunDistributedLearningEngine(CmdOptions& options)
00099 {
00100 ILOG_VAR(Application.RunDistributedLearningEngine);
00101 String setName = options.GetArg(0);
00102 String conceptsName = options.GetArg(1);
00103 String model = options.GetArg(2);
00104 String kernel = options.GetArg(3);
00105
00106 Database::RawDataSet* dataSet = Database::MakeRawDataSet(setName);
00107 int quidClass = dataSet->GetQuidClass();
00108
00109 int startnode = 1;
00110 int nodes = Link::Mpi::NrProcs() - startnode;
00111 Matrix::DistributedAccess* da =
00112 LoadDistributedAccess(model, kernel, dataSet, 0, startnode, nodes);
00113 if ((Link::Mpi::MyId() <= nodes + startnode) && (Link::Mpi::MyId() != 0))
00114 {
00115 da->StartEventLoop();
00116 }
00117 else
00118 {
00119 da->Subscribe();
00120 ILOG_INFO_NODE("engine loaded, waiting for annotations.");
00121
00122 while (true)
00123 {
00124
00125
00126 String identifier;
00127 while (true)
00128 {
00129 String fName = "Annotations/" + QuidClassToString(quidClass)
00130 + "/" + conceptsName + "/startlearner.txt";
00131 FileLocator fLoc(dataSet->GetLocator(), fName);
00132 Persistency::File file = RepositoryGetFile(fLoc, false, true);
00133 if (file.Valid())
00134 {
00135 std::vector<String> ids;
00136 file.ReadStrings(ids);
00137 if (ids.size() > 0)
00138 {
00139 identifier = ids[0];
00140 ILOG_INFO("Starting model train cycle for "<<identifier);
00141 ids.clear();
00142 file = RepositoryGetFile(fLoc, true, false);
00143 file.WriteStrings(ids.begin(), ids.end());
00144 break;
00145 }
00146 }
00147 }
00148
00149 AnnotationTableLocator annoLoc(dataSet->GetLocator(), quidClass,
00150 conceptsName, identifier);
00151 if (!AnnotationTableRepository().Exists(annoLoc))
00152 continue;
00153 Table::AnnotationTable* annotationtable =
00154 AnnotationTableRepository().Get(annoLoc);
00155
00156 ModelLocator modelLoc(dataSet->GetLocator(), conceptsName,
00157 "activelearn", "activelearn", identifier);
00158 Timer timer;
00159 Table::SimilarityTableSet* simSet =
00160 LearnConceptFromAnnotations(options, da, identifier, modelLoc,
00161 annotationtable);
00162 ILOG_INFO("Learned in " << timer.SplitTimeStr());
00163
00164 if (!simSet)
00165 {
00166 ILOG_ERROR("Learning aborted, no positives or negatives?");
00167 continue;
00168 }
00169
00170
00171 SimilarityTableSetLocator simLoc(dataSet->GetLocator(), true, "",
00172 conceptsName, "activelearn",
00173 "activelearn", "");
00174 SimilarityTableSetRepository().Add(simLoc, simSet);
00175
00176 delete annotationtable;
00177 delete simSet;
00178
00179 ILOG_INFO("Loop done: " << timer.SplitTimeStr());
00180 }
00181 da->Unsubscribe();
00182 }
00183 delete da;
00184 Link::Mpi::Finalize();
00185 return 0;
00186 }
00187
00188
00189 int
00190 mainActiveLearner(int argc, char** argv)
00191 {
00192 Link::Mpi::Init(&argc, &argv);
00193 CmdOptions& options = CmdOptions::GetInstance();
00194 options.Initialise(false, false, true);
00195 options.AddOption(0, "assume-shotid", "bool", "0");
00196 options.AddOption('m', "cache", "megabytes", "10");
00197
00198 options.AddOption(0, "maxVideoId", "index", "-1");
00199 options.AddOption(0, "maxPosPerVideo", "number", "-1");
00200 options.AddOption(0, "maxNegPerVideo", "number", "-1");
00201
00202 if (! options.ParseArgs(argc, argv, "<video set> <concept definitions> <model type> <features>", 4))
00203 {
00204 Link::Mpi::Finalize();
00205 return 1;
00206 }
00207
00208 return RunDistributedLearningEngine(options);
00209 }
00210
00211 }
00212 }
00213
00214 int
00215 main(int argc, char* argv[])
00216 {
00217 return Impala::Application::mainActiveLearner(argc, argv);
00218 }