00001 #ifndef Impala_Core_Training_Factory_h
00002 #define Impala_Core_Training_Factory_h
00003
00004 #include "Core/Training/ParameterSearcher.h"
00005 #include "Core/Training/MakeEvaluation.h"
00006 #include "Core/Training/Svm.h"
00007 #include "Core/Training/BestFile.h"
00008 #include "Core/Training/ClassifierEvaluator.h"
00009 #include "Core/Training/TrainDataSrcFeature.h"
00010 #include "Core/Training/TrainDataSrcKernelDistributed.h"
00011 #include "Core/Database/MakeRawDataSet.h"
00012 #include "Persistency/KeywordListRepository.h"
00013 #include "Persistency/AnnotationTableRepository.h"
00014 #include "Core/Training/LoadDistributedAccess.h"
00015 #include "Util/StlHelpers.h"
00016
00017 namespace Impala
00018 {
00019 namespace Core
00020 {
00021 namespace Training
00022 {
00023
00032 class Factory
00033 {
00034 public:
00035
00036 typedef Persistency::FeatureLocator FeatureLocator;
00037 typedef Persistency::ModelLocator ModelLocator;
00038
00039 Factory(CmdOptions* options, bool doClassifier)
00040 {
00041 mOptions = options;
00042 String setName = mOptions->GetArg(0);
00043 if (Persistency::ImageSetsRepository::GetInstance().Contains(setName))
00044 {
00045 ImageSet::ImageSet* is = ImageSet::MakeImageSet(setName);
00046 is->SetImageSrc(true, true, true);
00047 mDataSet = is;
00048 }
00049 else
00050 {
00051 mDataSet = VideoSet::MakeVideoSet(setName);
00052 }
00053
00054 if (doClassifier)
00055 mProperties = MakeClassifierProperties();
00056 else
00057 mProperties = 0;
00058
00059 mConceptSet = mOptions->GetArg(1);
00060 mModelType = mOptions->GetArg(2);
00061 mFeature = mOptions->GetArg(3);
00062 mModelLoc = ModelLocator(mDataSet->GetLocator(), mConceptSet,
00063 mModelType, mFeature, "dummy");
00064 mCodebookLoc = FeatureLocator(mDataSet->GetLocator(), true, true, "",
00065 mFeature, "");
00066 mConceptStart = mOptions->GetInt("start", 0);
00067 mConceptNumber = mOptions->GetInt("number", -1);
00068 mDistributedAccess = 0;
00069 }
00070
00071 virtual
00072 ~Factory()
00073 {
00074 if (mProperties)
00075 {
00076 delete mProperties;
00077 mProperties = 0;
00078 }
00079 if (mDistributedAccess)
00080 {
00081 mDistributedAccess->Unsubscribe();
00082 delete mDistributedAccess;
00083 }
00084 }
00085
00086 Util::PropertySet*
00087 MakeClassifierProperties()
00088 {
00089 ILOG_VAR(Samples.mainCrossValidate.GetProperties);
00090 Util::PropertySet* properties = new Util::PropertySet;
00091
00092 properties->Add("w1", mOptions->GetString("w1"));
00093 properties->Add("w2", mOptions->GetString("w2"));
00094 properties->Add("autoweight", mOptions->GetString("autoweight"));
00095 if(properties->GetBool("autoweight"))
00096 {
00097 properties->Add("w1", "1");
00098 properties->Add("w2", "1");
00099 }
00100 properties->Add("C", mOptions->GetString("C"));
00101 properties->Add("gamma", mOptions->GetString("gamma"));
00102 properties->Add("cache", mOptions->GetString("cache"));
00103 properties->Add("probability", mOptions->GetInt("probability"));
00104 properties->Add("kernel", mOptions->GetString("kernel"));
00105 properties->Add("precompute-kernel",
00106 mOptions->GetString("precompute-kernel"));
00107 properties->Add("folds", mOptions->GetInt("folds"));
00108 properties->Add("repetitions", mOptions->GetInt("repetitions"));
00109 properties->Add("episode-constrained",
00110 mOptions->GetInt("episode-constrained"));
00111 properties->Add("evaluator", mOptions->GetString("evaluator"));
00112 properties->Add("maxVideoId", mOptions->GetInt("maxVideoId"));
00113 properties->Add("maxPosPerVideo", mOptions->GetInt("maxPosPerVideo"));
00114 properties->Add("maxNegPerVideo", mOptions->GetInt("maxNegPerVideo"));
00115 properties->Add("restrictTestFoldSet",
00116 mOptions->GetInt("restrictTestFoldSet"));
00117 if (properties->GetInt("episode-constrained") && mDataSet->IsImageSet())
00118 {
00119 ILOG_WARN("ImageSet learning doesn't support episode-constrained!");
00120 properties->Add("episode-constrained", 0);
00121 }
00122 if (!properties->GetBool("episode-constrained"))
00123 ILOG_INFO_HEADNODE("not using episode constrained.")
00124 else
00125 ILOG_INFO_HEADNODE("using episode constrained.")
00126 ILOG_INFO_HEADNODE("using " << properties->GetString("evaluator") <<
00127 " for evaluation.");
00128
00129 return properties;
00130 }
00131
00132 Util::PropertySet*
00133 GetProperties()
00134 {
00135 return mProperties;
00136 }
00137
00138 static Classifier*
00139 MakeClassifier(String type)
00140 {
00141 if (type == "svm")
00142 return new Svm();
00143 ILOG_ERROR("unable to make classifier of type " << type);
00144 return 0;
00145 }
00146
00147 ParameterSearcher*
00148 MakeSearcher(Table::AnnotationTable* annotation)
00149 {
00150 String evalDesc = mProperties->GetString("evaluator");
00151 Evaluation* eval = MakeEvaluation(evalDesc, annotation);
00152 Classifier* svm = MakeClassifier("svm");
00153 TrainDataSrc* src = MakeDataSrc(annotation);
00154 ParameterEvaluator* evaluator = new ClassifierEvaluator(svm, src, eval);
00155 Training::ParameterSearcher* searcher;
00156 searcher = new Training::ParameterSearcher(mProperties, evaluator);
00157 if (mProperties->GetString("kernel","rbf") == "dist-precomputed")
00158 searcher->OverrideParallelMode(false);
00159 return searcher;
00160 }
00161
00163 TrainDataSrc*
00164 MakeDataSrc(Table::AnnotationTable* anno)
00165 {
00166 String kernelName = mProperties->GetString("kernel", "rbf");
00167 ILOG_INFO_HEADNODE("MakeDataSrc with kernel " << kernelName);
00168 if (kernelName == "dist-precomputed")
00169 {
00170 Matrix::DistributedAccess* da = GetDistributedAccess();
00171 return new TrainDataSrcKernelDistributed(da, anno);
00172 }
00173 return new TrainDataSrcFeature(anno, mDataSet, mFeature);
00174 }
00175
00176 ModelLocator
00177 GetModelLocator()
00178 {
00179 return mModelLoc;
00180 }
00181
00182 Array::Array2dVec3UInt8*
00183 MakeImage(Quid q)
00184 {
00185 QuidObj qo(q);
00186 if (qo.Class() != mDataSet->GetQuidClass())
00187 {
00188 ILOG_ERROR("[MakeImage] DataSet doesn't contain "<< qo.Class());
00189 return 0;
00190 }
00191 if (mDataSet->IsImageSet())
00192 {
00193 return static_cast<ImageSet::ImageSet*>(mDataSet)->GetImage(qo.Id());
00194 }
00195 else
00196 {
00197 ILOG_WARNING("MakeImage not tested for videosets");
00198 Stream::RgbDataSrc* src =
00199 static_cast<VideoSet::VideoSet*>(mDataSet)->GetVideo(qo.Object());
00200 src->GotoFrame(qo.Id());
00202 Array::Array2dVec3UInt8* im = Array::ArrayCreate<Array::Array2dVec3UInt8>
00203 (src->FrameWidth(), src->FrameHeight(), 0, 0, src->DataPtr(), true);
00204 delete src;
00205 return im;
00206 }
00207 }
00208
00209 std::vector<String>
00210 MakeConceptList()
00211 {
00212 typedef Persistency::KeywordListLocator KeywordListLocator;
00213 typedef Persistency::KeywordListRepository KeywordListRepository;
00214 KeywordListLocator loc(mDataSet->GetLocator(), mConceptSet);
00215 std::vector<String> conceptList = *(KeywordListRepository().Get(loc));
00216
00217
00218 Util::SubSelectInPlace(&conceptList, mConceptStart, mConceptNumber);
00219 return conceptList;
00220 }
00221
00222 Table::AnnotationTable*
00223 MakeAnnotation(String concept)
00224 {
00225 typedef Persistency::AnnotationTableLocator AnnotationTableLocator;
00226 typedef Persistency::AnnotationTableRepository AnnotationTableRepository;
00227 AnnotationTableLocator loc(mDataSet->GetLocator(),
00228 mDataSet->GetQuidClass(),
00229 mConceptSet, concept);
00230 Table::AnnotationTable* annotation = AnnotationTableRepository().Get(loc);
00232
00233
00234
00235
00236
00237
00238
00239
00240
00241
00242 ILOG_INFO_HEADNODE("Annotations: size = " << annotation->Size() );
00243 ILOG_INFO_HEADNODE("nr pos = "<< annotation->GetNrPositive() <<
00244 ", nr neg = "<< annotation->GetNrNegative());
00245 return annotation;
00246 }
00247
00248 Database::RawDataSet*
00249 GetDataSet()
00250 {
00251 return mDataSet;
00252 }
00253
00257 Matrix::DistributedAccess*
00258 GetDistributedAccess()
00259 {
00260 if(mDistributedAccess == 0)
00261 {
00262 if(Link::Mpi::NrProcs() == 1)
00263 {
00264 mDistributedAccess = Training::LoadDistributedAccess
00265 (mModelType, mFeature, mDataSet, 0, 0, 1);
00266 }
00267 else
00268 {
00269 mDistributedAccess = Training::LoadDistributedAccess
00270 (mModelType, mFeature, mDataSet, 0, 1,
00271 Link::Mpi::NrProcs() - 1);
00272 if(Link::Mpi::MyId() == 0)
00273 mDistributedAccess->Subscribe();
00274 }
00275 }
00276 return mDistributedAccess;
00277 }
00278
00279 void
00280 ServeDistributedAccess()
00281 {
00282 Matrix::DistributedAccess* da = GetDistributedAccess();
00283 da->StartEventLoop();
00284 }
00285
00286 bool
00287 CodebookExists()
00288 {
00289 return Persistency::FeatureTableRepository().Exists(mCodebookLoc);
00290 }
00291
00292 void
00293 WriteCodebook(Feature::FeatureTable* forest)
00294 {
00295 Persistency::FeatureTableRepository().Add(mCodebookLoc, forest);
00296 }
00297
00298 Feature::FeatureDefinition
00299 GetFeatureDefinition()
00300 {
00301 return Feature::FeatureDefinition(mFeature);
00302 }
00303
00304 private:
00305
00306 CmdOptions* mOptions;
00307 Database::RawDataSet* mDataSet;
00308 Util::PropertySet* mProperties;
00309 int mConceptStart;
00310 int mConceptNumber;
00311 String mConceptSet;
00312 String mModelType;
00313 String mFeature;
00314 ModelLocator mModelLoc;
00315 FeatureLocator mCodebookLoc;
00316 Matrix::DistributedAccess* mDistributedAccess;
00317
00318 ILOG_VAR_DECL;
00319 };
00320
00321 ILOG_VAR_INIT(Factory, Impala.Core.Training);
00322
00323 }
00324 }
00325 }
00326
00327 #endif