Home || Architecture || Video Search || Visual Search || Scripts || Applications || Important Messages || OGL || Src

mainRandomForest.cpp

Go to the documentation of this file.
00001 #include "Basis/ConfigBase.h"
00002 #include "Core/ApplicationFactory.h"
00003 #include "Core/Array/RGB2Intensity.h"
00004 #include "Core/Array/PixMax.h"
00005 #include "Core/Array/MulVal.h"
00006 #include "Core/Feature/Surf.h"
00007 #include "Core/Feature/MakeRandomTree.h"
00008 #include "Core/Feature/RandomForest.h"
00009 #include "Core/Column/Types.h"
00010 #include <fstream>
00011 
00012 namespace
00013 {
00014 
00015 using namespace Impala;
00016 using namespace Core;
00017 
00018 typedef Table::TableTem<Column::ColumnQuid,
00019                         Column::ColumnInt32> SelectionTable;
00020 
00021 const int cNrImages = 50;
00022 const int cPointPerImage = 250;
00023 // const int cNrImages = 10;
00024 //const int cPointPerImage = 50;
00025     
00026 void
00027 PickRandomSamples(SelectionTable& quids, Table::QuidTable* src, int classId)
00028 {
00029     Util::QuasiRandomSequenceIterator it(src->Size(), 0);
00030     for(int i=0 ; i<cNrImages ; ++i)
00031     {
00032         quids.Add(src->Get1(*it), classId);
00033         ++it;
00034     }
00035 }
00036     
00037 void
00038 SelectQuids(SelectionTable& quids, int& nrClasses, DataFactory* factory)
00039 {
00040     ILOG_VAR(main);
00041     std::vector<String> concepts = factory->MakeConceptList();
00042     if(concepts.size() == 0)
00043         ILOG_ERROR("couldn't open concept list");
00044     nrClasses = concepts.size();
00045     for(int i=0 ; i<nrClasses ; ++i)
00046     {
00047         Table::AnnotationTable* anno = factory->MakeAnnotation(concepts[i]);
00048         Table::QuidTable* pos = anno->GetPositive();
00049         ILOG_INFO(pos->Size() << " positive annotations");
00050         if(pos->Size() < cNrImages)
00051         {
00052             ILOG_WARN("not enough positive examples for concept "<< concepts[i]);
00053             continue;
00054         }
00055         PickRandomSamples(quids, pos, i);
00056     }
00057     // check size of quids for validity?
00058     ILOG_INFO(quids.Size() << " quids found");
00059 }
00060 
00061 void
00062 GetNSamples(Feature::AnnotatedFeatureTable& featureSampling,
00063             const Geometry::InterestPointList& points, int n, int classId)
00064 {
00065     Geometry::InterestPointList::const_iterator itPoint = points.begin();
00066     std::set<int> indices = Util::RandomUniqueNumbers(n, points.size());
00067     int index=0;
00068     for(std::set<int>::iterator itIndex = indices.begin() ;
00069         itIndex != indices.end() ; ++itIndex)
00070     {
00071         while(index<*itIndex)
00072         {
00073             ++itPoint;
00074             ++index;
00075         }
00076         std::vector<Real64>& v = (*itPoint)->mDescriptor;
00077         Vector::VectorTem<double> vector(v.size(), &v[0], true);
00078         featureSampling.Add(vector, classId);
00079     }
00080 }
00081     
00082 void
00083 GetFeatures(Feature::AnnotatedFeatureTable& featureSampling,
00084             const SelectionTable& quidSampling, DataFactory* factory,
00085             int dSurfParams[3])
00086 {
00087     ILOG_FUNCTION(GetFeatures);
00088     for(int i=0 ; i<quidSampling.Size() ; ++i)
00089     {
00090         ILOG_PROGRESS(i <<" of "<< quidSampling.Size() <<" processed", 4.);
00091         Quid q = quidSampling.Get1(i);
00092         Array::Array2dVec3UInt8* image = factory->MakeImage(q);
00093         if(image == 0)
00094             ILOG_ERROR("couldn't get image of quid "<< QuidObj(q));
00095         Geometry::InterestPointList pointList;
00096         String descriptor = factory->GetFeatureDefinition().GetName();
00097         ILOG_DEBUG("d="<<descriptor);
00098         descriptor = descriptor.substr(descriptor.find("-")+1);
00099         descriptor = StringReplace(descriptor, "surf", "");
00100         ILOG_DEBUG("d="<<descriptor);
00101         Feature::CalculateSurfDescriptors
00102             (image, pointList, descriptor,
00103              dSurfParams[0], dSurfParams[1], dSurfParams[2]);
00104         int classId = quidSampling.Get2(i);
00105         ILOG_DEBUG("point list size="<< pointList.size());
00106         GetNSamples(featureSampling, pointList, cPointPerImage, classId);
00107         pointList.EraseAndDelete();
00108         delete image;
00109     }
00110     ILOG_PROGRESS_DONE("all processed");
00111 }
00112 
00113 void
00114 DumpDescriptors(Feature::AnnotatedFeatureTable& featureSampling)
00115 {
00116     std::ofstream ofs("descriptordump.txt");
00117     if(ofs.is_open())
00118     {
00119         Feature::Dump(&featureSampling, ofs);
00120         ofs.close();
00121     }
00122 }
00123 
00124 void
00125 DumpTree(Feature::RandomTree* tree)
00126 {
00127     ILOG_FUNCTION(main);
00128     std::ofstream ofs("treedump.txt");
00129     if(ofs.is_open())
00130     {
00131         tree->Dump(ofs, 10000);
00132         ofs.close();
00133     }
00134     else
00135         ILOG_ERROR("couldn't dump trees");
00136 }
00137 
00138 void
00139 DumpTreeCounts(Feature::RandomTree* tree, int i)
00140 {
00141     ILOG_FUNCTION(main);
00142     std::ofstream ofs(("countdump"+MakeString(i)+".txt").c_str());
00143     if(ofs.is_open())
00144     {
00145         tree->DumpCount(ofs);
00146         ofs.close();
00147     }
00148     else
00149         ILOG_ERROR("couldn't dump tree counts");
00150 }
00151 
00152 void
00153 ProjectAndDump(Feature::FeatureTable* codebook,
00154                Feature::AnnotatedFeatureTable& featureSampling)
00155 {
00156     ILOG_FUNCTION(main);
00157     int codebookLength = GetCodebookLength(codebook);
00158     int* hist = new int[codebookLength];
00159     for(int i=0 ; i<codebookLength ; ++i)
00160         hist[i] = 0;
00161     Feature::RandomForest forest = ReadRandomForest(codebook);
00162     for(int i=0; i<featureSampling.Size(); ++i)
00163     {
00164         for(int f=0; f<forest.size(); ++f)
00165         {
00166             Feature::RandomTree* tree = forest[f];
00167             int codeword = tree->GetCodeWord(featureSampling.Get1(i));
00168             ++hist[codeword];
00169         }
00170     }
00171     DeleteForest(forest);
00172     std::ofstream ofs("projectiondump.txt");
00173     if(ofs.is_open())
00174     {
00175         for(int i=0 ; i<codebookLength ; ++i)
00176             ofs << i <<": #"<< hist[i] <<"\n";
00177         ofs.close();
00178     }
00179     else
00180         ILOG_ERROR("couldn't dump projection");
00181     delete hist;
00182 }
00183 
00184 
00185 class RandomForestConfig : public ConfigBase
00186 {
00187 public:
00188     RandomForestConfig() :
00189         depth(10), tries(32), dumpTree(false), dumpTreeCounts(false),
00190         projectAndDump(false)
00191     {
00192         // don't know about next part, defaults come from Feature::AddSurfOptions
00193         surfParams[0] = 3;
00194         surfParams[1] = 2;
00195         surfParams[2] = 4;
00196     }
00197 
00198     void InitOptions(CmdOptions& co)
00199     {
00200         co.AddOption(0, "dumpTree", "", "0");
00201         co.AddOption(0, "dumpTreeCounts", "", "0");
00202         co.AddOption(0, "projectAndDump", "", "0");
00203         co.AddOption(0, "forest-tries", "int", "32");
00204         co.AddOption(0, "forest-depth", "int", "10");
00205         Feature::AddDSurfOptions(co);
00206     }
00207 
00208     void RetrieveOptions(CmdOptions& co)
00209     {
00210         depth  = co.GetInt("forest-depth", depth);
00211         tries  = co.GetInt("forest-tries", tries);
00212         dumpTree = co.GetBool("dumpTree");
00213         dumpTreeCounts = co.GetBool("dumpTreeCounts");
00214         projectAndDump = co.GetBool("projectAndDump");
00215         Feature::GetDSurfOptions(co, surfParams[0], surfParams[1], surfParams[2]);
00216     }
00217 
00218 private:
00219     int surfParams[3];
00220     int depth;                                   
00221     int tries;                                   
00222     bool dumpTree, dumpTreeCounts, projectAndDump;
00223 
00224     friend int RandomForest(int argc, char** argv);
00225 };
00226 
00227 int RandomForest(int argc, char** argv)
00228 {
00229     ILOG_FUNCTION(main);
00230     RandomForestConfig config;
00231     CmdOptions& options = CmdOptions::GetInstance();
00232     options.Initialise(false, false, true);
00233     config.InitOptions(options);
00234     if (options.ParseArgs(argc, argv, "dataSet concepts 0 featureDef", 4))
00235     {
00236         // setup the application
00237         config.RetrieveOptions(options);
00238         Core::ApplicationFactory factory(&options);
00239         DataFactory* dataFactory = factory.MakeDataFactory();
00240         if(!dataFactory->CanWriteCodebook())
00241         {
00242             ILOG_WARNING("codebook already exists; skipping...");
00243             delete dataFactory;
00244             return 0;
00245         }
00246         
00247         // get the quids of the classes
00248         SelectionTable quidSampling(0);
00249         int nrClasses;
00250         SelectQuids(quidSampling, nrClasses, dataFactory);
00251         Feature::AnnotatedFeatureTable featureSampling
00252             (Vector::ColumnVectorSet(true, 64, 0), Column::ColumnInt32(0));
00253         
00254         // get the features
00255         GetFeatures(featureSampling, quidSampling, dataFactory, config.surfParams);
00256         ILOG_INFO("got "<< featureSampling.Size() <<" samples for random forest");
00257         
00258         // make the forest
00259         Feature::RandomTreeTable forest(0);
00260         for(int i=0 ; i<4 ; ++i)
00261         {
00262             Util::SetRandomSeed(i);
00263             Feature::RandomTree* tree = Feature::MakeRandomTree
00264                 (&featureSampling, nrClasses, config.depth, config.tries);
00265             Write(tree, &forest);
00266             if(config.dumpTreeCounts)
00267                 DumpTreeCounts(tree, i);
00268             if(config.dumpTree)
00269                 DumpTree(tree);
00270             delete tree;
00271         }
00272 
00273         if(forest.Size() > 0)
00274         {
00275             // save results to disk
00276             dataFactory->WriteRandomForest(&forest);
00277             Feature::FeatureTable* ft = Feature::MakeFeatureTable(&forest);
00278             if(config.projectAndDump)
00279                 ProjectAndDump(ft, featureSampling);
00280             dataFactory->WriteCodebook(ft);
00281             ILOG_INFO("saved code books");
00282             delete ft;
00283         }
00284         delete dataFactory;
00285     }
00286     return 0;
00287 }
00288 
00289 
00290 } // namespace
00291 
00292 int
00293 main(int argc, char* argv[])
00294 {
00295     return RandomForest(argc, argv);
00296 }
00297 

Generated on Fri Mar 19 09:30:29 2010 for ImpalaSrc by  doxygen 1.5.1