Home || Architecture || Video Search || Visual Search || Scripts || Applications || Important Messages || OGL || Src

TrecSVM.h

Go to the documentation of this file.
00001 #ifndef Impala_Core_Trec_TrecSVM_h
00002 #define Impala_Core_Trec_TrecSVM_h
00003 
00004 #include "Basis/String.h"
00005 #include "Core/Array/ArrayListDelete.h"
00006 #include "Core/Array/Set.h"
00007 //#include "Comp/Database/CxIdSetOps.h"
00008 #include "Core/Trec/svm.cpp"
00009 
00010 namespace Impala
00011 {
00012 namespace Core
00013 {
00014 namespace Trec
00015 {
00016 
00017 
00018 //using namespace Impala::Core::Array;
00019 
00020 class TrecSVM
00021 {
00022 public:
00023 
00024     TrecSVM(int nrConcepts, int nrFeatures)
00025     {
00026         mNrConcepts = nrConcepts;
00027         mNrFeatures = nrFeatures;
00028         mModel = 0;
00029         mProbs = 0;
00030         mNodes = 0;
00031         mFirstTime = true;
00032         mScale = false;
00033         mRangeMin = new Real64[mNrFeatures];
00034         mRangeMax = new Real64[mNrFeatures];
00035         mDataFile = std::string("concept_data");
00036     }
00037 
00038     void
00039     SetDataFile(std::string fileName)
00040     {
00041         mDataFile = fileName;
00042     }
00043 
00044     // each pixel in im is an example of the "feat"-th element of the feature
00045     // vector for concept "con" 
00046     void
00047     UpdateAnno(int con, int feat, Array2dScalarReal64* im)
00048     {
00049         if (feat == 0)
00050             if (mFeatList.size() > 0)
00051                 ArrayListDelete<Array2dScalarReal64>(&mFeatList);
00052 
00053         Array2dScalarReal64* cp = 0;
00054         Set(cp, im);
00055         mFeatList.push_back(cp);
00056 
00057         if (feat == mNrFeatures-1)
00058             WriteResultsPixel(con);
00059     }
00060 
00061     void
00062     WriteResultsPixel(int con)
00063     {
00064         FILE *fp;
00065         std::string wMode = (mFirstTime) ? "w" : "a";
00066         mFirstTime = false;
00067         if ((fp = fopen(mDataFile.c_str(), wMode.c_str())) == 0)
00068         {       
00069             std::cout << "TrecSVM::WriteResults : Unable to write to " 
00070                       << mDataFile << std::endl;
00071             return;
00072         }
00073 
00074         int maxSamples = 250;
00075         int width = mFeatList[0]->CW();
00076         int height = mFeatList[0]->CH();
00077         int skip = width * height / maxSamples;
00078         int s = skip;
00079         for (int y=0 ; y<height ; y++)
00080         {
00081             for (int x=0 ; x<width ; x++)
00082             {
00083                 if (--s <= 0)
00084                 {
00085                     fprintf(fp, "%d", con);
00086                     for (int f=0 ; f<mFeatList.size() ; f++)
00087                     {
00088                         Real64 val = *(mFeatList[f]->CPB(x, y));
00089                         fprintf(fp, " %d:%f", f+1, val);
00090                     }
00091                     fprintf(fp, "\n");
00092                     s = skip;
00093                 }
00094             }
00095         }
00096         fclose(fp);
00097     }
00098 
00099     void
00100     AddAnno(int con, Array2dScalarReal64* featVec)
00101     {
00102         if ((featVec->CW() * featVec->CH()) != mNrFeatures)
00103         {
00104             std::cout << "AddAnno : wrong number of features" << std::endl;
00105             return;
00106         }
00107         FILE *fp;
00108         std::string wMode = (mFirstTime) ? "w" : "a";
00109         mFirstTime = false;
00110         if ((fp = fopen(mDataFile.c_str(), wMode.c_str())) == 0)
00111         {       
00112             std::cout << "TrecSVM::AddAnno : Unable to write to " 
00113                       << mDataFile << std::endl;
00114             return;
00115         }
00116 
00117         fprintf(fp, "%d", con);
00118         int f = 1;
00119         for (int y=0 ; y<featVec->CH() ; y++)
00120         {
00121             for (int x=0 ; x<featVec->CW() ; x++)
00122             {
00123                 Real64 val = *(featVec->CPB(x, y));
00124                 fprintf(fp, " %d:%f", f++, val);
00125             }
00126         }
00127         fprintf(fp, "\n");
00128         fclose(fp);
00129     }
00130 
00131     void
00132     LoadModel(std::string fileName)
00133     {
00134         if ((mModel = svm_load_model(fileName.c_str())) == 0)
00135             std::cout << "TrecSVM: unable to load model " << fileName << std::endl;
00136         if (!svm_check_probability_model(mModel))
00137             std::cout << "TrecSVM: not a probability model" << std::endl;
00138 
00139         // setup svm_node array, assume fixed number of elements
00140         mNodes = new struct svm_node[mNrFeatures+1];
00141         for (int i=0 ; i<mNrFeatures ; i++)
00142             mNodes[i].index = i+1;
00143         mNodes[mNrFeatures].index = -1;
00144 
00145         mNrClass = svm_get_nr_class(mModel);
00146         mClassLabels = new int[mNrClass];
00147         svm_get_labels(mModel, mClassLabels);
00148         //for (int c=0 ; c<mNrClass ; c++)
00149         //    std::cout << "label " << c << " = " << labels[c] << std::endl;
00150         mProbs = new double[mNrClass];
00151     }
00152 
00153     void
00154     LoadModel(std::string fileName, std::string rangeFile)
00155     {
00156         LoadModel(fileName);
00157 
00158         // now load range file
00159         FILE* fp;
00160         if (! (fp = fopen(rangeFile.c_str(), "r")))
00161         {
00162             std::cout << "TrecSVM: unable to open " << rangeFile << std::endl;
00163             return;
00164         }
00165         char buf[200];
00166         int lineNr = 0;
00167         while (!feof(fp))
00168         {
00169             buf[0] = '\0';
00170             fgets(buf, 200, fp);
00171             if (strlen(buf) == 0)
00172                 continue;
00173             lineNr++;
00174             if (lineNr == 1)
00175                 continue; // skip first line
00176             int idx;
00177             double minVal;
00178             double maxVal;
00179             sscanf(buf, "%d %lf %lf", &idx, &minVal, &maxVal);
00180             std::cout << "read : " << idx << " " << minVal << " " << maxVal << std::endl;
00181             mRangeMin[idx-1] = minVal;
00182             mRangeMax[idx-1] = maxVal;
00183         }
00184         fclose(fp);
00185         mScale = true;
00186     }
00187 
00188     int
00189     GetNrClasses()
00190     {
00191         return mNrClass;
00192     }
00193 
00194     int
00195     Predict(Real64 s1, Real64 s2, Real64 s3, Real64 s4, Real64 s5,
00196             Real64 s6, Real64 s7, Real64 s8, Real64 s9, Real64 s10,
00197             Real64 s11, Real64 s12)
00198     {
00199         mNodes[0].value = s1;
00200         mNodes[1].value = s2;
00201         mNodes[2].value = s3;
00202         mNodes[3].value = s4;
00203         mNodes[4].value = s5;
00204         mNodes[5].value = s6;
00205         mNodes[6].value = s7;
00206         mNodes[7].value = s8;
00207         mNodes[8].value = s9;
00208         mNodes[9].value = s10;
00209         mNodes[10].value = s11;
00210         mNodes[11].value = s12;
00211         return svm_predict(mModel, mNodes);
00212     }
00213 
00214     int
00215     Predict(Real64 s1, Real64 s2, Real64 s3, Real64 s4, Real64 s5,
00216             Real64 s6, Real64 s7, Real64 s8, Real64 s9, Real64 s10,
00217             Real64 s11, Real64 s12, Real64 s13, Real64 s14,
00218             Real64 s15, Real64 s16, Real64 s17, Real64 s18,
00219             Real64 s19, Real64 s20, Real64 s21)
00220     {
00221         mNodes[0].value = s1;
00222         mNodes[1].value = s2;
00223         mNodes[2].value = s3;
00224         mNodes[3].value = s4;
00225         mNodes[4].value = s5;
00226         mNodes[5].value = s6;
00227         mNodes[6].value = s7;
00228         mNodes[7].value = s8;
00229         mNodes[8].value = s9;
00230         mNodes[9].value = s10;
00231         mNodes[10].value = s11;
00232         mNodes[11].value = s12;
00233         mNodes[12].value = s13;
00234         mNodes[13].value = s14;
00235         mNodes[14].value = s15;
00236         mNodes[15].value = s16;
00237         mNodes[16].value = s17;
00238         mNodes[17].value = s18;
00239         mNodes[18].value = s19;
00240         mNodes[19].value = s20;
00241         mNodes[20].value = s21;
00242         if (mScale)
00243             for (int i=0 ; i<mNrFeatures ; i++)
00244                 mNodes[i].value = ((mNodes[i].value - mRangeMin[i]) /
00245                                    (mRangeMax[i] - mRangeMin[i]));
00246         return svm_predict(mModel, mNodes);
00247     }
00248 
00249     /*
00250     int
00251     Predict(Array2dScalarReal64* featVec, CxIdSetWithWeight* probs)
00252     {
00253         if ((featVec->CW() * featVec->CH()) != mNrFeatures)
00254         {
00255             std::cout << "Predict : wrong number of features" << std::endl;
00256             return -1;
00257         }
00258         int i=0;
00259         for (int y=0 ; y<featVec->CH() ; y++)
00260             for (int x=0 ; x<featVec->CW() ; x++)
00261                 mNodes[i++].value = *CxArrayCPB(featVec, x, y);
00262         //return svm_predict(mModel, mNodes);
00263         double d = svm_predict_probability(mModel, mNodes, mProbs);
00264         if (probs)
00265             for (int i=0 ; i<mNrClass ; i++)
00266                 probs->AddIdWithWeight(mClassLabels[i], mProbs[i]);
00267         return d;
00268     }
00269     */
00270 
00271 private:
00272 
00273     std::vector<Array2dScalarReal64*> mFeatList;
00274 
00275     bool    mFirstTime;
00276     int     mNrConcepts; // to be replaced with mNrClass
00277     int     mNrFeatures;
00278     bool    mScale;
00279     Real64* mRangeMin;
00280     Real64* mRangeMax;
00281 
00282     std::string       mDataFile;
00283     struct svm_model* mModel;
00284     int               mNrClass;
00285     int*              mClassLabels;
00286     struct svm_node*  mNodes;
00287     double*           mProbs;
00288 
00289 };
00290 
00291 } // namespace Trec
00292 } // namespace Core
00293 } // namespace Impala
00294 
00295 #endif

Generated on Fri Mar 19 09:31:27 2010 for ImpalaSrc by  doxygen 1.5.1