00001 #ifndef Impala_Core_Trec_TrecSVM_h
00002 #define Impala_Core_Trec_TrecSVM_h
00003
00004 #include "Basis/String.h"
00005 #include "Core/Array/ArrayListDelete.h"
00006 #include "Core/Array/Set.h"
00007
00008 #include "Core/Trec/svm.cpp"
00009
00010 namespace Impala
00011 {
00012 namespace Core
00013 {
00014 namespace Trec
00015 {
00016
00017
00018
00019
00020 class TrecSVM
00021 {
00022 public:
00023
00024 TrecSVM(int nrConcepts, int nrFeatures)
00025 {
00026 mNrConcepts = nrConcepts;
00027 mNrFeatures = nrFeatures;
00028 mModel = 0;
00029 mProbs = 0;
00030 mNodes = 0;
00031 mFirstTime = true;
00032 mScale = false;
00033 mRangeMin = new Real64[mNrFeatures];
00034 mRangeMax = new Real64[mNrFeatures];
00035 mDataFile = std::string("concept_data");
00036 }
00037
00038 void
00039 SetDataFile(std::string fileName)
00040 {
00041 mDataFile = fileName;
00042 }
00043
00044
00045
00046 void
00047 UpdateAnno(int con, int feat, Array2dScalarReal64* im)
00048 {
00049 if (feat == 0)
00050 if (mFeatList.size() > 0)
00051 ArrayListDelete<Array2dScalarReal64>(&mFeatList);
00052
00053 Array2dScalarReal64* cp = 0;
00054 Set(cp, im);
00055 mFeatList.push_back(cp);
00056
00057 if (feat == mNrFeatures-1)
00058 WriteResultsPixel(con);
00059 }
00060
00061 void
00062 WriteResultsPixel(int con)
00063 {
00064 FILE *fp;
00065 std::string wMode = (mFirstTime) ? "w" : "a";
00066 mFirstTime = false;
00067 if ((fp = fopen(mDataFile.c_str(), wMode.c_str())) == 0)
00068 {
00069 std::cout << "TrecSVM::WriteResults : Unable to write to "
00070 << mDataFile << std::endl;
00071 return;
00072 }
00073
00074 int maxSamples = 250;
00075 int width = mFeatList[0]->CW();
00076 int height = mFeatList[0]->CH();
00077 int skip = width * height / maxSamples;
00078 int s = skip;
00079 for (int y=0 ; y<height ; y++)
00080 {
00081 for (int x=0 ; x<width ; x++)
00082 {
00083 if (--s <= 0)
00084 {
00085 fprintf(fp, "%d", con);
00086 for (int f=0 ; f<mFeatList.size() ; f++)
00087 {
00088 Real64 val = *(mFeatList[f]->CPB(x, y));
00089 fprintf(fp, " %d:%f", f+1, val);
00090 }
00091 fprintf(fp, "\n");
00092 s = skip;
00093 }
00094 }
00095 }
00096 fclose(fp);
00097 }
00098
00099 void
00100 AddAnno(int con, Array2dScalarReal64* featVec)
00101 {
00102 if ((featVec->CW() * featVec->CH()) != mNrFeatures)
00103 {
00104 std::cout << "AddAnno : wrong number of features" << std::endl;
00105 return;
00106 }
00107 FILE *fp;
00108 std::string wMode = (mFirstTime) ? "w" : "a";
00109 mFirstTime = false;
00110 if ((fp = fopen(mDataFile.c_str(), wMode.c_str())) == 0)
00111 {
00112 std::cout << "TrecSVM::AddAnno : Unable to write to "
00113 << mDataFile << std::endl;
00114 return;
00115 }
00116
00117 fprintf(fp, "%d", con);
00118 int f = 1;
00119 for (int y=0 ; y<featVec->CH() ; y++)
00120 {
00121 for (int x=0 ; x<featVec->CW() ; x++)
00122 {
00123 Real64 val = *(featVec->CPB(x, y));
00124 fprintf(fp, " %d:%f", f++, val);
00125 }
00126 }
00127 fprintf(fp, "\n");
00128 fclose(fp);
00129 }
00130
00131 void
00132 LoadModel(std::string fileName)
00133 {
00134 if ((mModel = svm_load_model(fileName.c_str())) == 0)
00135 std::cout << "TrecSVM: unable to load model " << fileName << std::endl;
00136 if (!svm_check_probability_model(mModel))
00137 std::cout << "TrecSVM: not a probability model" << std::endl;
00138
00139
00140 mNodes = new struct svm_node[mNrFeatures+1];
00141 for (int i=0 ; i<mNrFeatures ; i++)
00142 mNodes[i].index = i+1;
00143 mNodes[mNrFeatures].index = -1;
00144
00145 mNrClass = svm_get_nr_class(mModel);
00146 mClassLabels = new int[mNrClass];
00147 svm_get_labels(mModel, mClassLabels);
00148
00149
00150 mProbs = new double[mNrClass];
00151 }
00152
00153 void
00154 LoadModel(std::string fileName, std::string rangeFile)
00155 {
00156 LoadModel(fileName);
00157
00158
00159 FILE* fp;
00160 if (! (fp = fopen(rangeFile.c_str(), "r")))
00161 {
00162 std::cout << "TrecSVM: unable to open " << rangeFile << std::endl;
00163 return;
00164 }
00165 char buf[200];
00166 int lineNr = 0;
00167 while (!feof(fp))
00168 {
00169 buf[0] = '\0';
00170 fgets(buf, 200, fp);
00171 if (strlen(buf) == 0)
00172 continue;
00173 lineNr++;
00174 if (lineNr == 1)
00175 continue;
00176 int idx;
00177 double minVal;
00178 double maxVal;
00179 sscanf(buf, "%d %lf %lf", &idx, &minVal, &maxVal);
00180 std::cout << "read : " << idx << " " << minVal << " " << maxVal << std::endl;
00181 mRangeMin[idx-1] = minVal;
00182 mRangeMax[idx-1] = maxVal;
00183 }
00184 fclose(fp);
00185 mScale = true;
00186 }
00187
00188 int
00189 GetNrClasses()
00190 {
00191 return mNrClass;
00192 }
00193
00194 int
00195 Predict(Real64 s1, Real64 s2, Real64 s3, Real64 s4, Real64 s5,
00196 Real64 s6, Real64 s7, Real64 s8, Real64 s9, Real64 s10,
00197 Real64 s11, Real64 s12)
00198 {
00199 mNodes[0].value = s1;
00200 mNodes[1].value = s2;
00201 mNodes[2].value = s3;
00202 mNodes[3].value = s4;
00203 mNodes[4].value = s5;
00204 mNodes[5].value = s6;
00205 mNodes[6].value = s7;
00206 mNodes[7].value = s8;
00207 mNodes[8].value = s9;
00208 mNodes[9].value = s10;
00209 mNodes[10].value = s11;
00210 mNodes[11].value = s12;
00211 return svm_predict(mModel, mNodes);
00212 }
00213
00214 int
00215 Predict(Real64 s1, Real64 s2, Real64 s3, Real64 s4, Real64 s5,
00216 Real64 s6, Real64 s7, Real64 s8, Real64 s9, Real64 s10,
00217 Real64 s11, Real64 s12, Real64 s13, Real64 s14,
00218 Real64 s15, Real64 s16, Real64 s17, Real64 s18,
00219 Real64 s19, Real64 s20, Real64 s21)
00220 {
00221 mNodes[0].value = s1;
00222 mNodes[1].value = s2;
00223 mNodes[2].value = s3;
00224 mNodes[3].value = s4;
00225 mNodes[4].value = s5;
00226 mNodes[5].value = s6;
00227 mNodes[6].value = s7;
00228 mNodes[7].value = s8;
00229 mNodes[8].value = s9;
00230 mNodes[9].value = s10;
00231 mNodes[10].value = s11;
00232 mNodes[11].value = s12;
00233 mNodes[12].value = s13;
00234 mNodes[13].value = s14;
00235 mNodes[14].value = s15;
00236 mNodes[15].value = s16;
00237 mNodes[16].value = s17;
00238 mNodes[17].value = s18;
00239 mNodes[18].value = s19;
00240 mNodes[19].value = s20;
00241 mNodes[20].value = s21;
00242 if (mScale)
00243 for (int i=0 ; i<mNrFeatures ; i++)
00244 mNodes[i].value = ((mNodes[i].value - mRangeMin[i]) /
00245 (mRangeMax[i] - mRangeMin[i]));
00246 return svm_predict(mModel, mNodes);
00247 }
00248
00249
00250
00251
00252
00253
00254
00255
00256
00257
00258
00259
00260
00261
00262
00263
00264
00265
00266
00267
00268
00269
00270
00271 private:
00272
00273 std::vector<Array2dScalarReal64*> mFeatList;
00274
00275 bool mFirstTime;
00276 int mNrConcepts;
00277 int mNrFeatures;
00278 bool mScale;
00279 Real64* mRangeMin;
00280 Real64* mRangeMax;
00281
00282 std::string mDataFile;
00283 struct svm_model* mModel;
00284 int mNrClass;
00285 int* mClassLabels;
00286 struct svm_node* mNodes;
00287 double* mProbs;
00288
00289 };
00290
00291 }
00292 }
00293 }
00294
00295 #endif