00001 #include "Basis/ConfigBase.h"
00002 #include "Core/Training/Factory.h"
00003 #include "Core/Array/RGB2Intensity.h"
00004 #include "Core/Array/PixMax.h"
00005 #include "Core/Array/MulVal.h"
00006 #include "Core/Feature/Surf.h"
00007 #include "Core/Feature/MakeRandomTree.h"
00008 #include "Core/Feature/RandomForest.h"
00009 #include "Core/Feature/PointDescriptorTable.h"
00010 #include "Core/Column/Types.h"
00011 #include <fstream>
00012
00013
00014 #include "Link/Svm/LinkSvm.cpp"
00015
00016 namespace
00017 {
00018
00019 using namespace Impala;
00020 using namespace Core;
00021
00022 typedef Table::TableTem<Column::ColumnQuid,
00023 Column::ColumnInt32> SelectionTable;
00024
00025 const int cNrImages = 50;
00026 const int cPointPerImage = 250;
00027
00028
00029
00030 void
00031 PickRandomSamples(SelectionTable& quids, Table::QuidTable* src, int classId)
00032 {
00033 Util::QuasiRandomSequenceIterator it(src->Size(), 0);
00034 for(int i=0 ; i<cNrImages ; ++i)
00035 {
00036 quids.Add(src->Get1(*it), classId);
00037 ++it;
00038 }
00039 }
00040
00041 void
00042 SelectQuids(SelectionTable& quids, int& nrClasses, Training::Factory* factory)
00043 {
00044 ILOG_VAR(main);
00045 std::vector<String> concepts = factory->MakeConceptList();
00046 if(concepts.size() == 0)
00047 ILOG_ERROR("couldn't open concept list");
00048 nrClasses = concepts.size();
00049 for(int i=0 ; i<nrClasses ; ++i)
00050 {
00051 Table::AnnotationTable* anno = factory->MakeAnnotation(concepts[i]);
00052 Table::QuidTable* pos = anno->GetPositive();
00053 ILOG_INFO(pos->Size() << " positive annotations");
00054 if(pos->Size() < cNrImages)
00055 {
00056 ILOG_WARN("not enough positive examples for concept "<< concepts[i]);
00057 continue;
00058 }
00059 PickRandomSamples(quids, pos, i);
00060 }
00061
00062 ILOG_INFO(quids.Size() << " quids found");
00063 }
00064
00065 void
00066 GetNSamples(Feature::AnnotatedFeatureTable& featureSampling,
00067 Feature::PointDescriptorTable* pointData, int n, int classId)
00068 {
00069 static Util::Random sRNG;
00070 std::set<int> indices = sRNG.RandomUniqueNumbers(n, pointData->Size());
00071 for(std::set<int>::iterator itIndex = indices.begin() ;
00072 itIndex != indices.end() ; ++itIndex)
00073 {
00074 int index = *itIndex;
00075 Vector::VectorTem<double> vector(pointData->GetDescriptorLength(),
00076 pointData->GetDescriptorData(index), true);
00077 featureSampling.Add(vector, classId);
00078 }
00079 }
00080
00081 void
00082 GetFeatures(Feature::AnnotatedFeatureTable& featureSampling,
00083 const SelectionTable& quidSampling, Training::Factory* factory,
00084 int dSurfParams[3])
00085 {
00086 ILOG_FUNCTION(GetFeatures);
00087 for(int i=0 ; i<quidSampling.Size() ; ++i)
00088 {
00089 ILOG_PROGRESS(i <<" of "<< quidSampling.Size() <<" processed", 4.);
00090 Quid q = quidSampling.Get1(i);
00091 Array::Array2dVec3UInt8* image = factory->MakeImage(q);
00092 if(image == 0)
00093 ILOG_ERROR("couldn't get image of quid "<< QuidObj(q));
00094 Feature::PointDescriptorTable* pointData = new
00095 Feature::PointDescriptorTable(Feature::FeatureDefinition(""));
00096 String descriptor = factory->GetFeatureDefinition().GetName();
00097 ILOG_DEBUG("d="<<descriptor);
00098 descriptor = descriptor.substr(descriptor.find("-")+1);
00099 descriptor = StringReplace(descriptor, "surf", "");
00100 ILOG_DEBUG("d="<<descriptor);
00101 Feature::CalculateSurfDescriptors
00102 (image, pointData, descriptor,
00103 dSurfParams[0], dSurfParams[1], dSurfParams[2]);
00104 int classId = quidSampling.Get2(i);
00105 ILOG_DEBUG("point list size="<< pointData->Size());
00106 GetNSamples(featureSampling, pointData, cPointPerImage, classId);
00107 delete pointData;
00108 delete image;
00109 }
00110 ILOG_PROGRESS_DONE("all processed");
00111 }
00112
00113 void
00114 DumpDescriptors(Feature::AnnotatedFeatureTable& featureSampling)
00115 {
00116 std::ofstream ofs("descriptordump.txt");
00117 if(ofs.is_open())
00118 {
00119 Feature::Dump(&featureSampling, ofs);
00120 ofs.close();
00121 }
00122 }
00123
00124 void
00125 DumpTree(Feature::RandomTree* tree)
00126 {
00127 ILOG_FUNCTION(main);
00128 std::ofstream ofs("treedump.txt");
00129 if(ofs.is_open())
00130 {
00131 tree->Dump(ofs, 10000);
00132 ofs.close();
00133 }
00134 else
00135 ILOG_ERROR("couldn't dump trees");
00136 }
00137
00138 void
00139 DumpTreeCounts(Feature::RandomTree* tree, int i)
00140 {
00141 ILOG_FUNCTION(main);
00142 std::ofstream ofs(("countdump"+MakeString(i)+".txt").c_str());
00143 if(ofs.is_open())
00144 {
00145 tree->DumpCount(ofs);
00146 ofs.close();
00147 }
00148 else
00149 ILOG_ERROR("couldn't dump tree counts");
00150 }
00151
00152 void
00153 ProjectAndDump(Feature::FeatureTable* codebook,
00154 Feature::AnnotatedFeatureTable& featureSampling)
00155 {
00156 ILOG_FUNCTION(main);
00157 int codebookLength = GetCodebookLength(codebook);
00158 int* hist = new int[codebookLength];
00159 for(int i=0 ; i<codebookLength ; ++i)
00160 hist[i] = 0;
00161 Feature::RandomForest forest = ReadRandomForest(codebook);
00162 for(int i=0; i<featureSampling.Size(); ++i)
00163 {
00164 for(int f=0; f<forest.size(); ++f)
00165 {
00166 Feature::RandomTree* tree = forest[f];
00167 int codeword = tree->GetCodeWord(featureSampling.Get1(i));
00168 ++hist[codeword];
00169 }
00170 }
00171 DeleteForest(forest);
00172 std::ofstream ofs("projectiondump.txt");
00173 if(ofs.is_open())
00174 {
00175 for(int i=0 ; i<codebookLength ; ++i)
00176 ofs << i <<": #"<< hist[i] <<"\n";
00177 ofs.close();
00178 }
00179 else
00180 ILOG_ERROR("couldn't dump projection");
00181 delete hist;
00182 }
00183
00184
00185 class RandomForestConfig : public ConfigBase
00186 {
00187 public:
00188 RandomForestConfig() :
00189 depth(10), tries(32), dumpTree(false), dumpTreeCounts(false),
00190 projectAndDump(false)
00191 {
00192
00193 surfParams[0] = 3;
00194 surfParams[1] = 2;
00195 surfParams[2] = 4;
00196 }
00197
00198 void InitOptions(CmdOptions& co)
00199 {
00200 co.AddOption(0, "dumpTree", "", "0");
00201 co.AddOption(0, "dumpTreeCounts", "", "0");
00202 co.AddOption(0, "projectAndDump", "", "0");
00203 co.AddOption(0, "forest-tries", "int", "32");
00204 co.AddOption(0, "forest-depth", "int", "10");
00205 Feature::AddDSurfOptions(co);
00206 }
00207
00208 void RetrieveOptions(CmdOptions& co)
00209 {
00210 depth = co.GetInt("forest-depth", depth);
00211 tries = co.GetInt("forest-tries", tries);
00212 dumpTree = co.GetBool("dumpTree");
00213 dumpTreeCounts = co.GetBool("dumpTreeCounts");
00214 projectAndDump = co.GetBool("projectAndDump");
00215 Feature::GetDSurfOptions(co, surfParams[0], surfParams[1], surfParams[2]);
00216 }
00217
00218 private:
00219 int surfParams[3];
00220 int depth;
00221 int tries;
00222 bool dumpTree, dumpTreeCounts, projectAndDump;
00223
00224 friend int RandomForest(int argc, char** argv);
00225 };
00226
00227 int RandomForest(int argc, char** argv)
00228 {
00229 ILOG_FUNCTION(main);
00230 RandomForestConfig config;
00231 CmdOptions& options = CmdOptions::GetInstance();
00232 options.Initialise(false, false, true);
00233 config.InitOptions(options);
00234 if (options.ParseArgs(argc, argv, "dataSet concepts 0 featureDef", 4))
00235 {
00236
00237 config.RetrieveOptions(options);
00238 Core::Training::Factory factory(&options, false);
00239 if (factory.CodebookExists())
00240 {
00241 ILOG_WARNING("codebook already exists; skipping...");
00242 return ILOG_ERROR_COUNT;
00243 }
00244
00245
00246 SelectionTable quidSampling(0);
00247 int nrClasses;
00248 SelectQuids(quidSampling, nrClasses, &factory);
00249 Feature::AnnotatedFeatureTable featureSampling
00250 (Vector::ColumnVectorSet(true, 64, 0), Column::ColumnInt32(0));
00251
00252
00253 GetFeatures(featureSampling, quidSampling, &factory, config.surfParams);
00254 ILOG_INFO("got "<< featureSampling.Size() <<" samples for random forest");
00255
00256
00257 Feature::RandomTreeTable forest(0);
00258 for(int i=0 ; i<4 ; ++i)
00259 {
00260 Util::Random rng;
00261 rng.SetSeed(i);
00262 Feature::RandomTree* tree = Feature::MakeRandomTree
00263 (&featureSampling, nrClasses, config.depth, config.tries, rng);
00264 Write(tree, &forest);
00265 if(config.dumpTreeCounts)
00266 DumpTreeCounts(tree, i);
00267 if(config.dumpTree)
00268 DumpTree(tree);
00269 delete tree;
00270 }
00271
00272 if(forest.Size() > 0)
00273 {
00274
00275 Feature::FeatureTable* ft = Feature::MakeFeatureTable(&forest);
00276 if(config.projectAndDump)
00277 ProjectAndDump(ft, featureSampling);
00278 factory.WriteCodebook(ft);
00279 ILOG_INFO("saved code books");
00280 delete ft;
00281 }
00282 }
00283 return ILOG_ERROR_COUNT;
00284 }
00285
00286
00287 }
00288
00289 int
00290 main(int argc, char* argv[])
00291 {
00292 return RandomForest(argc, argv);
00293 }
00294