00001 #ifndef Impala_Core_Training_TrainDataSrcKernelDistributed_h 00002 #define Impala_Core_Training_TrainDataSrcKernelDistributed_h 00003 00004 #include "Core/Training/TrainDataSrc.h" 00005 #include "Core/Matrix/DistributedAccess.h" 00006 #include "Core/Table/AnnotationTable.h" 00007 #include "Link/Svm/LinkSvm.h" 00008 00009 namespace Impala 00010 { 00011 namespace Core 00012 { 00013 namespace Training 00014 { 00015 00016 00021 class TrainDataSrcKernelDistributed : public TrainDataSrc 00022 { 00023 public: 00024 TrainDataSrcKernelDistributed(Matrix::DistributedAccess* da, 00025 Table::AnnotationTable* annotation) 00026 : TrainDataSrc(annotation) 00027 { 00028 mDA = da; 00029 mKernelQuids = mDA->GetColumnQuids(); 00030 SelectValid(); 00031 } 00032 00033 virtual ~TrainDataSrcKernelDistributed() 00034 { 00035 } 00036 00037 virtual svm_problem* MakeSvmProblem() 00038 { 00039 set_distributed_access(mDA); 00040 return MakeProblem(mSelection); 00041 } 00042 00043 virtual svm_problem* MakeSvmProblem(int i) 00044 { 00045 set_distributed_access(mDA); 00046 if(i >= mSelection->Size()) 00047 { 00048 ILOG_WARNING("MakeSvmProblem(int) : index out of range"); 00049 return MakeEmptyProblem(); 00050 } 00051 Quid q = mSelection->Get1(i); 00052 Table::QuidTable t; 00053 t.Add(q); 00054 svm_problem* p = MakeProblem(&t); 00055 return p; 00056 } 00057 virtual int GetVectorLength() 00058 { 00059 return 1; 00060 } 00061 00062 virtual void SelectValid() 00063 { 00064 ILOG_INFO("SelectValid called, quids before: "<< mQuids->Size()); 00065 Table::CriterionElement1InSet<Table::AnnotationTable> eq(mKernelQuids); 00066 ILOG_INFO("nr quids in kernel quids = " << mKernelQuids->Size()); 00067 SetAnnotation(Select(mAnnotation, eq)); 00068 ILOG_INFO("quids after: "<< mQuids->Size()); 00069 } 00070 00071 private: 00072 Matrix::DistributedAccess* mDA; 00073 Table::QuidTable* mKernelQuids; 00074 00085 svm_problem* MakeProblem(Table::QuidTable* mask) 00086 { 00087 //CheckAnnoQuids(mask); 00088 00089 svm_problem* problem = new svm_problem; 00090 problem->l = mask->Size(); 00091 problem->y = new double[problem->l]; 00092 problem->x = new struct svm_node *[problem->l]; 00093 struct svm_node* nodes = new struct svm_node[problem->l*2]; 00094 for(int i=0 ; i<problem->l ; i++) 00095 { 00096 Quid q = mask->Get1(i); 00097 problem->y[i] = 0; 00098 if(mAnnotation->IsPositive(q)) 00099 problem->y[i] = 1; 00100 else if(mAnnotation->IsNegative(q)) 00101 problem->y[i] = -1; 00102 problem->x[i] = &nodes[i*2]; 00103 problem->x[i][0].index = 0; 00104 problem->x[i][0].value = mKernelQuids->GetIndex(q)+1; 00105 problem->x[i][1].index = -1; 00106 } 00107 return problem; 00108 } 00109 00110 /* 00111 Table::QuidTable* CheckAnnoQuids(Table::QuidTable* mask) 00112 { 00113 bool miss = false; 00114 for(int i=0 ; i<mask->Size() ; ++i) 00115 { 00116 Quid q = mask->Get1(i); 00117 int index = mKernelQuids->GetIndex(q); 00118 if(index >= mKernelQuids->Size()) 00119 { 00120 if(miss == false) 00121 { 00122 ILOG_ERROR("invalid quids in annotation " << 00123 "(keyframes/rkf mismatch?)"); 00124 ILOG_ERROR("CheckAnnoQuids: quids not found:"); 00125 } 00126 miss = true; 00127 std::cout << QuidObj(q) << "\t"; 00128 } 00129 } 00130 if(miss) 00131 { 00132 ILOG_ERROR("these are the kernel quids:"); 00133 //mKernelQuids->Dump(0, 0, -1); 00134 for(int i=0 ; i<mKernelQuids->Size() ; ++i) 00135 { 00136 std::cout << QuidObj(mKernelQuids->Get1(i)) <<"\t"; 00137 } 00138 } 00139 //Table::CriterionElement1InSet<Table::AnnotationTable> eq(mKernelQuids); 00140 //Table::CriterionElement2NotEquals<Table::AnnotationTable> neq(-1); 00141 //return Select(mask, eq); 00142 } 00143 */ 00144 00145 ILOG_VAR_DECL; 00146 }; 00147 00148 ILOG_VAR_INIT(TrainDataSrcKernelDistributed, Impala.Core.Training); 00149 00150 }//namespace Core 00151 }//namespace Training 00152 }//namespace Impala 00153 00154 00155 #endif