00001 #ifndef Impala_Core_Training_TrainDataSrcKernelMatrix_h
00002 #define Impala_Core_Training_TrainDataSrcKernelMatrix_h
00003
00004 #include "Core/Training/TrainDataSrc.h"
00005 #include "Core/Table/AnnotationTable.h"
00006 #include "Link/Svm/LinkSvm.h"
00007
00008 namespace Impala
00009 {
00010 namespace Core
00011 {
00012 namespace Training
00013 {
00014
00015
00022 class TrainDataSrcKernelMatrix : public TrainDataSrc
00023 {
00024 public:
00025 TrainDataSrcKernelMatrix(Matrix::Mat* kernelMatrix)
00026 : TrainDataSrc(0)
00027 {
00028 mMatrix = kernelMatrix;
00029 }
00030
00031 virtual ~TrainDataSrcKernelMatrix()
00032 {
00033 }
00034
00035 virtual svm_problem* MakeSvmProblem()
00036 {
00037 return MakeProblem(0, mMatrix->CW());
00038 }
00039
00040 virtual svm_problem* MakeSvmProblem(int i)
00041 {
00042 if(i >= mMatrix->CH())
00043 {
00044 ILOG_WARNING("MakeSvmProblem(int) : index out of range");
00045 return MakeEmptyProblem();
00046 }
00047 svm_problem* p = MakeProblem(i, i+1);
00048 return p;
00049 }
00050 virtual int GetVectorLength()
00051 {
00052 return 1;
00053 }
00054
00055 virtual int
00056 Size()
00057 {
00058 ILOG_DEBUG("correct Size called");
00059 return mMatrix->CH();
00060 }
00061
00062 virtual Quid
00063 GetQuid(int i)
00064 {
00065 ILOG_DEBUG("correct GetQuid called, but a matrix doesn't have quids");
00066 return 0;
00067 }
00068
00069
00070 private:
00071 Matrix::Mat* mMatrix;
00072
00084 svm_problem* MakeProblem(int start, int end)
00085 {
00086 ILOG_VAR(Core.Training.Svm.KernelSvmProblem);
00087 int vectorLength = mMatrix->CW();
00088 svm_problem* problem = new svm_problem;
00089 problem->l = end - start;
00090 problem->y = new double[problem->l];
00091 problem->x = new struct svm_node *[problem->l];
00092 struct svm_node* nodes = new struct svm_node[problem->l*(2+vectorLength)];
00093 for(int i=0 ; i<problem->l ; i++)
00094 {
00095 problem->y[i] = 0;
00096 problem->x[i] = &nodes[i*(vectorLength+2)];
00097 problem->x[i][0].index = 0;
00098 problem->x[i][0].value = 0;
00099 const double* values = mMatrix->CPB(0, start+i);
00100 int j;
00101 for(j=0 ; j<vectorLength ; j++)
00102 {
00103 problem->x[i][j+1].index = j+1;
00104 problem->x[i][j+1].value = values[j];
00105 }
00106 problem->x[i][j+1].index = -1;
00107 }
00108 return problem;
00109 }
00110 };
00111
00112 }
00113 }
00114 }
00115
00116
00117 #endif