Home || Architecture || Video Search || Visual Search || Scripts || Applications || Important Messages || OGL || Src

TrainDataSrc.h

Go to the documentation of this file.
00001 #ifndef Impala_Core_Training_TrainDataSrc_h
00002 #define Impala_Core_Training_TrainDataSrc_h
00003 
00004 #include "Core/Array/MakeFromValue.h"
00005 #include "Link/Svm/LinkSvm.h"
00006 #include "Core/VideoSet/VideoSet.h"
00007 #include "Core/Feature/FeatureDefinition.h"
00008 #include "Core/Table/QuidTable.h"
00009 #include "Core/Table/Sort.h"
00010 #include "Core/Table/AnnotationTable.h"
00011 
00012 namespace Impala
00013 {
00014 namespace Core
00015 {
00016 namespace Training
00017 {
00018 
00019 class TrainDataSrc
00020 {
00021 public:
00022     TrainDataSrc(Table::AnnotationTable* annotation)
00023     {
00024         if(annotation)
00025         {
00026             Table::CriterionElement2NotEquals<Table::AnnotationTable> neq(0);
00027             mAnnotation = Select(annotation, neq);
00028         }
00029         else
00030             mAnnotation = new Table::AnnotationTable();
00031         mQuids = mAnnotation->GetAnnotatedQuids();
00032         mSelection = mQuids;
00033     }
00034 
00035     virtual
00036     ~TrainDataSrc()
00037     {
00038         Clear();
00039     }
00040     
00050     virtual svm_problem*
00051     MakeSvmProblem() = 0;
00052 
00057     virtual svm_problem*
00058     MakeSvmProblem(int i) = 0;
00059 
00060     virtual int
00061     GetVectorLength() = 0;
00062 
00063     virtual Quid
00064     GetQuid(int i)
00065     {
00066         if(i >= mSelection->Size())
00067             ILOG_ERROR("GetQuid: i out of range:"<< i <<">="<<
00068                        mSelection->Size());
00069         return mSelection->Get1(i);
00070     }
00071 
00072     Table::QuidTable*
00073     GetSelection()
00074     {
00075         return mSelection;
00076     } 
00077     
00078     virtual int
00079     Size()
00080     {
00081         return mSelection->Size();
00082     }
00083 
00084     virtual void
00085     FilterTestFold(int f, int foldCount, int repetition, bool episodeConstrained)
00086     {
00087         std::vector<Table::QuidTable*> folds =
00088             MakeFolds(f, foldCount, repetition, episodeConstrained);
00089         for(int i=0 ; i<folds.size() ; ++i)
00090         {
00091             if(f == i)
00092                 mSelection = folds[f];
00093             else
00094                 delete folds[i];
00095         }
00096         // I'm actually not sure whether a sort here makes sense:
00097         // - on the one hand it makes TrainDataSrc more uniform and easier to debug
00098         // - on the other hand it is unnessecary and a bit slower
00099         // Let's stick with debuggability for now 
00100         Sort(mSelection, 1, true);
00101     }
00102 
00103     virtual void
00104     FilterTrainFold(int f, int foldCount, int repetition, bool episodeConstrained)
00105     {
00106         std::vector<Table::QuidTable*> folds =
00107             MakeFolds(f, foldCount, repetition, episodeConstrained);
00108         mSelection = new Table::QuidTable(0);
00109         ILOG_DEBUG_NODE("selection size = "<<mSelection->Size()<<")");
00110         for(int i=0 ; i<folds.size() ; ++i)
00111         {
00112             if(f != i)
00113             {
00114                 mSelection->Append(folds[i]);
00115             }
00116             delete folds[i];
00117         }
00118     }
00119 
00120     virtual void
00121     FreeProblem(svm_problem* p)
00122     {
00123         if(p->x)
00124         {
00125             delete p->x[0];
00126             delete p->x;
00127         }
00128         if(p->y)
00129         {
00130             delete p->y;
00131         }
00132         delete p;
00133     }
00134 
00135     int
00136     GetTotalPositiveCount()
00137     {
00138         return mAnnotation->GetNrPositive();
00139     }
00140 
00141     int
00142     GetTotalNegativeCount()
00143     {
00144         return mAnnotation->GetNrNegative();
00145     }
00146 
00147     void
00148     PrintSelection()
00149     {
00150         for(int i=0 ; i<mSelection->Size() ; ++i)
00151         {
00152             Quid q = mSelection->Get1(i);
00153             std::cout << QuidObj(q).ToString() <<" - "<<
00154                 mAnnotation->GetQualification(q) << std::endl;
00155         }
00156     }
00157 
00158     virtual Array::Array2dScalarReal64*
00159     MakeDataCopy(int maxCol, int maxRow)
00160     {
00161         ILOG_ERROR("dummy implementation called");
00162         return Array::MakeFromValue<Array::Array2dScalarReal64>(0, 0, 0, 0, 0);
00163     }
00164     
00165 protected:
00166     std::vector<Table::QuidTable*>
00167     MakeFolds(int f, int foldCount, int repetition, bool episodeConstrained)
00168     {
00169         ILOG_DEBUG_NODE("MakeFolds called with these params: "<< f <<", "<<
00170                         foldCount <<" , "<< repetition);
00171         ClearSelection();
00172         if(f >= foldCount)
00173         {
00174             ILOG_ERROR("MakeFolds invalid input: "<< f <<" >= "<< foldCount);
00175             exit(1);
00176         }
00177         if(episodeConstrained)
00178             return mAnnotation->MakeEpisodeFolds(foldCount, repetition);
00179         else
00180             return mAnnotation->MakeRandomFolds(foldCount, repetition);
00181     }
00182 
00183     void
00184     Clear()
00185     {
00186         ILOG_DEBUG("Clear() calling ClearSelection()");
00187         ClearSelection();
00188         ILOG_DEBUG("Clear() calling delete mQuids: "<<(void*)mQuids);
00189         delete mQuids;
00190         ILOG_DEBUG("Clear() calling delete mAnnotation");
00191         delete mAnnotation;
00192     }
00193 
00194     void
00195     SetAnnotation(Table::AnnotationTable* anno)
00196     {
00197         Clear();
00198         mAnnotation = anno;
00199         mQuids = mAnnotation->GetAnnotatedQuids();
00200         mSelection = mQuids;
00201     }
00202     
00203     svm_problem*
00204     MakeEmptyProblem()
00205     {
00206         svm_problem* p = new svm_problem;
00207         p->l = 0;
00208         p->x = 0;
00209         p->y = 0;
00210         return p;
00211     }
00212     
00213     void
00214     ClearSelection()
00215     {
00216         if(mSelection != mQuids)
00217             delete mSelection;
00218         mSelection = mQuids;
00219     }
00220 
00221     Table::AnnotationTable* mAnnotation;
00222     Table::QuidTable* mQuids;
00223     Table::QuidTable* mSelection;
00224 
00225     ILOG_VAR_DECL;
00226 };
00227 
00228 ILOG_VAR_INIT(TrainDataSrc, Impala.Core.Training);
00229 
00230 }//namespace Training
00231 }//namespace Core
00232 }//namespace Impala
00233 
00234 
00235 #endif

Generated on Fri Mar 19 09:31:25 2010 for ImpalaSrc by  doxygen 1.5.1