00001 #ifndef Impala_Core_Training_TrainDataSrc_h
00002 #define Impala_Core_Training_TrainDataSrc_h
00003
00004 #include "Core/Array/MakeFromValue.h"
00005 #include "Link/Svm/LinkSvm.h"
00006 #include "Core/VideoSet/VideoSet.h"
00007 #include "Core/Feature/FeatureDefinition.h"
00008 #include "Core/Table/QuidTable.h"
00009 #include "Core/Table/Sort.h"
00010 #include "Core/Table/AnnotationTable.h"
00011
00012 namespace Impala
00013 {
00014 namespace Core
00015 {
00016 namespace Training
00017 {
00018
00019 class TrainDataSrc
00020 {
00021 public:
00022 TrainDataSrc(Table::AnnotationTable* annotation)
00023 {
00024 if(annotation)
00025 {
00026 Table::CriterionElement2NotEquals<Table::AnnotationTable> neq(0);
00027 mAnnotation = Select(annotation, neq);
00028 }
00029 else
00030 mAnnotation = new Table::AnnotationTable();
00031 mQuids = mAnnotation->GetAnnotatedQuids();
00032 mSelection = mQuids;
00033 }
00034
00035 virtual
00036 ~TrainDataSrc()
00037 {
00038 Clear();
00039 }
00040
00050 virtual svm_problem*
00051 MakeSvmProblem() = 0;
00052
00057 virtual svm_problem*
00058 MakeSvmProblem(int i) = 0;
00059
00060 virtual int
00061 GetVectorLength() = 0;
00062
00063 virtual Quid
00064 GetQuid(int i)
00065 {
00066 if(i >= mSelection->Size())
00067 ILOG_ERROR("GetQuid: i out of range:"<< i <<">="<<
00068 mSelection->Size());
00069 return mSelection->Get1(i);
00070 }
00071
00072 Table::QuidTable*
00073 GetSelection()
00074 {
00075 return mSelection;
00076 }
00077
00078 virtual int
00079 Size()
00080 {
00081 return mSelection->Size();
00082 }
00083
00084 virtual void
00085 FilterTestFold(int f, int foldCount, int repetition, bool episodeConstrained)
00086 {
00087 std::vector<Table::QuidTable*> folds =
00088 MakeFolds(f, foldCount, repetition, episodeConstrained);
00089 for(int i=0 ; i<folds.size() ; ++i)
00090 {
00091 if(f == i)
00092 mSelection = folds[f];
00093 else
00094 delete folds[i];
00095 }
00096
00097
00098
00099
00100 Sort(mSelection, 1, true);
00101 }
00102
00103 virtual void
00104 FilterTrainFold(int f, int foldCount, int repetition, bool episodeConstrained)
00105 {
00106 std::vector<Table::QuidTable*> folds =
00107 MakeFolds(f, foldCount, repetition, episodeConstrained);
00108 mSelection = new Table::QuidTable(0);
00109 ILOG_DEBUG_NODE("selection size = "<<mSelection->Size()<<")");
00110 for(int i=0 ; i<folds.size() ; ++i)
00111 {
00112 if(f != i)
00113 {
00114 mSelection->Append(folds[i]);
00115 }
00116 delete folds[i];
00117 }
00118 }
00119
00120 virtual void
00121 FreeProblem(svm_problem* p)
00122 {
00123 if(p->x)
00124 {
00125 delete p->x[0];
00126 delete p->x;
00127 }
00128 if(p->y)
00129 {
00130 delete p->y;
00131 }
00132 delete p;
00133 }
00134
00135 int
00136 GetTotalPositiveCount()
00137 {
00138 return mAnnotation->GetNrPositive();
00139 }
00140
00141 int
00142 GetTotalNegativeCount()
00143 {
00144 return mAnnotation->GetNrNegative();
00145 }
00146
00147 void
00148 PrintSelection()
00149 {
00150 for(int i=0 ; i<mSelection->Size() ; ++i)
00151 {
00152 Quid q = mSelection->Get1(i);
00153 std::cout << QuidObj(q).ToString() <<" - "<<
00154 mAnnotation->GetQualification(q) << std::endl;
00155 }
00156 }
00157
00158 virtual Array::Array2dScalarReal64*
00159 MakeDataCopy(int maxCol, int maxRow)
00160 {
00161 ILOG_ERROR("dummy implementation called");
00162 return Array::MakeFromValue<Array::Array2dScalarReal64>(0, 0, 0, 0, 0);
00163 }
00164
00165 protected:
00166 std::vector<Table::QuidTable*>
00167 MakeFolds(int f, int foldCount, int repetition, bool episodeConstrained)
00168 {
00169 ILOG_DEBUG_NODE("MakeFolds called with these params: "<< f <<", "<<
00170 foldCount <<" , "<< repetition);
00171 ClearSelection();
00172 if(f >= foldCount)
00173 {
00174 ILOG_ERROR("MakeFolds invalid input: "<< f <<" >= "<< foldCount);
00175 exit(1);
00176 }
00177 if(episodeConstrained)
00178 return mAnnotation->MakeEpisodeFolds(foldCount, repetition);
00179 else
00180 return mAnnotation->MakeRandomFolds(foldCount, repetition);
00181 }
00182
00183 void
00184 Clear()
00185 {
00186 ILOG_DEBUG("Clear() calling ClearSelection()");
00187 ClearSelection();
00188 ILOG_DEBUG("Clear() calling delete mQuids: "<<(void*)mQuids);
00189 delete mQuids;
00190 ILOG_DEBUG("Clear() calling delete mAnnotation");
00191 delete mAnnotation;
00192 }
00193
00194 void
00195 SetAnnotation(Table::AnnotationTable* anno)
00196 {
00197 Clear();
00198 mAnnotation = anno;
00199 mQuids = mAnnotation->GetAnnotatedQuids();
00200 mSelection = mQuids;
00201 }
00202
00203 svm_problem*
00204 MakeEmptyProblem()
00205 {
00206 svm_problem* p = new svm_problem;
00207 p->l = 0;
00208 p->x = 0;
00209 p->y = 0;
00210 return p;
00211 }
00212
00213 void
00214 ClearSelection()
00215 {
00216 if(mSelection != mQuids)
00217 delete mSelection;
00218 mSelection = mQuids;
00219 }
00220
00221 Table::AnnotationTable* mAnnotation;
00222 Table::QuidTable* mQuids;
00223 Table::QuidTable* mSelection;
00224
00225 ILOG_VAR_DECL;
00226 };
00227
00228 ILOG_VAR_INIT(TrainDataSrc, Impala.Core.Training);
00229
00230 }
00231 }
00232 }
00233
00234
00235 #endif