Home || Architecture || Video Search || Visual Search || Scripts || Applications || Important Messages || OGL || Src

Tester.h

Go to the documentation of this file.
00001 #ifndef Impala_Core_Training_Test_h
00002 #define Impala_Core_Training_Test_h
00003 
00004 #include "Core/Equals.h"
00005 #include "Core/Table/Equals.h"
00006 #include "Core/Training/AveragePrecision.h"
00007 #include "Core/Training/Equals.h"
00008 #include "Link/Svm/LinkSvm.h"
00009 #include "Core/Training/TesterIoHelper.h"
00010 #include "Core/Array/Equals.h"
00011 
00012 namespace Impala
00013 {
00014 namespace Core
00015 {
00016 namespace Training
00017 {
00018 
00019 using Core::Equals;
00020     
00023 class Tester
00024 {
00025 public:
00026     Tester(DataFactory* input,
00027            VideoSet::VideoSet* reference,
00028            VideoSet::VideoSet* output)
00029         : mIO(input, reference, output)
00030     {
00031         mAnnotation = 0;
00032         mDataSrc = 0;
00033         mFails = 0;
00034     }
00035 
00036     virtual ~Tester()
00037     {
00038         ClearData();
00039     }
00040 
00045     void
00046     Test()
00047     {
00048         mFails = 0;
00049         mIO.ReadAnnotation(mAnnotation);
00050         TestTrainAndApply("precomputed");
00051         TestTrainAndApply("rbf");
00052         TestFolds();
00053         TestMakeSvmProblem();
00054         TestAveragePrecision();
00055         ClearData();
00056         Report();
00057     }
00058 
00059     void
00060     TestMpi()
00061     {
00062         mFails = 0;
00063         mIO.ReadAnnotation(mAnnotation);
00064         TestTrainAndApply("dist-precomputed");
00065         TestDistVsNonDist();
00066         ClearData();
00067         Report();
00068     }
00069 
00070     void
00071     Report()
00072     {
00073         if(mFails == 0)
00074         {
00075             ILOG_INFO("\n\n\nall tests passed.\n");
00076         }
00077         else
00078         {
00079             if(mFails == 1)
00080             {
00081                 ILOG_WARNING("\n\n\n*** 1 test failed! ***\n");
00082             }
00083             else
00084                 ILOG_WARNING("\n\n\n*** " << mFails << " tests failed! ***\n");
00085         }
00086     }
00087 
00088 private:
00089     void
00090     TestTrainAndApply(String kernel)
00091     {
00092         ClearData(mDataSrc);
00093         ILOG_INFO("annotation: "<< mAnnotation);
00094         mIO.MakeTrainData(mDataSrc, mAnnotation, kernel);
00095         TestTrainModel(kernel);
00096         TestApply(kernel);
00097     }
00098 
00099     void
00100     TestFolds()
00101     {
00102         std::vector<Table::QuidTable*> folds =
00103             mAnnotation->MakeEpisodeFolds(3, 0);
00104         for(int i=0 ; i<folds.size() ; ++i)
00105         {
00106             Table::QuidTable* fold = folds[i];
00107             mIO.WriteFold(i, fold);
00108             Table::QuidTable *refFold = mIO.ReadReferenceFold(i);
00109             Test(fold, refFold, "fold test #" + MakeString(i));
00110             delete refFold;
00111             delete fold;
00112         }
00113     }
00114 
00115     void
00116     TestMakeSvmProblem()
00117     {
00118         if(mDataSrc == 0)
00119             mIO.MakeTrainData(mDataSrc, mAnnotation, "rbf");
00120         FirstFold();
00121         svm_problem* problem = mDataSrc->MakeSvmProblem();
00122         mIO.WriteProblem(problem);
00123         svm_problem* refProblem = mIO.ReadReferenceProblem();
00124         Test(problem, refProblem, "make svm problem test");
00125         // ILOG_INFO("dump problem:");
00126         // Dump(problem, 10, 8, std::cout);
00127         // ILOG_INFO("dump ref problem:");
00128         // Dump(refProblem, 10, 8, std::cout);
00129         DestroySvmProblem(refProblem);
00130         mDataSrc->FreeProblem(problem);
00131     }
00132 
00133     void
00134     TestTrainModel(String kernel)
00135     {
00136         FirstFold();
00137         Util::PropertySet props;
00138         props.Add("kernel", kernel);
00139         ILOG_DEBUG("data size = "<< mDataSrc->Size());
00140         Svm svm;
00141         svm.Train(&props, mDataSrc);
00142         mIO.WriteModel(&svm, kernel);
00143         Svm refSvm;
00144         mIO.ReadModel(&refSvm, kernel);
00145         Test(&svm, &refSvm, "train " + kernel + " model");
00146     }
00147 
00150     void
00151     TestApply(String kernel)
00152     {
00153         FirstFoldInv();
00154         Svm svm;
00155         mIO.ReadModel(&svm, kernel);
00156         Table::ScoreTable* ranking = svm.Predict(mDataSrc);
00157 
00158         mIO.WriteRanking(ranking, kernel);
00159         Table::ScoreTable* refRanking = mIO.ReadRanking(kernel);
00166         double tolerance = 1e-14;
00167         Test(ranking, refRanking, "apply concept " + kernel, tolerance);
00168         delete ranking;
00169         delete refRanking;
00170     }
00171      
00172      
00173      
00174 
00177     void
00178     TestAveragePrecision()
00179     {
00180         Table::ScoreTable* refRanking = mIO.ReadRanking("rbf");
00181         AveragePrecision ap(mAnnotation);
00182         double score = ap.Compute(refRanking);
00183         mIO.WriteAveragePrecision(score);
00184         double refScore = mIO.ReadAveragePrecision();
00185         Test(score, refScore, "average precision");
00186         delete refRanking;
00187     }
00188 
00189     void
00190     ClearData()
00191     {
00192         ClearData(mAnnotation);
00193         ClearData(mDataSrc);
00194     }
00195 
00196     template<class type>
00197     void
00198     ClearData(type *& data)
00199     {
00200         if(data)
00201             delete data;
00202         data = 0;
00203     }
00204 
00205     void
00206     FirstFold()
00207     {
00208         // getting the test fold is actually wrong, but backwards compatible
00209         mDataSrc->FilterTestFold(0, 3, 1, false); 
00210         ILOG_DEBUG("first fold in train data size: " << mDataSrc->Size() );
00211     }
00212 
00213     void
00214     FirstFoldInv()
00215     {
00216         // getting the train fold is actually wrong, but backwards compatible
00217         mDataSrc->FilterTrainFold(0, 3, 1, false);
00218         ILOG_DEBUG("fold 2+3 in train data size: " << mDataSrc->Size() );
00219     }
00220 
00221     void
00222     TestDistVsNonDist()
00223     {
00224         TrainDataSrc* dist=0;
00225         mIO.MakeTrainData(dist, mAnnotation, "dist-precomputed");
00226         Matrix::Mat* matDist = dist->MakeDataCopy(-1, -1);
00227         //matDist->WriteTo(std::cout, -1, -1);
00228 
00229         TrainDataSrc* nonDist=0;
00230         mIO.MakeTrainData(nonDist, mAnnotation, "precomputed");
00231         Matrix::Mat* matNonDist = nonDist->MakeDataCopy(-1, -1);
00232         //matNonDist->WriteTo(std::cout, -1, -1);
00233 
00234         Test(matDist, matNonDist, "dist == non-dist");
00235         delete dist;
00236         delete nonDist;
00237     }
00238 
00239 
00240     template<class type>
00241     void
00242     Test(const type item, const type reference, String testname)
00243     {
00244         Test(Equals(item, reference), testname);
00245     }
00246 
00247     template<class type>
00248     void
00249     Test(const type item, const type reference, 
00250               String testname, double tolerance)
00251     {
00252         Test(Equals(item, reference, tolerance), testname);
00253     }
00254 
00255     void
00256     Test(bool passed, String testname)
00257     {
00258         if(passed)
00259         {
00260             ILOG_INFO("passed " << testname);
00261         }
00262         else
00263         {
00264             ++mFails;
00265             ILOG_WARNING("*** FAILED *** " << testname);
00266         }
00267     }
00268 
00269     int mFails;
00270     TesterIOHelper mIO;
00271     Table::AnnotationTable* mAnnotation;
00272     TrainDataSrc* mDataSrc;
00273 
00274     ILOG_VAR_DECL;
00275 };
00276 
00277 ILOG_VAR_INIT(Tester, Impala.Core.Training);
00278 
00279 }//namespace Core
00280 }//namespace Training
00281 }//Impala
00282 
00283 #endif
00284 

Generated on Fri Mar 19 09:31:25 2010 for ImpalaSrc by  doxygen 1.5.1