00001 #ifndef Impala_Core_Training_SvmProblemBuilder_h
00002 #define Impala_Core_Training_SvmProblemBuilder_h
00003
00004 #include "Core/Table/AnnotationTable.h"
00005 #include "Core/Feature/FeatureTable.h"
00006 #include "Link/Svm/LinkSvm.h"
00007 #include "Util/SimpleMap.h"
00008
00009 namespace Impala
00010 {
00011 namespace Core
00012 {
00013 namespace Training
00014 {
00015
00026 class SvmProblemBuilder
00027 {
00028 public:
00029 SvmProblemBuilder(Table::AnnotationTable* annotation,
00030 Table::QuidTable* usedQuids)
00031 {
00032 mAnnotation = annotation;
00033 mUsedQuids = usedQuids;
00034 mModelLength = 0;
00035 }
00036
00037 virtual
00038 ~SvmProblemBuilder()
00039 {
00040
00041
00042 for(int i=0 ; i<mModelNodes.size() ; ++i)
00043 delete[] mModelNodes[i];
00044 }
00045
00046 void
00047 AddFeatureTable(Feature::FeatureTable* table)
00048 {
00049 if(table == 0)
00050 {
00051 ILOG_DEBUG("AddFeatureTable: empty table");
00052 return;
00053 }
00054
00055 for(int i=0 ; i<table->Size() ; ++i)
00056 {
00057 Quid q = table->Get1(i);
00058
00062 if(mUsedQuids->GetIndex(q) < mUsedQuids->Size())
00063 {
00064 Vector::VectorTem<double> v = table->Get2(i);
00065 AddFeature(q, &v);
00066 }
00067 }
00068 ILOG_DEBUG("AddFeatureTable: added " << table->Size() <<
00069 " features, current size = " << mModelLength);
00070 }
00071
00072 svm_problem*
00073 MakeProblem()
00074 {
00075
00076
00077 if(mModelLength != mModelLabels.size() ||
00078 mModelLength != mModelNodes.size())
00079 {
00080 ILOG_ERROR("invalid state: model length not consistent");
00081 Clear();
00082 return MakeProblem(0);
00083 }
00084 if(mModelLength > mUsedQuids->Size())
00085 {
00086 ILOG_ERROR_NODE("logic error, quitting");
00087 Clear();
00088 return MakeProblem(0);
00089 }
00090 if(mModelLength < mUsedQuids->Size())
00091 {
00092 ILOG_WARNING_NODE("not all quids of fold/annotation found in features"
00093 <<" model: "<< mModelLength <<", quids: "<<
00094 mUsedQuids->Size());
00095 }
00096
00097
00098
00099
00100 svm_problem* problem = Impala::MakeProblem(mModelLength);
00101 int modelIndex = 0;
00102 for(int i=0 ; i<mUsedQuids->Size() ; ++i)
00103 {
00104 Quid q = mUsedQuids->Get1(i);
00105
00106 int index;
00107 if(mQuidsIndices.Get(q, index))
00108 {
00109
00110 problem->y[modelIndex] = mModelLabels[index];
00111 problem->x[modelIndex] = mModelNodes[index];
00112 ++modelIndex;
00113 }
00114 else
00115 {
00116
00117
00118 }
00119 }
00120 ILOG_DEBUG("problem built; used quids: "<< mUsedQuids->Size() <<
00121 ", problem size: "<< mModelLength);
00122
00123
00124
00125 Clear();
00126 return problem;
00127 }
00128
00129 svm_problem*
00130 MakeProblem(Vector::VectorTem<double>* feature)
00131 {
00132 if(mModelLength > 0)
00133 ILOG_ERROR_NODE("wrong MakeProblem called: model not empty");
00134 AddFeature(0, feature);
00135 svm_problem* problem = Impala::MakeProblem(1);
00136 problem->x[0] = mModelNodes[0];
00137 Clear();
00138 return problem;
00139 }
00140
00141 private:
00142 void
00143 AddFeature(Quid q, Vector::VectorTem<double>* feature)
00144 {
00145 int length = feature->Size();
00146
00147 struct svm_node* nodes = new struct svm_node[length+1];
00148 const double* values = feature->GetData();
00149 for(int i=0 ; i<length ; ++i)
00150 {
00151 nodes[i].index = i+1;
00152 nodes[i].value = values[i];
00153 }
00154 nodes[length].index = -1;
00155 double label;
00156 label = ClassLabel(q);
00157
00158 mQuidsIndices.Add(q, mModelLabels.size());
00159 mModelNodes.push_back(nodes);
00160 mModelLabels.push_back(label);
00161 ++mModelLength;
00162 }
00163
00164 void
00165 Clear()
00166 {
00167 mModelLength = 0;
00168 mModelLabels.clear();
00169 mModelNodes.clear();
00170 mQuidsIndices.Clear();
00171
00172
00173 }
00174
00175 double
00176 ClassLabel(Quid q) const
00177 {
00178 if(mAnnotation == 0)
00179 return 0;
00180 if(mAnnotation->IsPositive(q))
00181 return 1;
00182 if(mAnnotation->IsNegative(q))
00183 return -1;
00184 return 0;
00185 }
00186
00187 void
00188 AddFeature(Quid q, Vector::VectorTem<double>* feature, double& label,
00189 svm_node* nodes)
00190 {
00191 }
00192
00193 Table::AnnotationTable* mAnnotation;
00194 Table::QuidTable* mUsedQuids;
00195
00196
00197 int mModelLength;
00198 std::deque<int> mModelLabels;
00199 std::deque<svm_node*> mModelNodes;
00200 Util::SimpleMap<Quid, int> mQuidsIndices;
00201
00202 ILOG_VAR_DECL;
00203 };
00204
00205 ILOG_VAR_INIT(SvmProblemBuilder, Impala.Core.Training);
00206
00207 }
00208 }
00209 }
00210
00211
00212 #endif