Home || Visual Search || Applications || Architecture || Important Messages || OGL || Src

TrainDataSrc.h

Go to the documentation of this file.
00001 #ifndef Impala_Core_Training_TrainDataSrc_h
00002 #define Impala_Core_Training_TrainDataSrc_h
00003 
00004 #include "Core/Array/MakeFromValue.h"
00005 #include "Link/Svm/LinkSvm.h"
00006 #include "Core/VideoSet/VideoSet.h"
00007 #include "Core/Feature/FeatureDefinition.h"
00008 #include "Core/Table/QuidTable.h"
00009 #include "Core/Table/Sort.h"
00010 #include "Core/Table/AnnotationTable.h"
00011 
00012 namespace Impala
00013 {
00014 namespace Core
00015 {
00016 namespace Training
00017 {
00018 
00019 
00042 class TrainDataSrc
00043 {
00044 public:
00045     TrainDataSrc(Table::AnnotationTable* annotation)
00046     {
00047         if(annotation)
00048         {
00049             Table::CriterionElement2NotEquals<Table::AnnotationTable> neq(0);
00050             mAnnotation = Select(annotation, neq);
00051         }
00052         else
00053             mAnnotation = new Table::AnnotationTable();
00054         mQuids = mAnnotation->GetAnnotatedQuids();
00055         mSelection = mQuids;
00056     }
00057 
00058     virtual
00059     ~TrainDataSrc()
00060     {
00061         Clear();
00062     }
00063     
00073     virtual svm_problem*
00074     MakeSvmProblem() = 0;
00075 
00080     virtual svm_problem*
00081     MakeSvmProblem(int i) = 0;
00082 
00083     virtual int
00084     GetVectorLength() = 0;
00085 
00086     virtual Quid
00087     GetQuid(int i)
00088     {
00089         if(i >= mSelection->Size())
00090             ILOG_ERROR("GetQuid: i out of range:"<< i <<">="<<
00091                        mSelection->Size());
00092         return mSelection->Get1(i);
00093     }
00094 
00095     virtual int
00096     Size()
00097     {
00098         return mSelection->Size();
00099     }
00100 
00101     virtual void
00102     FilterTestFold(int f, int foldCount, int repetition,
00103                    bool episodeConstrained, int restrictSet)
00104     {
00105         std::vector<Table::QuidTable*> folds =
00106             MakeFolds(f, foldCount, repetition, episodeConstrained);
00107         for(int i=0 ; i<folds.size() ; ++i)
00108         {
00109             if(f == i)
00110                 mSelection = folds[f];
00111             else
00112                 delete folds[i];
00113         }
00114 
00115         if (restrictSet != -1)
00116         {
00117             int oldSize = mSelection->Size();
00118             Table::QuidTable* tmp = 0;
00119             Table::CriterionQuidSetEquals<Table::QuidTable> crit(restrictSet);
00120             tmp = Select(mSelection, crit);
00121             bool hasPositive = false;
00122             for (int i=0 ; i<tmp->Size() ; i++)
00123             {
00124                 if (mAnnotation->IsPositive(tmp->Get1(i)))
00125                 {
00126                     hasPositive = true;
00127                     break;
00128                 }
00129             }
00130             if (hasPositive)
00131             {
00132                 delete mSelection;
00133                 mSelection = tmp;
00134                 ILOG_INFO("FilterTestFold: restrictSetId: size from " << oldSize
00135                           << " to " << mSelection->Size());
00136             }
00137             else
00138             {
00139                 ILOG_WARN("FilterTestFold: restrict would eliminiate positives");
00140                 delete tmp;
00141             }
00142         }
00143 
00144         // I'm actually not sure whether a sort here makes sense:
00145         // - on the one hand it makes TrainDataSrc more uniform and easier to debug
00146         // - on the other hand it is unnessecary and a bit slower
00147         // Let's stick with debuggability for now 
00148         Sort(mSelection, 1, true);
00149     }
00150 
00151     virtual void
00152     FilterTrainFold(int f, int foldCount, int repetition, bool episodeConstrained)
00153     {
00154         std::vector<Table::QuidTable*> folds =
00155             MakeFolds(f, foldCount, repetition, episodeConstrained);
00156         mSelection = new Table::QuidTable(0);
00157         ILOG_DEBUG_NODE("selection size = "<<mSelection->Size()<<")");
00158         for(int i=0 ; i<folds.size() ; ++i)
00159         {
00160             if(f != i)
00161             {
00162                 mSelection->Append(folds[i]);
00163             }
00164             delete folds[i];
00165         }
00166     }
00167 
00168     virtual void
00169     FreeProblem(svm_problem* p)
00170     {
00171         if(p->x)
00172         {
00173             delete p->x[0];
00174             delete p->x;
00175         }
00176         if(p->y)
00177         {
00178             delete p->y;
00179         }
00180         delete p;
00181     }
00182 
00183     int
00184     GetTotalPositiveCount()
00185     {
00186         return mAnnotation->GetNrPositive();
00187     }
00188 
00189     int
00190     GetTotalNegativeCount()
00191     {
00192         return mAnnotation->GetNrNegative();
00193     }
00194 
00195     void
00196     PrintSelection()
00197     {
00198         std::cout << "selection: size = " << mSelection->Size() << std::endl;
00199         for(int i=0 ; i<mSelection->Size() ; ++i)
00200         {
00201             Quid q = mSelection->Get1(i);
00202             std::cout << QuidObj(q).ToString() <<" - "<<
00203                 mAnnotation->GetQualification(q) << std::endl;
00204         }
00205     }
00206 
00207 protected:
00208     std::vector<Table::QuidTable*>
00209     MakeFolds(int f, int foldCount, int repetition, bool episodeConstrained)
00210     {
00211         ILOG_DEBUG_NODE("MakeFolds called with these params: "<< f <<", "<<
00212                         foldCount <<" , "<< repetition);
00213         ClearSelection();
00214         if(f >= foldCount)
00215         {
00216             ILOG_ERROR("MakeFolds invalid input: "<< f <<" >= "<< foldCount);
00217             exit(1);
00218         }
00219         if(episodeConstrained)
00220             return mAnnotation->MakeEpisodeFolds(foldCount, repetition);
00221         else
00222             return mAnnotation->MakeRandomFolds(foldCount, repetition);
00223     }
00224 
00225     void
00226     Clear()
00227     {
00228         ILOG_DEBUG("Clear() calling ClearSelection()");
00229         ClearSelection();
00230         ILOG_DEBUG("Clear() calling delete mQuids: "<<(void*)mQuids);
00231         delete mQuids;
00232         ILOG_DEBUG("Clear() calling delete mAnnotation");
00233         delete mAnnotation;
00234     }
00235 
00236     void
00237     SetAnnotation(Table::AnnotationTable* anno)
00238     {
00239         Clear();
00240         mAnnotation = anno;
00241         mQuids = mAnnotation->GetAnnotatedQuids();
00242         mSelection = mQuids;
00243     }
00244     
00245     svm_problem*
00246     MakeEmptyProblem()
00247     {
00248         svm_problem* p = new svm_problem;
00249         p->l = 0;
00250         p->x = 0;
00251         p->y = 0;
00252         return p;
00253     }
00254     
00255     void
00256     ClearSelection()
00257     {
00258         if(mSelection != mQuids)
00259             delete mSelection;
00260         mSelection = mQuids;
00261     }
00262 
00263     Table::AnnotationTable* mAnnotation;
00264     Table::QuidTable* mQuids;
00265     Table::QuidTable* mSelection;
00266 
00267     ILOG_VAR_DECL;
00268 };
00269 
00270 ILOG_VAR_INIT(TrainDataSrc, Impala.Core.Training);
00271 
00272 }//namespace Training
00273 }//namespace Core
00274 }//namespace Impala
00275 
00276 
00277 #endif

Generated on Thu Jan 13 09:04:42 2011 for ImpalaSrc by  doxygen 1.5.1