00001 #ifndef Impala_Core_Training_TrainDataSrc_h
00002 #define Impala_Core_Training_TrainDataSrc_h
00003
00004 #include "Core/Array/MakeFromValue.h"
00005 #include "Link/Svm/LinkSvm.h"
00006 #include "Core/VideoSet/VideoSet.h"
00007 #include "Core/Feature/FeatureDefinition.h"
00008 #include "Core/Table/QuidTable.h"
00009 #include "Core/Table/Sort.h"
00010 #include "Core/Table/AnnotationTable.h"
00011
00012 namespace Impala
00013 {
00014 namespace Core
00015 {
00016 namespace Training
00017 {
00018
00019
00042 class TrainDataSrc
00043 {
00044 public:
00045 TrainDataSrc(Table::AnnotationTable* annotation)
00046 {
00047 if(annotation)
00048 {
00049 Table::CriterionElement2NotEquals<Table::AnnotationTable> neq(0);
00050 mAnnotation = Select(annotation, neq);
00051 }
00052 else
00053 mAnnotation = new Table::AnnotationTable();
00054 mQuids = mAnnotation->GetAnnotatedQuids();
00055 mSelection = mQuids;
00056 }
00057
00058 virtual
00059 ~TrainDataSrc()
00060 {
00061 Clear();
00062 }
00063
00073 virtual svm_problem*
00074 MakeSvmProblem() = 0;
00075
00080 virtual svm_problem*
00081 MakeSvmProblem(int i) = 0;
00082
00083 virtual int
00084 GetVectorLength() = 0;
00085
00086 virtual Quid
00087 GetQuid(int i)
00088 {
00089 if(i >= mSelection->Size())
00090 ILOG_ERROR("GetQuid: i out of range:"<< i <<">="<<
00091 mSelection->Size());
00092 return mSelection->Get1(i);
00093 }
00094
00095 virtual int
00096 Size()
00097 {
00098 return mSelection->Size();
00099 }
00100
00101 virtual void
00102 FilterTestFold(int f, int foldCount, int repetition,
00103 bool episodeConstrained, int restrictSet)
00104 {
00105 std::vector<Table::QuidTable*> folds =
00106 MakeFolds(f, foldCount, repetition, episodeConstrained);
00107 for(int i=0 ; i<folds.size() ; ++i)
00108 {
00109 if(f == i)
00110 mSelection = folds[f];
00111 else
00112 delete folds[i];
00113 }
00114
00115 if (restrictSet != -1)
00116 {
00117 int oldSize = mSelection->Size();
00118 Table::QuidTable* tmp = 0;
00119 Table::CriterionQuidSetEquals<Table::QuidTable> crit(restrictSet);
00120 tmp = Select(mSelection, crit);
00121 bool hasPositive = false;
00122 for (int i=0 ; i<tmp->Size() ; i++)
00123 {
00124 if (mAnnotation->IsPositive(tmp->Get1(i)))
00125 {
00126 hasPositive = true;
00127 break;
00128 }
00129 }
00130 if (hasPositive)
00131 {
00132 delete mSelection;
00133 mSelection = tmp;
00134 ILOG_INFO("FilterTestFold: restrictSetId: size from " << oldSize
00135 << " to " << mSelection->Size());
00136 }
00137 else
00138 {
00139 ILOG_WARN("FilterTestFold: restrict would eliminiate positives");
00140 delete tmp;
00141 }
00142 }
00143
00144
00145
00146
00147
00148 Sort(mSelection, 1, true);
00149 }
00150
00151 virtual void
00152 FilterTrainFold(int f, int foldCount, int repetition, bool episodeConstrained)
00153 {
00154 std::vector<Table::QuidTable*> folds =
00155 MakeFolds(f, foldCount, repetition, episodeConstrained);
00156 mSelection = new Table::QuidTable(0);
00157 ILOG_DEBUG_NODE("selection size = "<<mSelection->Size()<<")");
00158 for(int i=0 ; i<folds.size() ; ++i)
00159 {
00160 if(f != i)
00161 {
00162 mSelection->Append(folds[i]);
00163 }
00164 delete folds[i];
00165 }
00166 }
00167
00168 virtual void
00169 FreeProblem(svm_problem* p)
00170 {
00171 if(p->x)
00172 {
00173 delete p->x[0];
00174 delete p->x;
00175 }
00176 if(p->y)
00177 {
00178 delete p->y;
00179 }
00180 delete p;
00181 }
00182
00183 int
00184 GetTotalPositiveCount()
00185 {
00186 return mAnnotation->GetNrPositive();
00187 }
00188
00189 int
00190 GetTotalNegativeCount()
00191 {
00192 return mAnnotation->GetNrNegative();
00193 }
00194
00195 void
00196 PrintSelection()
00197 {
00198 std::cout << "selection: size = " << mSelection->Size() << std::endl;
00199 for(int i=0 ; i<mSelection->Size() ; ++i)
00200 {
00201 Quid q = mSelection->Get1(i);
00202 std::cout << QuidObj(q).ToString() <<" - "<<
00203 mAnnotation->GetQualification(q) << std::endl;
00204 }
00205 }
00206
00207 protected:
00208 std::vector<Table::QuidTable*>
00209 MakeFolds(int f, int foldCount, int repetition, bool episodeConstrained)
00210 {
00211 ILOG_DEBUG_NODE("MakeFolds called with these params: "<< f <<", "<<
00212 foldCount <<" , "<< repetition);
00213 ClearSelection();
00214 if(f >= foldCount)
00215 {
00216 ILOG_ERROR("MakeFolds invalid input: "<< f <<" >= "<< foldCount);
00217 exit(1);
00218 }
00219 if(episodeConstrained)
00220 return mAnnotation->MakeEpisodeFolds(foldCount, repetition);
00221 else
00222 return mAnnotation->MakeRandomFolds(foldCount, repetition);
00223 }
00224
00225 void
00226 Clear()
00227 {
00228 ILOG_DEBUG("Clear() calling ClearSelection()");
00229 ClearSelection();
00230 ILOG_DEBUG("Clear() calling delete mQuids: "<<(void*)mQuids);
00231 delete mQuids;
00232 ILOG_DEBUG("Clear() calling delete mAnnotation");
00233 delete mAnnotation;
00234 }
00235
00236 void
00237 SetAnnotation(Table::AnnotationTable* anno)
00238 {
00239 Clear();
00240 mAnnotation = anno;
00241 mQuids = mAnnotation->GetAnnotatedQuids();
00242 mSelection = mQuids;
00243 }
00244
00245 svm_problem*
00246 MakeEmptyProblem()
00247 {
00248 svm_problem* p = new svm_problem;
00249 p->l = 0;
00250 p->x = 0;
00251 p->y = 0;
00252 return p;
00253 }
00254
00255 void
00256 ClearSelection()
00257 {
00258 if(mSelection != mQuids)
00259 delete mSelection;
00260 mSelection = mQuids;
00261 }
00262
00263 Table::AnnotationTable* mAnnotation;
00264 Table::QuidTable* mQuids;
00265 Table::QuidTable* mSelection;
00266
00267 ILOG_VAR_DECL;
00268 };
00269
00270 ILOG_VAR_INIT(TrainDataSrc, Impala.Core.Training);
00271
00272 }
00273 }
00274 }
00275
00276
00277 #endif