Home || Architecture || Video Search || Visual Search || Scripts || Applications || Important Messages || OGL || Src

Svm.h

Go to the documentation of this file.
00001 #ifndef Impala_Core_Training_Svm_h
00002 #define Impala_Core_Training_Svm_h
00003 
00004 #include "Util/FileCopy.h"
00005 #include "Core/Training/Classifier.h"
00006 #include "Core/Training/TrainDataSrcFeature.h"
00007 #include "Core/Training/Equals.h"
00008 #include "Core/Matrix/DistributedAccess.h"
00009 #include "Core/Table/ScoreTable.h"
00010 #include <iomanip>
00011 #include "Link/Svm/LinkSvm.h"
00012 #include "Util/TimeStats.h"
00013 
00014 namespace Impala
00015 {
00016 namespace Core
00017 {
00018 namespace Training
00019 {
00020 
00021 svm_problem* ConvertToSvmProblem(const Core::Vector::VectorTem<double>* feature);
00022 
00049 class Svm : public Classifier
00050 {
00051 public:
00052     Svm()
00053     {
00054         mModel = 0;
00055         mProbilityIndex = 0;
00056         mIsProbabilityModel = false;
00057     }
00058 
00059     virtual
00060     ~Svm()
00061     {
00062         if(mModel)
00063             svm_destroy_model(mModel);
00064     }
00065 
00066     virtual void
00067     Train(Util::PropertySet* properties, TrainDataSrc* data)
00068     {
00069         if(properties->GetBool("autoweight"))
00070         {
00071             double pos = data->GetTotalPositiveCount();
00072             double neg = data->GetTotalNegativeCount();
00073             double posweight = (pos+neg) / pos;
00074             double negweight = (pos+neg) / neg;
00075             properties->Add("w1", posweight);
00076             properties->Add("w2", negweight);
00077             ILOG_DEBUG("autoweight: w+1=" << posweight << " w-1= " << negweight);
00078         }
00079         svm_parameter* parameters = 
00080             MakeSvmParams(properties, data->GetVectorLength(), 0);
00081         svm_problem* p = data->MakeSvmProblem();
00082         ILOG_DEBUG_NODE("svm_train called with problem size = " << p->l);
00083         svm_model* model = svm_train(p, parameters);
00084         FixModelDependency(model);
00085         data->FreeProblem(p);
00086         DestroySvmParameters(parameters);
00087         SetModel(model);
00088     }
00089 
00090     virtual Table::ScoreTable*
00091     Predict(TrainDataSrc* data)
00092     {
00093         //ILOG_DEBUG_NODE("predict called, datasize = " << data->Size());
00094         Util::TimeStats stats;
00095         stats.AddGroup("make problem");
00096         stats.AddGroup("predict");
00097         stats.AddGroup("free problem");
00098 
00099         Table::ScoreTable* result = new Table::ScoreTable();
00100         ILOG_DEBUG_NODE("Predict(): size of data = " << data->Size());
00101         for (int i=0 ; i<data->Size() ; ++i)
00102         {
00103             stats.MeasureFirst();
00104             ILOG_DEBUG_NODE("making svm problem");
00105             svm_problem *temp = data->MakeSvmProblem(i);
00106             stats.MeasureNext();
00107             ILOG_DEBUG_NODE("calling predictsingle");
00108             double score = PredictSingle(temp->x[0]);
00109             ILOG_DEBUG_NODE("adding result");
00110             result->Add(data->GetQuid(i), score);
00111             stats.MeasureNext();
00112             ILOG_DEBUG_NODE("freeing problem");
00113             data->FreeProblem(temp);
00114             stats.StopTime();
00115         }
00116         std::ostringstream oss;
00117         stats.Print(oss);
00118         //ILOG_DEBUG_NODE(oss.str());
00119         return result;
00120     }
00121 
00132     virtual void
00133     PredictForActiveLearn(Matrix::DistributedAccess& da,
00134                           Table::QuidTable* columnQuids,
00135                           Core::Table::SimilarityTableSet::SimTableType* result)
00136     {
00137         /* this function can be used when the number of column quids is very,
00138            very small and the matrix is very big.  This function will cache the
00139            whole (!) matrix part used, so the prediction is faster.  WARNING:
00140            the columnQuids are the quids you have learned from! So NOT the quids
00141            to apply to; it will be applied to all quids that are available.
00142         */
00143         if(!mModel)
00144         {
00145             ILOG_ERROR("[PredictForActiveLearn] classifier untrained\n");
00146             return;
00147         }
00148         
00149         result->SetSize(0);
00150         ILOG_INFO("starting PredictForActiveLearn");
00151         
00152         Table::QuidTable* rowQuids = da.GetRowQuids();
00153         Table::QuidTable* allColumns = da.GetColumnQuids();
00154         int rowCount = da.GetRows();
00155         
00156         std::vector<Quid>    sortedColumnQuids;
00157         std::vector<int>     sortedColumnQuidIndices;
00158         std::vector<Real64*> sortedColumnQuidData;
00159         for(int i = 0; i < allColumns->Size(); i++)
00160         {
00161             if(columnQuids->Contains(allColumns->Get1(i)))
00162             {
00163                 // we have this one
00164                 sortedColumnQuids.push_back(allColumns->Get1(i));
00165                 sortedColumnQuidIndices.push_back(i);
00166                 Real64* buf = new Real64[rowCount];
00167                 int received = da.GetColumn(i, buf, rowCount);
00168                 if(received != rowCount)
00169                     ILOG_ERROR("[PredictForActiveLearn] didn't receive column");
00170                 sortedColumnQuidData.push_back(buf);
00171                 //ILOG_INFO("Loaded " << i << " " << allColumns->Get1(i));
00172             }
00173         }
00174         if(sortedColumnQuids.size() != columnQuids->Size())
00175         {
00176             ILOG_ERROR("not all column  quids are in the table?!");
00177             return;
00178         }
00179 
00180         ILOG_INFO("Done loading");
00181         // PATCH THE MODEL
00182         mModel->param.kernel_type = PRECOMPUTED;
00183 
00184         int colCount = da.GetColumns();
00185         struct svm_node* x = new struct svm_node[colCount + 2];
00186         for(int i = 0; i < colCount + 2; i++)
00187         {
00188             x[i].index = 0;
00189             x[i].value = 0;
00190         }
00191         for(int i = 0; i < rowCount; i++)
00192         {
00193             x[0].index = 0;
00194             x[0].value = -1; // for predict this value is not necessary
00195             // fill it with values
00196             for(int j = 0; j < sortedColumnQuids.size(); j++)
00197             {
00198                 int index = sortedColumnQuidIndices[j];
00199                 x[index].index = index;
00200                 x[index].value = sortedColumnQuidData[j][i];
00201             }
00202             x[colCount+1].index = -1; // end-of-vector marker
00203             x[colCount+1].value = 0;  // not used
00204 
00205             //ILOG_INFO("predict " << i);
00206             double score = PredictSingle(x);
00207             result->Add(score);
00208         }
00209         delete x;
00210         for(int i = 0; i < sortedColumnQuidData.size(); i++)
00211         {
00212             delete sortedColumnQuidData[i];
00213         }
00214         ILOG_INFO("done predict");
00215     }
00216 
00224     double
00225     Predict(const Impala::Core::Vector::VectorTem<double>* feature)
00226     {
00227         if(!mModel)
00228         {
00229             ILOG_ERROR("[Svm::Predict] untrained classifier can't predict");
00230             return 0.5;
00231         }
00232         svm_problem* p = ConvertToSvmProblem(feature);
00233         double score = PredictSingle(p->x[0]);
00234         DestroySvmProblem(p);
00235         return score;
00236     }
00237 
00238     void
00239     LoadModel(const std::string& name, Util::Database* db)
00240     {
00241         svm_model* model = 0;
00242         if (db->GetDataChannel())
00243         {
00244             std::string tmpName = FileNameTmp();
00245             Util::IOBuffer* buf = db->GetIOBuffer(name, true, false, "");
00246             FileCopyRemoteToLocal(buf, tmpName);
00247             delete buf;
00248             model = svm_load_model(tmpName.c_str());
00249             unlink(tmpName.c_str());
00250         }
00251         else
00252         {
00253             model = svm_load_model(name.c_str());
00254         }
00255         if(!model)
00256         {
00257             ILOG_ERROR("[LoadModel] unable to load " << name);
00258         }
00259         else
00260             SetModel(model);
00261     }
00262 
00263     void
00264     SaveModel(const std::string& name, Util::Database* db)
00265     {
00266         if (db->GetDataChannel())
00267         {
00268             std::string tmpName = FileNameTmp();
00269             if(svm_save_model(tmpName.c_str(), mModel) == -1)
00270                 ILOG_ERROR("[SaveModel] unable to save " << name);
00271             Util::FileCopyLocalToRemote(tmpName, name);
00272             unlink(tmpName.c_str());
00273         }
00274         else
00275         {
00276             if(svm_save_model(name.c_str(), mModel) == -1)
00277                 ILOG_ERROR("[Svm::SaveModel] unable to save " << name);
00278         }
00279     }
00280 
00281     void
00282     OverrideModelOptions(Util::PropertySet* properties)
00283     {
00284         MakeSvmParams(properties, 1, &mModel->param);
00285     }
00286 
00287     const svm_model*
00288     GetModel()
00289     {
00290         return mModel;
00291     }
00292 
00293     bool
00294     Equals(const Svm* other) const
00295     {
00296         if(mModel==0 && other->mModel==0)
00297             return true;
00298         if(mModel==0 || other->mModel==0)
00299             return false;
00300 
00301         return Training::Equals(mModel, other->mModel);
00302     }
00303 
00304 private:
00305     double
00306     PredictSingle(const svm_node* problem)
00307     {
00308         double score;
00309         if(mIsProbabilityModel)
00310         {
00311             double probabilities[2];
00312             svm_predict_probability(mModel, problem, probabilities);
00313             score = probabilities[mProbilityIndex];
00314         }
00315         else
00316         {
00317             score = svm_predict(mModel, problem);
00318         }
00319         //ILOG_DEBUG(problem[0].value <<", "<< problem[1].value <<
00320         //           "... -> "<< score); 
00321         return score;
00322     }
00323 
00324     void
00325     SetModel(svm_model* model)
00326     {
00327         if(mModel)
00328             svm_destroy_model(mModel);
00329         mModel = model;
00330         if(svm_check_probability_model(mModel))
00331             mIsProbabilityModel = true;
00332         else
00333             mIsProbabilityModel = false;
00334 
00335         if(mIsProbabilityModel)
00336         {
00337             int labels[2];
00338             svm_get_labels(mModel, labels);
00339             if(labels[1] == 1)
00340                 mProbilityIndex = 1;
00341         }
00342     }
00343 
00344     svm_model* mModel;
00345     bool mIsProbabilityModel;
00346     int mProbilityIndex;
00347     ILOG_VAR_DEC;
00348 };
00349 
00350 ILOG_VAR_INIT(Svm, Impala.Core.Training);
00351 
00352 
00353 bool
00354 Equals(const Svm* svm1, const Svm* svm2)
00355 {
00356     return svm1->Equals(svm2);
00357 }
00358 
00359 svm_problem*
00360 ConvertToSvmProblem(const Impala::Core::Vector::VectorTem<double>* feature)
00361 {
00362     svm_problem* svm = new svm_problem;
00363     svm->l = 1;
00364     svm->y = new double[svm->l];
00365     svm->x = new struct svm_node *[svm->l];
00366     struct svm_node* nodes = new struct svm_node[svm->l*(feature->Size() + 1)];
00367     for(int i=0 ; i<svm->l ; i++)
00368     {
00369         svm->y[i] = 0;
00370         svm->x[i] = nodes;
00371         const double* values = feature->GetData();
00372         int j;
00373         for(j=0 ; j<feature->Size() ; j++)
00374         {
00375             svm->x[i][j].index = j+1;
00376             svm->x[i][j].value = values[j];
00377         }
00378         svm->x[i][j].index = -1;
00379     }
00380     return svm;
00381 }
00382 
00383 }//namespace
00384 }//namespace
00385 }//namespace
00386 
00387 
00388 #endif

Generated on Fri Mar 19 09:31:25 2010 for ImpalaSrc by  doxygen 1.5.1