00001 #ifndef Impala_Core_Training_TesterIoHelper_h
00002 #define Impala_Core_Training_TesterIoHelper_h
00003
00004 #include "Core/DataFactory.h"
00005 #include "Core/Training/Factory.h"
00006 #include "Util/Database.h"
00007 #include "Core/VideoSet/VideoSet.h"
00008 #include "Core/Training/Svm.h"
00009 #include "Core/Training/SvmFile.h"
00010
00011 namespace Impala
00012 {
00013 namespace Core
00014 {
00015 namespace Training
00016 {
00017
00018 class TesterIOHelper
00019 {
00020 public:
00021 TesterIOHelper(DataFactory* input,
00022 Database::RawDataSet* reference,
00023 Database::RawDataSet* output)
00024 {
00025 mInput = input;
00026 mReference = reference;
00027 mOutput = output;
00028
00029
00030 Database::RawDataSet* set = input->GetDataSet();
00031 String conceptFile =
00032 set->GetFilePathAnnotation(cConceptSet, false, false);
00033 std::vector<String> conceptList;
00034 Util::Database* db = set->GetDatabase();
00035 Util::DatabaseReadStrings(conceptList, conceptFile, db);
00036 mConcept = conceptList[0];
00037 }
00038
00039 void MakeTrainData(TrainDataSrc*& data, Table::AnnotationTable* annotation,
00040 String kerneltype)
00041 {
00042 Util::PropertySet pset;
00043 pset.Add("kernel", kerneltype);
00044 Factory factory(&pset);
00045 data = factory.MakeDataSrc(annotation, mInput);
00046 ILOG_DEBUG("Train data size: " << data->Size() );
00047 }
00048
00049 void ReadAnnotation(Table::AnnotationTable*& annotation)
00050 {
00051 annotation = mInput->MakeAnnotation(mConcept);
00052 ILOG_DEBUG("Annotations size: " << annotation->Size() );
00053 }
00054
00055
00056 void WriteFold(int nr, Table::QuidTable *fold)
00057 {
00058 String filename = FoldFilename(nr);
00059 String path = mOutput->GetFilePathAnnotation
00060 (QUID_CLASS_FRAME, cConceptSet, filename, true, false);
00061 AssertFileFound(path, "could not write fold " + filename);
00062 Write(fold, path, mOutput->GetDatabase(), false);
00063 }
00064
00065 Table::QuidTable* ReadReferenceFold(int nr)
00066 {
00067 String filename = FoldFilename(nr);
00068 String path = mReference->GetFilePathAnnotation
00069 (QUID_CLASS_FRAME, cConceptSet, filename, false, false);
00070 AssertFileFound(path, "could not read reference fold " + filename);
00071 Table::QuidTable* refFold = new Table::QuidTable;
00072 Read(refFold, path, mReference->GetDatabase());
00073 return refFold;
00074 }
00075
00076 void WriteProblem(svm_problem* problem)
00077 {
00078 String filename = ProblemFilename();
00079 String path = mOutput->GetFilePathConceptModel
00080 (cConceptSet, cModel, cFeature, filename, true, false);
00081 AssertFileFound(path, "could not write problem " + filename);
00082 WriteSvmFile(problem, path);
00083 }
00084
00085 svm_problem* ReadReferenceProblem()
00086 {
00087 String filename = ProblemFilename();
00088 String path = mReference->GetFilePathConceptModel
00089 (cConceptSet, cModel, cFeature, filename, false, false);
00090 AssertFileFound(path, "could not read reference problem " + filename);
00091 return ReadSvmFile(path);
00092 }
00093
00094 void ReadModel(Svm* svm, String type)
00095 {
00096 String filename = ModelFilename(type);
00097 String path = mReference->GetFilePathConceptModel
00098 (cConceptSet, cModel, cFeature, filename, false, false);
00099 AssertFileFound(path, "could not read model " + filename);
00100 svm->LoadModel(path, mReference->GetDatabase());
00101 }
00102
00103 void WriteModel(Svm* svm, String type)
00104 {
00105 String filename = ModelFilename(type);
00106 String path = mOutput->GetFilePathConceptModel
00107 (cConceptSet, cModel, cFeature, filename, true, false);
00108 AssertFileFound(path, "could not write model " + filename);
00109 svm->SaveModel(path, mOutput->GetDatabase());
00110 }
00111
00112 Table::ScoreTable* ReadRanking(String type)
00113 {
00114 String filename = type + RankingFilename();
00115 Table::ScoreTable* ranking = new Table::ScoreTable;
00116 String path = mReference->GetFilePathConceptModel
00117 (cConceptSet, cModel, cFeature, filename, false, false);
00118 AssertFileFound(path, "could not read ranking " + filename);
00119 Read(ranking, path, mReference->GetDatabase());
00120 return ranking;
00121 }
00122
00123 void WriteRanking(Table::ScoreTable* ranking, String type)
00124 {
00125 String filename = type + RankingFilename();
00126 String path = mOutput->GetFilePathConceptModel
00127 (cConceptSet, cModel, cFeature, filename, true, false);
00128 AssertFileFound(path, "could not write ranking " + filename);
00129 Write(ranking, path, mOutput->GetDatabase(), true);
00130 }
00131
00132 double ReadAveragePrecision()
00133 {
00134 String filename = ScoreFilename();
00135 Table::ScoreTable* ranking = new Table::ScoreTable;
00136 String path = mReference->GetFilePathConceptModel
00137 (cConceptSet, cModel, cFeature, filename, false, false);
00138 AssertFileFound(path, "could not read AP " + filename);
00139
00140 double ap = -1;
00141 std::ifstream ifs(path.c_str());
00142 if(ifs.is_open())
00143 {
00144 ifs.read((char*) &ap, sizeof(double));
00145 ifs.close();
00146 }
00147 else
00148 ILOG_ERROR("couldn't open file " << path << "for reading");
00149 return ap;
00150 }
00151
00152 void WriteAveragePrecision(double ap)
00153 {
00154 String filename = ScoreFilename();
00155 String path = mOutput->GetFilePathConceptModel
00156 (cConceptSet, cModel, cFeature, filename, true, false);
00157 AssertFileFound(path, "could not write AP "+ filename);
00158
00159 std::ofstream ofs(path.c_str());
00160 if(ofs.is_open())
00161 {
00162 ofs.write((char*) &ap, sizeof(double));
00163 ofs.close();
00164 }
00165 else
00166 ILOG_ERROR("couldn't open file " << path << "for reading");
00167 }
00168
00169 private:
00170 void AssertFileFound(String filename, String errormsg)
00171 {
00172 if(filename.empty())
00173 {
00174 ILOG_ERROR(errormsg);
00175 exit(1);
00176 }
00177 }
00178
00179 String ProblemFilename()
00180 {
00181 return mConcept + ".problem.txt";
00182 }
00183 String FoldFilename(int i)
00184 {
00185 return mConcept +".fold"+ MakeString(i) +"of3(r0).tab";
00186 }
00187 String ModelFilename(String type)
00188 {
00189 return mConcept + "." + type + ".model";
00190 }
00191 String RankingFilename()
00192 {
00193 return mConcept+".ranking";
00194 }
00195 String ScoreFilename()
00196 {
00197 return mConcept+".score";
00198 }
00199
00200
00201 DataFactory* mInput;
00202 Database::RawDataSet* mReference;
00203 Database::RawDataSet* mOutput;
00204 String mConcept;
00205
00206 static const String cConceptSet;
00207 static const Feature::FeatureDefinition cFeature;
00208 static const String cModel;
00209 static const String cKernelMatrix;
00210 ILOG_VAR_DECL;
00211 };
00212
00213 const String TesterIOHelper::cConceptSet = "test.txt";
00214 const Feature::FeatureDefinition
00215 TesterIOHelper::cFeature("vissem_proto_annotation_nrScales_2_nrRects_130");
00216 const String TesterIOHelper::cModel("testmodel");
00217 const String TesterIOHelper::cKernelMatrix = "testsingle";
00218 ILOG_VAR_INIT(TesterIOHelper, Impala.Core.Training);
00219
00220
00221 }
00222 }
00223 }
00224
00225 #endif