Home || Visual Search || Applications || Architecture || Important Messages || OGL || Src

Svm.h

Go to the documentation of this file.
00001 #ifndef Impala_Core_Training_Svm_h
00002 #define Impala_Core_Training_Svm_h
00003 
00004 #include "Util/FileCopy.h"
00005 #include "Core/Training/Classifier.h"
00006 #include "Core/Training/TrainDataSrcFeature.h"
00007 #include "Core/Training/Equals.h"
00008 #include "Core/Matrix/DistributedAccess.h"
00009 #include "Core/Table/ScoreTable.h"
00010 #include <iomanip>
00011 #include "Link/Svm/LinkSvm.h"
00012 #include "Util/TimeStats.h"
00013 #include "Persistency/SvmProblemRepository.h" // for testMode
00014 #include "Core/VideoSet/MakeVideoSet.h"
00015 
00016 namespace Impala
00017 {
00018 namespace Core
00019 {
00020 namespace Training
00021 {
00022 
00050 class Svm : public Classifier
00051 {
00052 public:
00053     Svm()
00054     {
00055         mModel = 0;
00056         mProbilityIndex = 0;
00057         mIsProbabilityModel = false;
00058     }
00059 
00060     Svm(String s)
00061     {
00062         mModel = 0;
00063         mProbilityIndex = 0;
00064         mIsProbabilityModel = false;
00065 
00066         Util::IOBuffer buf(s.size());
00067         buf.Puts(s);
00068         buf.Rewind();
00069 
00070         String tmpName = FileNameTmp();
00071         FileCopyRemoteToLocal(&buf, tmpName);
00072         svm_model* model = svm_load_model(tmpName.c_str());
00073         unlink(tmpName.c_str());
00074         if (!model)
00075         {
00076             ILOG_ERROR("Unable to load model");
00077         }
00078         else
00079         {
00080             SetModel(model);
00081         }
00082     }
00083 
00084     virtual
00085     ~Svm()
00086     {
00087         if(mModel)
00088             svm_destroy_model(mModel);
00089     }
00090 
00103     virtual void
00104     Train(Util::PropertySet* properties, TrainDataSrc* data)
00105     {
00106         if (properties->GetBool("autoweight"))
00107         {
00108             double pos = data->GetTotalPositiveCount();
00109             double neg = data->GetTotalNegativeCount();
00110             double posweight = (pos+neg) / pos;
00111             double negweight = (pos+neg) / neg;
00112             properties->Add("w1", posweight);
00113             properties->Add("w2", negweight);
00114             ILOG_DEBUG("autoweight: w+1=" << posweight << " w-1= " << negweight);
00115         }
00116         if ((properties->GetBool("probability")) &&
00117             (properties->GetString("kernel") == "rbf"))
00118         {
00119             //ILOG_WARNING("[Train] Random seed reset");
00120             std::srand(1);
00121         }
00122         svm_parameter* parameters = 
00123             MakeSvmParams(properties, data->GetVectorLength(), 0);
00124         svm_problem* p = data->MakeSvmProblem();
00125         CheckTestMode(p);
00126         ILOG_DEBUG_NODE("svm_train called with problem size = " << p->l);
00127         svm_model* model = svm_train(p, parameters);
00128         FixModelDependency(model);
00129         data->FreeProblem(p);
00130         DestroySvmParameters(parameters);
00131         SetModel(model);
00132     }
00133 
00134     virtual Table::ScoreTable*
00135     Predict(TrainDataSrc* data)
00136     {
00137         Util::TimeStats stats;
00138         stats.AddGroup("make problem");
00139         stats.AddGroup("predict");
00140         stats.AddGroup("free problem");
00141 
00142         Table::ScoreTable* result = new Table::ScoreTable();
00143         ILOG_DEBUG_NODE("Predict(): size of data = " << data->Size());
00144         for (int i=0 ; i<data->Size() ; ++i)
00145         {
00146             stats.MeasureFirst();
00147             svm_problem *temp = data->MakeSvmProblem(i);
00148             stats.MeasureNext();
00149             double score = PredictSingle(temp->x[0]);
00150             result->Add(data->GetQuid(i), score);
00151             stats.MeasureNext();
00152             data->FreeProblem(temp);
00153             stats.StopTime();
00154         }
00155         std::ostringstream oss;
00156         stats.Print(oss);
00157         //ILOG_DEBUG_NODE(oss.str());
00158         return result;
00159     }
00160 
00171     virtual void
00172     PredictForActiveLearn(Matrix::DistributedAccess* da,
00173                           Table::QuidTable* columnQuids,
00174                           Core::Table::SimilarityTableSet::SimTableType* result)
00175     {
00176         /* this function can be used when the number of column quids is very,
00177            very small and the matrix is very big.  This function will cache the
00178            whole (!) matrix part used, so the prediction is faster.  WARNING:
00179            the columnQuids are the quids you have learned from! So NOT the quids
00180            to apply to; it will be applied to all quids that are available.
00181         */
00182         if(!mModel)
00183         {
00184             ILOG_ERROR("[PredictForActiveLearn] classifier untrained\n");
00185             return;
00186         }
00187         
00188         result->SetSize(0);
00189         ILOG_INFO("starting PredictForActiveLearn");
00190         
00191         Table::QuidTable* rowQuids = da->GetRowQuids();
00192         Table::QuidTable* allColumns = da->GetColumnQuids();
00193         int rowCount = da->GetRows();
00194         
00195         std::vector<Quid>    sortedColumnQuids;
00196         std::vector<int>     sortedColumnQuidIndices;
00197         std::vector<Real64*> sortedColumnQuidData;
00198         for(int i = 0; i < allColumns->Size(); i++)
00199         {
00200             if(columnQuids->Contains(allColumns->Get1(i)))
00201             {
00202                 // we have this one
00203                 sortedColumnQuids.push_back(allColumns->Get1(i));
00204                 sortedColumnQuidIndices.push_back(i);
00205                 Real64* buf = new Real64[rowCount];
00206                 int received = da->GetColumn(i, buf, rowCount);
00207                 if(received != rowCount)
00208                     ILOG_ERROR("[PredictForActiveLearn] didn't receive column");
00209                 sortedColumnQuidData.push_back(buf);
00210                 //ILOG_INFO("Loaded " << i << " " << allColumns->Get1(i));
00211             }
00212         }
00213         if(sortedColumnQuids.size() != columnQuids->Size())
00214         {
00215             ILOG_ERROR("not all column  quids are in the table?!");
00216             return;
00217         }
00218 
00219         ILOG_INFO("Done loading");
00220         // PATCH THE MODEL
00221         mModel->param.kernel_type = PRECOMPUTED;
00222 
00223         int colCount = da->GetColumns();
00224         struct svm_node* x = new struct svm_node[colCount + 2];
00225         for(int i = 0; i < colCount + 2; i++)
00226         {
00227             x[i].index = 0;
00228             x[i].value = 0;
00229         }
00230         for(int i = 0; i < rowCount; i++)
00231         {
00232             x[0].index = 0;
00233             x[0].value = -1; // for predict this value is not necessary
00234             // fill it with values
00235             for(int j = 0; j < sortedColumnQuids.size(); j++)
00236             {
00237                 int index = sortedColumnQuidIndices[j]+1;
00238                 x[index].index = index;
00239                 x[index].value = sortedColumnQuidData[j][i];
00240             }
00241             x[colCount+1].index = -1; // end-of-vector marker
00242             x[colCount+1].value = 0;  // not used
00243 
00244             //ILOG_INFO("predict " << i);
00245             double score = PredictSingle(x);
00246             result->Add(score);
00247         }
00248         delete x;
00249         for(int i = 0; i < sortedColumnQuidData.size(); i++)
00250         {
00251             delete sortedColumnQuidData[i];
00252         }
00253         ILOG_INFO("done predict");
00254     }
00255 
00256     void
00257     SetModel(svm_model* model)
00258     {
00259         if (mModel)
00260             svm_destroy_model(mModel);
00261         mModel = model;
00262         if (svm_check_probability_model(mModel))
00263             mIsProbabilityModel = true;
00264         else
00265             mIsProbabilityModel = false;
00266 
00267         if (mIsProbabilityModel)
00268         {
00269             int labels[2];
00270             svm_get_labels(mModel, labels);
00271             if (labels[1] == 1)
00272                 mProbilityIndex = 1;
00273         }
00274     }
00275 
00276     virtual void
00277     OverrideModelOptions(Util::PropertySet* properties)
00278     {
00279         MakeSvmParams(properties, 1, &mModel->param);
00280     }
00281 
00282     const svm_model*
00283     GetModel()
00284     {
00285         return mModel;
00286     }
00287 
00288     bool
00289     Equals(const Svm* other) const
00290     {
00291         if (mModel==0 && other->mModel==0)
00292             return true;
00293         if (mModel==0 || other->mModel==0)
00294             return false;
00295 
00296         return Training::Equals(mModel, other->mModel);
00297     }
00298 
00299     int
00300     ReferenceDiff(const Svm* other) const
00301     {
00302         int diff = Diff(other);
00303         if (diff > 0)
00304             ILOG_ERROR("Found "<< diff <<" differences");
00305         return diff;
00306     }
00307     
00308     int
00309     Diff(const Svm* other) const
00310     {
00311         if (mModel==0 && other->mModel==0)
00312             return 0;
00313         if (mModel==0 || other->mModel==0)
00314             return 1;
00315         return Impala::Diff(mModel, other->mModel);
00316     }
00317 
00318 private:
00319 
00320     // Not the best place for this function, awaiting further refactoring
00321     void
00322     CheckTestMode(svm_problem* problem)
00323     {
00324         CmdOptions& options = CmdOptions::GetInstance();
00325         if (! options.GetBool("testMode"))
00326             return;
00327 
00328         VideoSet::VideoSet* vidSet =
00329             VideoSet::MakeVideoSet("trec2005fsd_try.txt");
00330         String filename = "anchor.problem.txt";
00331         Persistency::ModelLocator loc
00332             (vidSet->GetLocator(), "concepts.txt", "testmodel",
00333              "vissem_proto_annotation_nrScales_2_nrRects_130", "anchor");
00334         Persistency::SvmProblemRepository().Add(loc, problem);
00335         ILOG_INFO("TestMode, we're done");
00336         exit(0);
00337     }
00338 
00339     double
00340     PredictSingle(const svm_node* problem)
00341     {
00342         double score;
00343         if(mIsProbabilityModel)
00344         {
00345             double probabilities[2];
00346             svm_predict_probability(mModel, problem, probabilities);
00347             score = probabilities[mProbilityIndex];
00348         }
00349         else
00350         {
00351             score = svm_predict(mModel, problem);
00352         }
00353         return score;
00354     }
00355 
00356     svm_model* mModel;
00357     bool       mIsProbabilityModel;
00358     int        mProbilityIndex;
00359 
00360     ILOG_VAR_DEC;
00361 };
00362 
00363 ILOG_VAR_INIT(Svm, Impala.Core.Training);
00364 
00365 bool
00366 Equals(const Svm* svm1, const Svm* svm2)
00367 {
00368     return svm1->Equals(svm2);
00369 }
00370 
00371 int
00372 Diff(const Svm* svm1, const Svm* svm2)
00373 {
00374     return svm1->Diff(svm2);
00375 }
00376 
00377 }//namespace
00378 }//namespace
00379 }//namespace
00380 
00381 
00382 #endif

Generated on Thu Jan 13 09:04:41 2011 for ImpalaSrc by  doxygen 1.5.1