00001 #ifndef Impala_Core_Training_Svm_h
00002 #define Impala_Core_Training_Svm_h
00003
00004 #include "Util/FileCopy.h"
00005 #include "Core/Training/Classifier.h"
00006 #include "Core/Training/TrainDataSrcFeature.h"
00007 #include "Core/Training/Equals.h"
00008 #include "Core/Matrix/DistributedAccess.h"
00009 #include "Core/Table/ScoreTable.h"
00010 #include <iomanip>
00011 #include "Link/Svm/LinkSvm.h"
00012 #include "Util/TimeStats.h"
00013 #include "Persistency/SvmProblemRepository.h"
00014 #include "Core/VideoSet/MakeVideoSet.h"
00015
00016 namespace Impala
00017 {
00018 namespace Core
00019 {
00020 namespace Training
00021 {
00022
00050 class Svm : public Classifier
00051 {
00052 public:
00053 Svm()
00054 {
00055 mModel = 0;
00056 mProbilityIndex = 0;
00057 mIsProbabilityModel = false;
00058 }
00059
00060 Svm(String s)
00061 {
00062 mModel = 0;
00063 mProbilityIndex = 0;
00064 mIsProbabilityModel = false;
00065
00066 Util::IOBuffer buf(s.size());
00067 buf.Puts(s);
00068 buf.Rewind();
00069
00070 String tmpName = FileNameTmp();
00071 FileCopyRemoteToLocal(&buf, tmpName);
00072 svm_model* model = svm_load_model(tmpName.c_str());
00073 unlink(tmpName.c_str());
00074 if (!model)
00075 {
00076 ILOG_ERROR("Unable to load model");
00077 }
00078 else
00079 {
00080 SetModel(model);
00081 }
00082 }
00083
00084 virtual
00085 ~Svm()
00086 {
00087 if(mModel)
00088 svm_destroy_model(mModel);
00089 }
00090
00103 virtual void
00104 Train(Util::PropertySet* properties, TrainDataSrc* data)
00105 {
00106 if (properties->GetBool("autoweight"))
00107 {
00108 double pos = data->GetTotalPositiveCount();
00109 double neg = data->GetTotalNegativeCount();
00110 double posweight = (pos+neg) / pos;
00111 double negweight = (pos+neg) / neg;
00112 properties->Add("w1", posweight);
00113 properties->Add("w2", negweight);
00114 ILOG_DEBUG("autoweight: w+1=" << posweight << " w-1= " << negweight);
00115 }
00116 if ((properties->GetBool("probability")) &&
00117 (properties->GetString("kernel") == "rbf"))
00118 {
00119
00120 std::srand(1);
00121 }
00122 svm_parameter* parameters =
00123 MakeSvmParams(properties, data->GetVectorLength(), 0);
00124 svm_problem* p = data->MakeSvmProblem();
00125 CheckTestMode(p);
00126 ILOG_DEBUG_NODE("svm_train called with problem size = " << p->l);
00127 svm_model* model = svm_train(p, parameters);
00128 FixModelDependency(model);
00129 data->FreeProblem(p);
00130 DestroySvmParameters(parameters);
00131 SetModel(model);
00132 }
00133
00134 virtual Table::ScoreTable*
00135 Predict(TrainDataSrc* data)
00136 {
00137 Util::TimeStats stats;
00138 stats.AddGroup("make problem");
00139 stats.AddGroup("predict");
00140 stats.AddGroup("free problem");
00141
00142 Table::ScoreTable* result = new Table::ScoreTable();
00143 ILOG_DEBUG_NODE("Predict(): size of data = " << data->Size());
00144 for (int i=0 ; i<data->Size() ; ++i)
00145 {
00146 stats.MeasureFirst();
00147 svm_problem *temp = data->MakeSvmProblem(i);
00148 stats.MeasureNext();
00149 double score = PredictSingle(temp->x[0]);
00150 result->Add(data->GetQuid(i), score);
00151 stats.MeasureNext();
00152 data->FreeProblem(temp);
00153 stats.StopTime();
00154 }
00155 std::ostringstream oss;
00156 stats.Print(oss);
00157
00158 return result;
00159 }
00160
00171 virtual void
00172 PredictForActiveLearn(Matrix::DistributedAccess* da,
00173 Table::QuidTable* columnQuids,
00174 Core::Table::SimilarityTableSet::SimTableType* result)
00175 {
00176
00177
00178
00179
00180
00181
00182 if(!mModel)
00183 {
00184 ILOG_ERROR("[PredictForActiveLearn] classifier untrained\n");
00185 return;
00186 }
00187
00188 result->SetSize(0);
00189 ILOG_INFO("starting PredictForActiveLearn");
00190
00191 Table::QuidTable* rowQuids = da->GetRowQuids();
00192 Table::QuidTable* allColumns = da->GetColumnQuids();
00193 int rowCount = da->GetRows();
00194
00195 std::vector<Quid> sortedColumnQuids;
00196 std::vector<int> sortedColumnQuidIndices;
00197 std::vector<Real64*> sortedColumnQuidData;
00198 for(int i = 0; i < allColumns->Size(); i++)
00199 {
00200 if(columnQuids->Contains(allColumns->Get1(i)))
00201 {
00202
00203 sortedColumnQuids.push_back(allColumns->Get1(i));
00204 sortedColumnQuidIndices.push_back(i);
00205 Real64* buf = new Real64[rowCount];
00206 int received = da->GetColumn(i, buf, rowCount);
00207 if(received != rowCount)
00208 ILOG_ERROR("[PredictForActiveLearn] didn't receive column");
00209 sortedColumnQuidData.push_back(buf);
00210
00211 }
00212 }
00213 if(sortedColumnQuids.size() != columnQuids->Size())
00214 {
00215 ILOG_ERROR("not all column quids are in the table?!");
00216 return;
00217 }
00218
00219 ILOG_INFO("Done loading");
00220
00221 mModel->param.kernel_type = PRECOMPUTED;
00222
00223 int colCount = da->GetColumns();
00224 struct svm_node* x = new struct svm_node[colCount + 2];
00225 for(int i = 0; i < colCount + 2; i++)
00226 {
00227 x[i].index = 0;
00228 x[i].value = 0;
00229 }
00230 for(int i = 0; i < rowCount; i++)
00231 {
00232 x[0].index = 0;
00233 x[0].value = -1;
00234
00235 for(int j = 0; j < sortedColumnQuids.size(); j++)
00236 {
00237 int index = sortedColumnQuidIndices[j]+1;
00238 x[index].index = index;
00239 x[index].value = sortedColumnQuidData[j][i];
00240 }
00241 x[colCount+1].index = -1;
00242 x[colCount+1].value = 0;
00243
00244
00245 double score = PredictSingle(x);
00246 result->Add(score);
00247 }
00248 delete x;
00249 for(int i = 0; i < sortedColumnQuidData.size(); i++)
00250 {
00251 delete sortedColumnQuidData[i];
00252 }
00253 ILOG_INFO("done predict");
00254 }
00255
00256 void
00257 SetModel(svm_model* model)
00258 {
00259 if (mModel)
00260 svm_destroy_model(mModel);
00261 mModel = model;
00262 if (svm_check_probability_model(mModel))
00263 mIsProbabilityModel = true;
00264 else
00265 mIsProbabilityModel = false;
00266
00267 if (mIsProbabilityModel)
00268 {
00269 int labels[2];
00270 svm_get_labels(mModel, labels);
00271 if (labels[1] == 1)
00272 mProbilityIndex = 1;
00273 }
00274 }
00275
00276 virtual void
00277 OverrideModelOptions(Util::PropertySet* properties)
00278 {
00279 MakeSvmParams(properties, 1, &mModel->param);
00280 }
00281
00282 const svm_model*
00283 GetModel()
00284 {
00285 return mModel;
00286 }
00287
00288 bool
00289 Equals(const Svm* other) const
00290 {
00291 if (mModel==0 && other->mModel==0)
00292 return true;
00293 if (mModel==0 || other->mModel==0)
00294 return false;
00295
00296 return Training::Equals(mModel, other->mModel);
00297 }
00298
00299 int
00300 ReferenceDiff(const Svm* other) const
00301 {
00302 int diff = Diff(other);
00303 if (diff > 0)
00304 ILOG_ERROR("Found "<< diff <<" differences");
00305 return diff;
00306 }
00307
00308 int
00309 Diff(const Svm* other) const
00310 {
00311 if (mModel==0 && other->mModel==0)
00312 return 0;
00313 if (mModel==0 || other->mModel==0)
00314 return 1;
00315 return Impala::Diff(mModel, other->mModel);
00316 }
00317
00318 private:
00319
00320
00321 void
00322 CheckTestMode(svm_problem* problem)
00323 {
00324 CmdOptions& options = CmdOptions::GetInstance();
00325 if (! options.GetBool("testMode"))
00326 return;
00327
00328 VideoSet::VideoSet* vidSet =
00329 VideoSet::MakeVideoSet("trec2005fsd_try.txt");
00330 String filename = "anchor.problem.txt";
00331 Persistency::ModelLocator loc
00332 (vidSet->GetLocator(), "concepts.txt", "testmodel",
00333 "vissem_proto_annotation_nrScales_2_nrRects_130", "anchor");
00334 Persistency::SvmProblemRepository().Add(loc, problem);
00335 ILOG_INFO("TestMode, we're done");
00336 exit(0);
00337 }
00338
00339 double
00340 PredictSingle(const svm_node* problem)
00341 {
00342 double score;
00343 if(mIsProbabilityModel)
00344 {
00345 double probabilities[2];
00346 svm_predict_probability(mModel, problem, probabilities);
00347 score = probabilities[mProbilityIndex];
00348 }
00349 else
00350 {
00351 score = svm_predict(mModel, problem);
00352 }
00353 return score;
00354 }
00355
00356 svm_model* mModel;
00357 bool mIsProbabilityModel;
00358 int mProbilityIndex;
00359
00360 ILOG_VAR_DEC;
00361 };
00362
00363 ILOG_VAR_INIT(Svm, Impala.Core.Training);
00364
00365 bool
00366 Equals(const Svm* svm1, const Svm* svm2)
00367 {
00368 return svm1->Equals(svm2);
00369 }
00370
00371 int
00372 Diff(const Svm* svm1, const Svm* svm2)
00373 {
00374 return svm1->Diff(svm2);
00375 }
00376
00377 }
00378 }
00379 }
00380
00381
00382 #endif