Home || Visual Search || Applications || Architecture || Important Messages || OGL || Src

Factory.h

Go to the documentation of this file.
00001 #ifndef Impala_Core_Training_Factory_h
00002 #define Impala_Core_Training_Factory_h
00003 
00004 #include "Core/Training/ParameterSearcher.h"
00005 #include "Core/Training/MakeEvaluation.h"
00006 #include "Core/Training/Svm.h"
00007 #include "Core/Training/BestFile.h"
00008 #include "Core/Training/ClassifierEvaluator.h"
00009 #include "Core/Training/TrainDataSrcFeature.h"
00010 #include "Core/Training/TrainDataSrcKernelDistributed.h"
00011 #include "Core/Database/MakeRawDataSet.h"
00012 #include "Persistency/KeywordListRepository.h"
00013 #include "Persistency/AnnotationTableRepository.h"
00014 #include "Core/Training/LoadDistributedAccess.h"
00015 #include "Util/StlHelpers.h"
00016 
00017 namespace Impala
00018 {
00019 namespace Core
00020 {
00021 namespace Training
00022 {
00023 
00032 class Factory
00033 {
00034 public:
00035 
00036     typedef Persistency::FeatureLocator FeatureLocator;
00037     typedef Persistency::ModelLocator ModelLocator;
00038 
00039     Factory(CmdOptions* options, bool doClassifier)
00040     {
00041         mOptions = options;
00042         String setName = mOptions->GetArg(0);
00043         if (Persistency::ImageSetsRepository::GetInstance().Contains(setName))
00044         {
00045             ImageSet::ImageSet* is = ImageSet::MakeImageSet(setName);
00046             is->SetImageSrc(true, true, true);
00047             mDataSet = is;
00048         }
00049         else
00050         {
00051             mDataSet = VideoSet::MakeVideoSet(setName);
00052         }
00053 
00054         if (doClassifier)
00055             mProperties = MakeClassifierProperties();
00056         else
00057             mProperties = 0;
00058 
00059         mConceptSet = mOptions->GetArg(1);
00060         mModelType = mOptions->GetArg(2);
00061         mFeature = mOptions->GetArg(3);
00062         mModelLoc = ModelLocator(mDataSet->GetLocator(), mConceptSet,
00063                                  mModelType, mFeature, "dummy");
00064         mCodebookLoc = FeatureLocator(mDataSet->GetLocator(), true, true, "",
00065                                       mFeature, "");
00066         mConceptStart = mOptions->GetInt("start", 0);
00067         mConceptNumber = mOptions->GetInt("number", -1);
00068         mDistributedAccess = 0;
00069     }
00070 
00071     virtual
00072     ~Factory()
00073     {
00074         if (mProperties)
00075         {
00076             delete mProperties;
00077             mProperties = 0;
00078         }
00079         if (mDistributedAccess)
00080         {
00081             mDistributedAccess->Unsubscribe();
00082             delete mDistributedAccess;
00083         }
00084     }
00085 
00086     Util::PropertySet*
00087     MakeClassifierProperties()
00088     {
00089         ILOG_VAR(Samples.mainCrossValidate.GetProperties);
00090         Util::PropertySet* properties = new Util::PropertySet;
00091         // todo: maybe we can change these calls into a loop.
00092         properties->Add("w1", mOptions->GetString("w1"));
00093         properties->Add("w2", mOptions->GetString("w2"));
00094         properties->Add("autoweight", mOptions->GetString("autoweight"));
00095         if(properties->GetBool("autoweight"))
00096         {
00097             properties->Add("w1", "1");
00098             properties->Add("w2", "1");
00099         }
00100         properties->Add("C", mOptions->GetString("C"));
00101         properties->Add("gamma", mOptions->GetString("gamma"));
00102         properties->Add("cache", mOptions->GetString("cache"));
00103         properties->Add("probability", mOptions->GetInt("probability"));
00104         properties->Add("kernel", mOptions->GetString("kernel"));
00105         properties->Add("precompute-kernel", 
00106                         mOptions->GetString("precompute-kernel"));
00107         properties->Add("folds", mOptions->GetInt("folds"));
00108         properties->Add("repetitions", mOptions->GetInt("repetitions"));
00109         properties->Add("episode-constrained",
00110                         mOptions->GetInt("episode-constrained"));
00111         properties->Add("evaluator", mOptions->GetString("evaluator"));
00112         properties->Add("maxVideoId", mOptions->GetInt("maxVideoId"));
00113         properties->Add("maxPosPerVideo", mOptions->GetInt("maxPosPerVideo"));
00114         properties->Add("maxNegPerVideo", mOptions->GetInt("maxNegPerVideo"));
00115         properties->Add("restrictTestFoldSet",
00116                         mOptions->GetInt("restrictTestFoldSet"));
00117         if (properties->GetInt("episode-constrained") && mDataSet->IsImageSet())
00118         {
00119             ILOG_WARN("ImageSet learning doesn't support episode-constrained!");
00120             properties->Add("episode-constrained", 0);
00121         }
00122         if (!properties->GetBool("episode-constrained"))
00123             ILOG_INFO_HEADNODE("not using episode constrained.")
00124         else
00125             ILOG_INFO_HEADNODE("using episode constrained.")
00126         ILOG_INFO_HEADNODE("using " << properties->GetString("evaluator") <<
00127                            " for evaluation.");
00128 
00129         return properties;
00130     }
00131 
00132     Util::PropertySet*
00133     GetProperties()
00134     {
00135         return mProperties;
00136     }
00137 
00138     static Classifier*
00139     MakeClassifier(String type)
00140     {
00141         if (type == "svm")
00142             return new Svm();
00143         ILOG_ERROR("unable to make classifier of type " << type);
00144         return 0;
00145     }
00146 
00147     ParameterSearcher*
00148     MakeSearcher(Table::AnnotationTable* annotation)
00149     {
00150         String evalDesc = mProperties->GetString("evaluator");
00151         Evaluation* eval = MakeEvaluation(evalDesc, annotation);
00152         Classifier* svm = MakeClassifier("svm");
00153         TrainDataSrc* src = MakeDataSrc(annotation);
00154         ParameterEvaluator* evaluator = new ClassifierEvaluator(svm, src, eval);
00155         Training::ParameterSearcher* searcher;
00156         searcher = new Training::ParameterSearcher(mProperties, evaluator);
00157         if (mProperties->GetString("kernel","rbf") == "dist-precomputed")
00158             searcher->OverrideParallelMode(false);
00159         return searcher;
00160     }
00161 
00163     TrainDataSrc*
00164     MakeDataSrc(Table::AnnotationTable* anno)
00165     {
00166         String kernelName = mProperties->GetString("kernel", "rbf");
00167         ILOG_INFO_HEADNODE("MakeDataSrc with kernel " << kernelName);
00168         if (kernelName == "dist-precomputed")
00169         {
00170             Matrix::DistributedAccess* da = GetDistributedAccess();
00171             return new TrainDataSrcKernelDistributed(da, anno);
00172         }
00173         return new TrainDataSrcFeature(anno, mDataSet, mFeature);
00174     }
00175 
00176     ModelLocator
00177     GetModelLocator()
00178     {
00179         return mModelLoc;
00180     }
00181 
00182     Array::Array2dVec3UInt8*
00183     MakeImage(Quid q)
00184     {
00185         QuidObj qo(q);
00186         if (qo.Class() != mDataSet->GetQuidClass())
00187         {
00188             ILOG_ERROR("[MakeImage] DataSet doesn't contain "<< qo.Class());
00189             return 0;
00190         }
00191         if (mDataSet->IsImageSet())
00192         {
00193             return static_cast<ImageSet::ImageSet*>(mDataSet)->GetImage(qo.Id());
00194         }
00195         else
00196         {
00197             ILOG_WARNING("MakeImage not tested for videosets");
00198             Stream::RgbDataSrc* src =
00199                 static_cast<VideoSet::VideoSet*>(mDataSet)->GetVideo(qo.Object());
00200             src->GotoFrame(qo.Id());
00202             Array::Array2dVec3UInt8* im = Array::ArrayCreate<Array::Array2dVec3UInt8>
00203                 (src->FrameWidth(), src->FrameHeight(), 0, 0, src->DataPtr(), true);
00204             delete src;
00205             return im;
00206         }
00207     }
00208 
00209     std::vector<String>
00210     MakeConceptList()
00211     {
00212         typedef Persistency::KeywordListLocator KeywordListLocator;
00213         typedef Persistency::KeywordListRepository KeywordListRepository;
00214         KeywordListLocator loc(mDataSet->GetLocator(), mConceptSet);
00215         std::vector<String> conceptList = *(KeywordListRepository().Get(loc));
00216 
00217         // truncate the concept list according to conceptStart and conceptNumber
00218         Util::SubSelectInPlace(&conceptList, mConceptStart, mConceptNumber);
00219         return conceptList;
00220     }
00221 
00222     Table::AnnotationTable*
00223     MakeAnnotation(String concept)
00224     {
00225         typedef Persistency::AnnotationTableLocator AnnotationTableLocator;
00226         typedef Persistency::AnnotationTableRepository AnnotationTableRepository;
00227         AnnotationTableLocator loc(mDataSet->GetLocator(),
00228                                    mDataSet->GetQuidClass(),
00229                                    mConceptSet, concept);
00230         Table::AnnotationTable* annotation = AnnotationTableRepository().Get(loc);
00232         /*
00233         int maxV = mProperties->GetInt("maxVideoId");
00234         mAnnotation->SelectQuidObjectMaxId(maxV);
00235         int maxP = mProperties->GetInt("maxPosPerVideo");
00236         mAnnotation->SelectQuidObjectMaxPositive(maxP);
00237         int maxN = mProperties->GetInt("maxNegPerVideo");
00238         mAnnotation->SelectQuidObjectMaxNegative(maxN);
00239         double pos = mAnnotation->GetNrPositive();
00240         double neg = mAnnotation->GetNrNegative();
00241         */
00242         ILOG_INFO_HEADNODE("Annotations: size = " << annotation->Size() );
00243         ILOG_INFO_HEADNODE("nr pos = "<< annotation->GetNrPositive() <<
00244                            ", nr neg = "<< annotation->GetNrNegative());
00245         return annotation;
00246     }
00247 
00248     Database::RawDataSet*
00249     GetDataSet()
00250     {
00251         return mDataSet;
00252     }
00253 
00257     Matrix::DistributedAccess*
00258     GetDistributedAccess()
00259     {
00260         if(mDistributedAccess == 0)
00261         {
00262             if(Link::Mpi::NrProcs() == 1)
00263             {
00264                 mDistributedAccess = Training::LoadDistributedAccess
00265                     (mModelType, mFeature, mDataSet, 0, 0, 1);
00266             }
00267             else
00268             {
00269                 mDistributedAccess = Training::LoadDistributedAccess
00270                     (mModelType, mFeature, mDataSet, 0, 1,
00271                      Link::Mpi::NrProcs() - 1);
00272                 if(Link::Mpi::MyId() == 0)
00273                     mDistributedAccess->Subscribe();
00274             }
00275         }
00276         return mDistributedAccess;
00277     }
00278 
00279     void
00280     ServeDistributedAccess()
00281     {
00282         Matrix::DistributedAccess* da = GetDistributedAccess();
00283         da->StartEventLoop();
00284     }
00285 
00286     bool
00287     CodebookExists()
00288     {
00289         return Persistency::FeatureTableRepository().Exists(mCodebookLoc);
00290     }
00291 
00292     void
00293     WriteCodebook(Feature::FeatureTable* forest)
00294     {
00295         Persistency::FeatureTableRepository().Add(mCodebookLoc, forest);
00296     }
00297 
00298     Feature::FeatureDefinition
00299     GetFeatureDefinition()
00300     {
00301         return Feature::FeatureDefinition(mFeature);
00302     }
00303 
00304 private:
00305 
00306     CmdOptions* mOptions;
00307     Database::RawDataSet* mDataSet;
00308     Util::PropertySet* mProperties;
00309     int mConceptStart;
00310     int mConceptNumber;
00311     String mConceptSet;
00312     String mModelType;
00313     String mFeature;
00314     ModelLocator mModelLoc;
00315     FeatureLocator mCodebookLoc;
00316     Matrix::DistributedAccess* mDistributedAccess;
00317 
00318     ILOG_VAR_DECL;
00319 };
00320 
00321 ILOG_VAR_INIT(Factory, Impala.Core.Training);
00322 
00323 } //namespace Training
00324 } //namespace Core
00325 } //namespace Impala
00326 
00327 #endif

Generated on Thu Jan 13 09:04:41 2011 for ImpalaSrc by  doxygen 1.5.1