00001 #ifndef Impala_Core_Training_TrainDataSrcKernelDistributed_h
00002 #define Impala_Core_Training_TrainDataSrcKernelDistributed_h
00003
00004 #include "Core/Training/TrainDataSrc.h"
00005 #include "Core/Matrix/DistributedAccess.h"
00006 #include "Core/Table/AnnotationTable.h"
00007 #include "Link/Svm/LinkSvm.h"
00008
00009 namespace Impala
00010 {
00011 namespace Core
00012 {
00013 namespace Training
00014 {
00015
00016
00021 class TrainDataSrcKernelDistributed : public TrainDataSrc
00022 {
00023 public:
00024 TrainDataSrcKernelDistributed(Matrix::DistributedAccess* da,
00025 Table::AnnotationTable* annotation)
00026 : TrainDataSrc(annotation)
00027 {
00028 mDA = da;
00029 mKernelQuids = mDA->GetColumnQuids();
00030 SelectValid();
00031 }
00032
00033 virtual ~TrainDataSrcKernelDistributed()
00034 {
00035 }
00036
00037 virtual svm_problem* MakeSvmProblem()
00038 {
00039
00040 set_distributed_access(mDA);
00041 return MakeProblem(mSelection);
00042 }
00043
00044 virtual svm_problem* MakeSvmProblem(int i)
00045 {
00046
00047 set_distributed_access(mDA);
00048 if(i >= mSelection->Size())
00049 {
00050 ILOG_WARNING("MakeSvmProblem(int) : index out of range");
00051 return MakeEmptyProblem();
00052 }
00053 Quid q = mSelection->Get1(i);
00054 Table::QuidTable t;
00055 t.Add(q);
00056 svm_problem* p = MakeProblem(&t);
00057 return p;
00058 }
00059 virtual int GetVectorLength()
00060 {
00061 return 1;
00062 }
00063
00064 virtual void SelectValid()
00065 {
00066 ILOG_DEBUG("SelectValid called, quids before: "<< mQuids->Size());
00067 Table::CriterionElement1InSet<Table::AnnotationTable> eq(mKernelQuids);
00068 ILOG_DEBUG("nr quids in kernel quids = " << mKernelQuids->Size());
00069 SetAnnotation(Select(mAnnotation, eq));
00070 ILOG_DEBUG("quids after: "<< mQuids->Size());
00071 }
00072
00073 Matrix::Mat*
00074 MakeDataCopy(int maxCol, int maxRow)
00075 {
00076 if(maxCol == -1 || maxCol > mDA->GetColumns())
00077 maxCol = mDA->GetColumns();
00078 if(maxRow == -1 || maxRow > mDA->GetRows())
00079 maxRow = mDA->GetRows();
00080 Matrix::Mat* m = new Matrix::Mat(maxCol, maxRow, 0, 0);
00081 for(int x=0 ; x<maxCol ; ++x)
00082 {
00083 int received = mDA->GetColumn(x, m->CPB(0,x), maxRow);
00084 if(received != maxRow)
00085 {
00086 ILOG_WARNING("received " << received << " values i.s.o. " << maxRow);
00087 }
00088 }
00089 Matrix::MatTranspose(m);
00090 return m;
00091 }
00092
00093 private:
00094 Matrix::DistributedAccess* mDA;
00095 Table::QuidTable* mKernelQuids;
00096
00107 svm_problem* MakeProblem(Table::QuidTable* mask)
00108 {
00109
00110
00111 svm_problem* problem = new svm_problem;
00112 problem->l = mask->Size();
00113 problem->y = new double[problem->l];
00114 problem->x = new struct svm_node *[problem->l];
00115 struct svm_node* nodes = new struct svm_node[problem->l*2];
00116 for(int i=0 ; i<problem->l ; i++)
00117 {
00118 Quid q = mask->Get1(i);
00119 problem->y[i] = 0;
00120 if(mAnnotation->IsPositive(q))
00121 problem->y[i] = 1;
00122 else if(mAnnotation->IsNegative(q))
00123 problem->y[i] = -1;
00124 problem->x[i] = &nodes[i*2];
00125 problem->x[i][0].index = 0;
00126 problem->x[i][0].value = mKernelQuids->GetIndex(q)+1;
00127 problem->x[i][1].index = -1;
00128 }
00129 return problem;
00130 }
00131
00132 Table::QuidTable* CheckAnnoQuids(Table::QuidTable* mask)
00133 {
00134 bool miss = false;
00135 for(int i=0 ; i<mask->Size() ; ++i)
00136 {
00137 Quid q = mask->Get1(i);
00138 int index = mKernelQuids->GetIndex(q);
00139 if(index >= mKernelQuids->Size())
00140 {
00141 if(miss == false)
00142 {
00143 ILOG_ERROR("invalid quids in annotation (keyframes/rkf mismatch?)");
00144 ILOG_ERROR("CheckAnnoQuids: quids not found:");
00145 }
00146 miss = true;
00147 std::cout << QuidObj(q) << "\t";
00148 }
00149 }
00150 if(miss)
00151 {
00152 ILOG_ERROR("these are the kernel quids:");
00153
00154 for(int i=0 ; i<mKernelQuids->Size() ; ++i)
00155 {
00156 std::cout << QuidObj(mKernelQuids->Get1(i)) <<"\t";
00157 }
00158 }
00159
00160
00161
00162 }
00163 ILOG_VAR_DECL;
00164 };
00165
00166 ILOG_VAR_INIT(TrainDataSrcKernelDistributed, Impala.Core.Training);
00167
00168 }
00169 }
00170 }
00171
00172
00173 #endif