00001 #ifndef Impala_Core_Training_TrainDataSrcKernelTable_h
00002 #define Impala_Core_Training_TrainDataSrcKernelTable_h
00003
00004 #include "Core/Array/Pattern/PatSet.h"
00005 #include "Core/Training/TrainDataSrc.h"
00006 #include "Core/Matrix/DistributedAccess.h"
00007 #include "Core/Table/AnnotationTable.h"
00008 #include "Link/Svm/LinkSvm.h"
00009
00010 namespace Impala
00011 {
00012 namespace Core
00013 {
00014 namespace Training
00015 {
00016
00017
00018 class TrainDataSrcKernelTable : public TrainDataSrc
00019 {
00020 public:
00021 TrainDataSrcKernelTable(Feature::FeatureTable* kernelMatrix,
00022 Table::AnnotationTable* annotation)
00023 : TrainDataSrc(annotation)
00024 {
00025 mTable = kernelMatrix;
00026 SelectValid();
00027 }
00028
00029 virtual ~TrainDataSrcKernelTable()
00030 {
00031 delete mTable;
00032 }
00033
00034 virtual svm_problem* MakeSvmProblem()
00035 {
00036 return MakeProblem(mSelection);
00037 }
00038
00039 virtual svm_problem* MakeSvmProblem(int i)
00040 {
00041 if(i >= mSelection->Size())
00042 {
00043 ILOG_WARNING("MakeSvmProblem(int) : index out of range");
00044 return MakeEmptyProblem();
00045 }
00046 Quid q = mSelection->Get1(i);
00047 Table::QuidTable t;
00048 t.Add(q);
00049 svm_problem* p = MakeProblem(&t);
00050 return p;
00051 }
00052 virtual int GetVectorLength()
00053 {
00054 return 1;
00055 }
00056
00057 virtual void SelectValid()
00058 {
00059 ILOG_DEBUG("SelectValid called");
00060 ILOG_DEBUG("size annotation: "<< mAnnotation->Size())
00061 Table::QuidTable* tableQuids = mTable->GetQuidTable();
00062 ILOG_DEBUG("nr table quids: "<< tableQuids->Size());
00063 Table::CriterionElement1InSet<Table::AnnotationTable> eq(tableQuids);
00064 SetAnnotation(Select(mAnnotation, eq));
00065 ILOG_DEBUG("size selection: "<< mSelection->Size())
00066 delete tableQuids;
00067 }
00068
00069 Matrix::Mat*
00070 MakeDataCopy(int maxCol, int maxRow)
00071 {
00072 Matrix::Mat* storage = mTable->GetColumn2()->GetStorage();
00073 if(maxCol == -1 || maxCol > storage->CW())
00074 maxCol = storage->CW();
00075 if(maxRow == -1 || maxRow > storage->CH())
00076 maxRow = storage->CH();
00077 Matrix::Mat* m = new Matrix::Mat(maxCol, maxRow, 0, 0);
00078 Array::Pattern::PatSet(m, storage, 0, 0, maxCol, maxRow, 0, 0);
00079 return m;
00080 }
00081
00082 private:
00083 svm_problem* MakeProblem(Table::QuidTable* mask)
00084 {
00085 int vectorLength = mTable->GetFeatureVectorLength();
00086 Table::QuidTable* quids = mTable->GetQuidTable();
00087 svm_problem* problem = new svm_problem;
00088 problem->l = mask->Size();
00089 problem->y = new double[problem->l];
00090 problem->x = new struct svm_node *[problem->l];
00091 struct svm_node* nodes = new struct svm_node[problem->l*(2+vectorLength)];
00092 for(int i=0 ; i<problem->l ; i++)
00093 {
00094 Quid q = mask->Get1(i);
00095 problem->y[i] = 0;
00096 if(mAnnotation->IsPositive(q))
00097 problem->y[i] = 1;
00098 else if(mAnnotation->IsNegative(q))
00099 problem->y[i] = -1;
00100
00101 problem->x[i] = &nodes[i*(vectorLength+2)];
00102 problem->x[i][0].index = 0;
00103 problem->x[i][0].value = quids->GetIndex(q)+1;
00104 int rank = mTable->FindQuid(q);
00105 if(rank >= mTable->Size())
00106 ILOG_ERROR("couldn't find quid in mTable table");
00107 Vector::VectorTem<double> vec = mTable->Get2(rank);
00108 const double* values = vec.GetData();
00109 int j;
00110 for(j=0 ; j<vectorLength ; j++)
00111 {
00112 problem->x[i][j+1].index = j+1;
00113 problem->x[i][j+1].value = values[j];
00114 }
00115 problem->x[i][j+1].index = -1;
00116 }
00117 return problem;
00118 }
00119
00120
00121
00122
00123
00124
00125
00126
00127
00128
00129
00130
00131
00132
00133
00134
00135
00136
00137
00138
00139
00140
00141
00142
00143
00144
00145
00146
00147
00148
00149
00150
00151
00152
00153
00154
00155
00156
00157
00158 Feature::FeatureTable* mTable;
00159 ILOG_CLASS;
00160 };
00161
00162 ILOG_CLASS_INIT(TrainDataSrcKernelTable, Impala.Core.Training.Svm);
00163
00164 }
00165 }
00166 }
00167
00168
00169 #endif