Home || Visual Search || Applications || Architecture || Important Messages || OGL || Src

mainRandomForest.cpp

Go to the documentation of this file.
00001 #include "Basis/ConfigBase.h"
00002 #include "Core/Training/Factory.h"
00003 #include "Core/Array/RGB2Intensity.h"
00004 #include "Core/Array/PixMax.h"
00005 #include "Core/Array/MulVal.h"
00006 #include "Core/Feature/Surf.h"
00007 #include "Core/Feature/MakeRandomTree.h"
00008 #include "Core/Feature/RandomForest.h"
00009 #include "Core/Feature/PointDescriptorTable.h"
00010 #include "Core/Column/Types.h"
00011 #include <fstream>
00012 
00013 // not used, but needed because of factory. sigh...
00014 #include "Link/Svm/LinkSvm.cpp"
00015 
00016 namespace
00017 {
00018 
00019 using namespace Impala;
00020 using namespace Core;
00021 
00022 typedef Table::TableTem<Column::ColumnQuid,
00023                         Column::ColumnInt32> SelectionTable;
00024 
00025 const int cNrImages = 50;
00026 const int cPointPerImage = 250;
00027 // const int cNrImages = 10;
00028 //const int cPointPerImage = 50;
00029     
00030 void
00031 PickRandomSamples(SelectionTable& quids, Table::QuidTable* src, int classId)
00032 {
00033     Util::QuasiRandomSequenceIterator it(src->Size(), 0);
00034     for(int i=0 ; i<cNrImages ; ++i)
00035     {
00036         quids.Add(src->Get1(*it), classId);
00037         ++it;
00038     }
00039 }
00040     
00041 void
00042 SelectQuids(SelectionTable& quids, int& nrClasses, Training::Factory* factory)
00043 {
00044     ILOG_VAR(main);
00045     std::vector<String> concepts = factory->MakeConceptList();
00046     if(concepts.size() == 0)
00047         ILOG_ERROR("couldn't open concept list");
00048     nrClasses = concepts.size();
00049     for(int i=0 ; i<nrClasses ; ++i)
00050     {
00051         Table::AnnotationTable* anno = factory->MakeAnnotation(concepts[i]);
00052         Table::QuidTable* pos = anno->GetPositive();
00053         ILOG_INFO(pos->Size() << " positive annotations");
00054         if(pos->Size() < cNrImages)
00055         {
00056             ILOG_WARN("not enough positive examples for concept "<< concepts[i]);
00057             continue;
00058         }
00059         PickRandomSamples(quids, pos, i);
00060     }
00061     // check size of quids for validity?
00062     ILOG_INFO(quids.Size() << " quids found");
00063 }
00064 
00065 void
00066 GetNSamples(Feature::AnnotatedFeatureTable& featureSampling,
00067             Feature::PointDescriptorTable* pointData, int n, int classId)
00068 {
00069     static Util::Random sRNG;
00070     std::set<int> indices = sRNG.RandomUniqueNumbers(n, pointData->Size());
00071     for(std::set<int>::iterator itIndex = indices.begin() ;
00072         itIndex != indices.end() ; ++itIndex)
00073     {
00074         int index = *itIndex;
00075         Vector::VectorTem<double> vector(pointData->GetDescriptorLength(), 
00076             pointData->GetDescriptorData(index), true);
00077         featureSampling.Add(vector, classId);
00078     }
00079 }
00080     
00081 void
00082 GetFeatures(Feature::AnnotatedFeatureTable& featureSampling,
00083             const SelectionTable& quidSampling, Training::Factory* factory,
00084             int dSurfParams[3])
00085 {
00086     ILOG_FUNCTION(GetFeatures);
00087     for(int i=0 ; i<quidSampling.Size() ; ++i)
00088     {
00089         ILOG_PROGRESS(i <<" of "<< quidSampling.Size() <<" processed", 4.);
00090         Quid q = quidSampling.Get1(i);
00091         Array::Array2dVec3UInt8* image = factory->MakeImage(q);
00092         if(image == 0)
00093             ILOG_ERROR("couldn't get image of quid "<< QuidObj(q));
00094         Feature::PointDescriptorTable* pointData = new 
00095             Feature::PointDescriptorTable(Feature::FeatureDefinition(""));
00096         String descriptor = factory->GetFeatureDefinition().GetName();
00097         ILOG_DEBUG("d="<<descriptor);
00098         descriptor = descriptor.substr(descriptor.find("-")+1);
00099         descriptor = StringReplace(descriptor, "surf", "");
00100         ILOG_DEBUG("d="<<descriptor);
00101         Feature::CalculateSurfDescriptors
00102             (image, pointData, descriptor,
00103              dSurfParams[0], dSurfParams[1], dSurfParams[2]);
00104         int classId = quidSampling.Get2(i);
00105         ILOG_DEBUG("point list size="<< pointData->Size());
00106         GetNSamples(featureSampling, pointData, cPointPerImage, classId);
00107         delete pointData;
00108         delete image;
00109     }
00110     ILOG_PROGRESS_DONE("all processed");
00111 }
00112 
00113 void
00114 DumpDescriptors(Feature::AnnotatedFeatureTable& featureSampling)
00115 {
00116     std::ofstream ofs("descriptordump.txt");
00117     if(ofs.is_open())
00118     {
00119         Feature::Dump(&featureSampling, ofs);
00120         ofs.close();
00121     }
00122 }
00123 
00124 void
00125 DumpTree(Feature::RandomTree* tree)
00126 {
00127     ILOG_FUNCTION(main);
00128     std::ofstream ofs("treedump.txt");
00129     if(ofs.is_open())
00130     {
00131         tree->Dump(ofs, 10000);
00132         ofs.close();
00133     }
00134     else
00135         ILOG_ERROR("couldn't dump trees");
00136 }
00137 
00138 void
00139 DumpTreeCounts(Feature::RandomTree* tree, int i)
00140 {
00141     ILOG_FUNCTION(main);
00142     std::ofstream ofs(("countdump"+MakeString(i)+".txt").c_str());
00143     if(ofs.is_open())
00144     {
00145         tree->DumpCount(ofs);
00146         ofs.close();
00147     }
00148     else
00149         ILOG_ERROR("couldn't dump tree counts");
00150 }
00151 
00152 void
00153 ProjectAndDump(Feature::FeatureTable* codebook,
00154                Feature::AnnotatedFeatureTable& featureSampling)
00155 {
00156     ILOG_FUNCTION(main);
00157     int codebookLength = GetCodebookLength(codebook);
00158     int* hist = new int[codebookLength];
00159     for(int i=0 ; i<codebookLength ; ++i)
00160         hist[i] = 0;
00161     Feature::RandomForest forest = ReadRandomForest(codebook);
00162     for(int i=0; i<featureSampling.Size(); ++i)
00163     {
00164         for(int f=0; f<forest.size(); ++f)
00165         {
00166             Feature::RandomTree* tree = forest[f];
00167             int codeword = tree->GetCodeWord(featureSampling.Get1(i));
00168             ++hist[codeword];
00169         }
00170     }
00171     DeleteForest(forest);
00172     std::ofstream ofs("projectiondump.txt");
00173     if(ofs.is_open())
00174     {
00175         for(int i=0 ; i<codebookLength ; ++i)
00176             ofs << i <<": #"<< hist[i] <<"\n";
00177         ofs.close();
00178     }
00179     else
00180         ILOG_ERROR("couldn't dump projection");
00181     delete hist;
00182 }
00183 
00184 
00185 class RandomForestConfig : public ConfigBase
00186 {
00187 public:
00188     RandomForestConfig() :
00189         depth(10), tries(32), dumpTree(false), dumpTreeCounts(false),
00190         projectAndDump(false)
00191     {
00192         // don't know about next part, defaults come from Feature::AddSurfOptions
00193         surfParams[0] = 3;
00194         surfParams[1] = 2;
00195         surfParams[2] = 4;
00196     }
00197 
00198     void InitOptions(CmdOptions& co)
00199     {
00200         co.AddOption(0, "dumpTree", "", "0");
00201         co.AddOption(0, "dumpTreeCounts", "", "0");
00202         co.AddOption(0, "projectAndDump", "", "0");
00203         co.AddOption(0, "forest-tries", "int", "32");
00204         co.AddOption(0, "forest-depth", "int", "10");
00205         Feature::AddDSurfOptions(co);
00206     }
00207 
00208     void RetrieveOptions(CmdOptions& co)
00209     {
00210         depth  = co.GetInt("forest-depth", depth);
00211         tries  = co.GetInt("forest-tries", tries);
00212         dumpTree = co.GetBool("dumpTree");
00213         dumpTreeCounts = co.GetBool("dumpTreeCounts");
00214         projectAndDump = co.GetBool("projectAndDump");
00215         Feature::GetDSurfOptions(co, surfParams[0], surfParams[1], surfParams[2]);
00216     }
00217 
00218 private:
00219     int surfParams[3];
00220     int depth;                                   
00221     int tries;                                   
00222     bool dumpTree, dumpTreeCounts, projectAndDump;
00223 
00224     friend int RandomForest(int argc, char** argv);
00225 };
00226 
00227 int RandomForest(int argc, char** argv)
00228 {
00229     ILOG_FUNCTION(main);
00230     RandomForestConfig config;
00231     CmdOptions& options = CmdOptions::GetInstance();
00232     options.Initialise(false, false, true);
00233     config.InitOptions(options);
00234     if (options.ParseArgs(argc, argv, "dataSet concepts 0 featureDef", 4))
00235     {
00236         // setup the application
00237         config.RetrieveOptions(options);
00238         Core::Training::Factory factory(&options, false);
00239         if (factory.CodebookExists())
00240         {
00241             ILOG_WARNING("codebook already exists; skipping...");
00242             return ILOG_ERROR_COUNT;
00243         }
00244         
00245         // get the quids of the classes
00246         SelectionTable quidSampling(0);
00247         int nrClasses;
00248         SelectQuids(quidSampling, nrClasses, &factory);
00249         Feature::AnnotatedFeatureTable featureSampling
00250             (Vector::ColumnVectorSet(true, 64, 0), Column::ColumnInt32(0));
00251         
00252         // get the features
00253         GetFeatures(featureSampling, quidSampling, &factory, config.surfParams);
00254         ILOG_INFO("got "<< featureSampling.Size() <<" samples for random forest");
00255         
00256         // make the forest
00257         Feature::RandomTreeTable forest(0);
00258         for(int i=0 ; i<4 ; ++i)
00259         {
00260             Util::Random rng;
00261             rng.SetSeed(i);
00262             Feature::RandomTree* tree = Feature::MakeRandomTree
00263                 (&featureSampling, nrClasses, config.depth, config.tries, rng);
00264             Write(tree, &forest);
00265             if(config.dumpTreeCounts)
00266                 DumpTreeCounts(tree, i);
00267             if(config.dumpTree)
00268                 DumpTree(tree);
00269             delete tree;
00270         }
00271 
00272         if(forest.Size() > 0)
00273         {
00274             // save results to disk
00275             Feature::FeatureTable* ft = Feature::MakeFeatureTable(&forest);
00276             if(config.projectAndDump)
00277                 ProjectAndDump(ft, featureSampling);
00278             factory.WriteCodebook(ft);
00279             ILOG_INFO("saved code books");
00280             delete ft;
00281         }
00282     }
00283     return ILOG_ERROR_COUNT;
00284 }
00285 
00286 
00287 } // namespace
00288 
00289 int
00290 main(int argc, char* argv[])
00291 {
00292     return RandomForest(argc, argv);
00293 }
00294 

Generated on Thu Jan 13 09:03:43 2011 for ImpalaSrc by  doxygen 1.5.1