00001 #ifndef Impala_Core_Training_TrainDataSrcKernelTable_h
00002 #define Impala_Core_Training_TrainDataSrcKernelTable_h
00003 
00004 #include "Core/Array/Pattern/PatSet.h"
00005 #include "Core/Training/TrainDataSrc.h"
00006 #include "Core/Matrix/DistributedAccess.h"
00007 #include "Core/Table/AnnotationTable.h"
00008 #include "Link/Svm/LinkSvm.h"
00009 
00010 namespace Impala
00011 {
00012 namespace Core
00013 {
00014 namespace Training
00015 {
00016 
00017 
00018 class TrainDataSrcKernelTable : public TrainDataSrc
00019 {
00020 public:
00021     TrainDataSrcKernelTable(Feature::FeatureTable* kernelMatrix,
00022                             Table::AnnotationTable* annotation)
00023         : TrainDataSrc(annotation)
00024     {
00025         mTable = kernelMatrix;
00026         SelectValid();
00027     }
00028 
00029     virtual ~TrainDataSrcKernelTable()
00030     {
00031         delete mTable;
00032     }
00033 
00034     virtual svm_problem* MakeSvmProblem()
00035     {
00036         return MakeProblem(mSelection);
00037     }
00038 
00039     virtual svm_problem* MakeSvmProblem(int i)
00040     {
00041         if(i >= mSelection->Size())
00042         {
00043             ILOG_WARNING("MakeSvmProblem(int) : index out of range");
00044             return MakeEmptyProblem();
00045         }
00046         Quid q = mSelection->Get1(i);
00047         Table::QuidTable t;
00048         t.Add(q);
00049         svm_problem* p = MakeProblem(&t);
00050         return p;
00051     }
00052     virtual int GetVectorLength()
00053     {
00054         return 1;
00055     }
00056 
00057     virtual void SelectValid()
00058     {
00059         ILOG_DEBUG("SelectValid called");
00060         ILOG_DEBUG("size annotation: "<< mAnnotation->Size())
00061         Table::QuidTable* tableQuids = mTable->GetQuidTable();
00062         ILOG_DEBUG("nr table quids: "<< tableQuids->Size());
00063         Table::CriterionElement1InSet<Table::AnnotationTable> eq(tableQuids);
00064         SetAnnotation(Select(mAnnotation, eq));
00065         ILOG_DEBUG("size selection: "<< mSelection->Size())
00066         delete tableQuids;
00067     }
00068 
00069     Matrix::Mat*
00070     MakeDataCopy(int maxCol, int maxRow)
00071     {
00072         Matrix::Mat* storage = mTable->GetColumn2()->GetStorage();
00073         if(maxCol == -1 || maxCol > storage->CW())
00074             maxCol = storage->CW();
00075         if(maxRow == -1 || maxRow > storage->CH())
00076             maxRow = storage->CH();
00077         Matrix::Mat* m = new Matrix::Mat(maxCol, maxRow, 0, 0);
00078         Array::Pattern::PatSet(m, storage, 0, 0, maxCol, maxRow, 0, 0);
00079         return m;
00080     }
00081 
00082 private:
00083     svm_problem* MakeProblem(Table::QuidTable* mask)
00084     {
00085         int vectorLength = mTable->GetFeatureVectorLength();
00086         Table::QuidTable* quids = mTable->GetQuidTable();
00087         svm_problem* problem = new svm_problem;
00088         problem->l = mask->Size();
00089         problem->y = new double[problem->l];
00090         problem->x = new struct svm_node *[problem->l];
00091         struct svm_node* nodes = new struct svm_node[problem->l*(2+vectorLength)];
00092         for(int i=0 ; i<problem->l ; i++)
00093         {
00094             Quid q = mask->Get1(i);
00095             problem->y[i] = 0;
00096             if(mAnnotation->IsPositive(q))
00097                 problem->y[i] = 1;
00098             else if(mAnnotation->IsNegative(q))
00099                 problem->y[i] = -1;
00100 
00101             problem->x[i] = &nodes[i*(vectorLength+2)];
00102             problem->x[i][0].index = 0;
00103             problem->x[i][0].value = quids->GetIndex(q)+1;
00104             int rank = mTable->FindQuid(q);
00105             if(rank >= mTable->Size())
00106                 ILOG_ERROR("couldn't find quid in mTable table");
00107             Vector::VectorTem<double> vec = mTable->Get2(rank);
00108             const double* values = vec.GetData();
00109             int j;
00110             for(j=0 ; j<vectorLength ; j++)
00111             {
00112                 problem->x[i][j+1].index = j+1;
00113                 problem->x[i][j+1].value = values[j];
00114             }
00115             problem->x[i][j+1].index = -1;
00116         }
00117         return problem;
00118     }
00119     
00120     
00121 
00122 
00123 
00124 
00125 
00126 
00127 
00128 
00129 
00130 
00131 
00132 
00133 
00134 
00135 
00136 
00137 
00138 
00139 
00140 
00141 
00142 
00143 
00144 
00145 
00146 
00147 
00148 
00149 
00150 
00151 
00152 
00153 
00154 
00155 
00156 
00157 
00158     Feature::FeatureTable* mTable;
00159     ILOG_CLASS;
00160 };
00161 
00162 ILOG_CLASS_INIT(TrainDataSrcKernelTable, Impala.Core.Training.Svm);
00163 
00164 }
00165 }
00166 }
00167 
00168 
00169 #endif