00001 #include "Basis/ConfigBase.h"
00002 #include "Core/ApplicationFactory.h"
00003 #include "Core/Array/RGB2Intensity.h"
00004 #include "Core/Array/PixMax.h"
00005 #include "Core/Array/MulVal.h"
00006 #include "Core/Feature/Surf.h"
00007 #include "Core/Feature/MakeRandomTree.h"
00008 #include "Core/Feature/RandomForest.h"
00009 #include "Core/Column/Types.h"
00010 #include <fstream>
00011
00012 namespace
00013 {
00014
00015 using namespace Impala;
00016 using namespace Core;
00017
00018 typedef Table::TableTem<Column::ColumnQuid,
00019 Column::ColumnInt32> SelectionTable;
00020
00021 const int cNrImages = 50;
00022 const int cPointPerImage = 250;
00023
00024
00025
00026 void
00027 PickRandomSamples(SelectionTable& quids, Table::QuidTable* src, int classId)
00028 {
00029 Util::QuasiRandomSequenceIterator it(src->Size(), 0);
00030 for(int i=0 ; i<cNrImages ; ++i)
00031 {
00032 quids.Add(src->Get1(*it), classId);
00033 ++it;
00034 }
00035 }
00036
00037 void
00038 SelectQuids(SelectionTable& quids, int& nrClasses, DataFactory* factory)
00039 {
00040 ILOG_VAR(main);
00041 std::vector<String> concepts = factory->MakeConceptList();
00042 if(concepts.size() == 0)
00043 ILOG_ERROR("couldn't open concept list");
00044 nrClasses = concepts.size();
00045 for(int i=0 ; i<nrClasses ; ++i)
00046 {
00047 Table::AnnotationTable* anno = factory->MakeAnnotation(concepts[i]);
00048 Table::QuidTable* pos = anno->GetPositive();
00049 ILOG_INFO(pos->Size() << " positive annotations");
00050 if(pos->Size() < cNrImages)
00051 {
00052 ILOG_WARN("not enough positive examples for concept "<< concepts[i]);
00053 continue;
00054 }
00055 PickRandomSamples(quids, pos, i);
00056 }
00057
00058 ILOG_INFO(quids.Size() << " quids found");
00059 }
00060
00061 void
00062 GetNSamples(Feature::AnnotatedFeatureTable& featureSampling,
00063 const Geometry::InterestPointList& points, int n, int classId)
00064 {
00065 Geometry::InterestPointList::const_iterator itPoint = points.begin();
00066 std::set<int> indices = Util::RandomUniqueNumbers(n, points.size());
00067 int index=0;
00068 for(std::set<int>::iterator itIndex = indices.begin() ;
00069 itIndex != indices.end() ; ++itIndex)
00070 {
00071 while(index<*itIndex)
00072 {
00073 ++itPoint;
00074 ++index;
00075 }
00076 std::vector<Real64>& v = (*itPoint)->mDescriptor;
00077 Vector::VectorTem<double> vector(v.size(), &v[0], true);
00078 featureSampling.Add(vector, classId);
00079 }
00080 }
00081
00082 void
00083 GetFeatures(Feature::AnnotatedFeatureTable& featureSampling,
00084 const SelectionTable& quidSampling, DataFactory* factory,
00085 int dSurfParams[3])
00086 {
00087 ILOG_FUNCTION(GetFeatures);
00088 for(int i=0 ; i<quidSampling.Size() ; ++i)
00089 {
00090 ILOG_PROGRESS(i <<" of "<< quidSampling.Size() <<" processed", 4.);
00091 Quid q = quidSampling.Get1(i);
00092 Array::Array2dVec3UInt8* image = factory->MakeImage(q);
00093 if(image == 0)
00094 ILOG_ERROR("couldn't get image of quid "<< QuidObj(q));
00095 Geometry::InterestPointList pointList;
00096 String descriptor = factory->GetFeatureDefinition().GetName();
00097 ILOG_DEBUG("d="<<descriptor);
00098 descriptor = descriptor.substr(descriptor.find("-")+1);
00099 descriptor = StringReplace(descriptor, "surf", "");
00100 ILOG_DEBUG("d="<<descriptor);
00101 Feature::CalculateSurfDescriptors
00102 (image, pointList, descriptor,
00103 dSurfParams[0], dSurfParams[1], dSurfParams[2]);
00104 int classId = quidSampling.Get2(i);
00105 ILOG_DEBUG("point list size="<< pointList.size());
00106 GetNSamples(featureSampling, pointList, cPointPerImage, classId);
00107 pointList.EraseAndDelete();
00108 delete image;
00109 }
00110 ILOG_PROGRESS_DONE("all processed");
00111 }
00112
00113 void
00114 DumpDescriptors(Feature::AnnotatedFeatureTable& featureSampling)
00115 {
00116 std::ofstream ofs("descriptordump.txt");
00117 if(ofs.is_open())
00118 {
00119 Feature::Dump(&featureSampling, ofs);
00120 ofs.close();
00121 }
00122 }
00123
00124 void
00125 DumpTree(Feature::RandomTree* tree)
00126 {
00127 ILOG_FUNCTION(main);
00128 std::ofstream ofs("treedump.txt");
00129 if(ofs.is_open())
00130 {
00131 tree->Dump(ofs, 10000);
00132 ofs.close();
00133 }
00134 else
00135 ILOG_ERROR("couldn't dump trees");
00136 }
00137
00138 void
00139 DumpTreeCounts(Feature::RandomTree* tree, int i)
00140 {
00141 ILOG_FUNCTION(main);
00142 std::ofstream ofs(("countdump"+MakeString(i)+".txt").c_str());
00143 if(ofs.is_open())
00144 {
00145 tree->DumpCount(ofs);
00146 ofs.close();
00147 }
00148 else
00149 ILOG_ERROR("couldn't dump tree counts");
00150 }
00151
00152 void
00153 ProjectAndDump(Feature::FeatureTable* codebook,
00154 Feature::AnnotatedFeatureTable& featureSampling)
00155 {
00156 ILOG_FUNCTION(main);
00157 int codebookLength = GetCodebookLength(codebook);
00158 int* hist = new int[codebookLength];
00159 for(int i=0 ; i<codebookLength ; ++i)
00160 hist[i] = 0;
00161 Feature::RandomForest forest = ReadRandomForest(codebook);
00162 for(int i=0; i<featureSampling.Size(); ++i)
00163 {
00164 for(int f=0; f<forest.size(); ++f)
00165 {
00166 Feature::RandomTree* tree = forest[f];
00167 int codeword = tree->GetCodeWord(featureSampling.Get1(i));
00168 ++hist[codeword];
00169 }
00170 }
00171 DeleteForest(forest);
00172 std::ofstream ofs("projectiondump.txt");
00173 if(ofs.is_open())
00174 {
00175 for(int i=0 ; i<codebookLength ; ++i)
00176 ofs << i <<": #"<< hist[i] <<"\n";
00177 ofs.close();
00178 }
00179 else
00180 ILOG_ERROR("couldn't dump projection");
00181 delete hist;
00182 }
00183
00184
00185 class RandomForestConfig : public ConfigBase
00186 {
00187 public:
00188 RandomForestConfig() :
00189 depth(10), tries(32), dumpTree(false), dumpTreeCounts(false),
00190 projectAndDump(false)
00191 {
00192
00193 surfParams[0] = 3;
00194 surfParams[1] = 2;
00195 surfParams[2] = 4;
00196 }
00197
00198 void InitOptions(CmdOptions& co)
00199 {
00200 co.AddOption(0, "dumpTree", "", "0");
00201 co.AddOption(0, "dumpTreeCounts", "", "0");
00202 co.AddOption(0, "projectAndDump", "", "0");
00203 co.AddOption(0, "forest-tries", "int", "32");
00204 co.AddOption(0, "forest-depth", "int", "10");
00205 Feature::AddDSurfOptions(co);
00206 }
00207
00208 void RetrieveOptions(CmdOptions& co)
00209 {
00210 depth = co.GetInt("forest-depth", depth);
00211 tries = co.GetInt("forest-tries", tries);
00212 dumpTree = co.GetBool("dumpTree");
00213 dumpTreeCounts = co.GetBool("dumpTreeCounts");
00214 projectAndDump = co.GetBool("projectAndDump");
00215 Feature::GetDSurfOptions(co, surfParams[0], surfParams[1], surfParams[2]);
00216 }
00217
00218 private:
00219 int surfParams[3];
00220 int depth;
00221 int tries;
00222 bool dumpTree, dumpTreeCounts, projectAndDump;
00223
00224 friend int RandomForest(int argc, char** argv);
00225 };
00226
00227 int RandomForest(int argc, char** argv)
00228 {
00229 ILOG_FUNCTION(main);
00230 RandomForestConfig config;
00231 CmdOptions& options = CmdOptions::GetInstance();
00232 options.Initialise(false, false, true);
00233 config.InitOptions(options);
00234 if (options.ParseArgs(argc, argv, "dataSet concepts 0 featureDef", 4))
00235 {
00236
00237 config.RetrieveOptions(options);
00238 Core::ApplicationFactory factory(&options);
00239 DataFactory* dataFactory = factory.MakeDataFactory();
00240 if(!dataFactory->CanWriteCodebook())
00241 {
00242 ILOG_WARNING("codebook already exists; skipping...");
00243 delete dataFactory;
00244 return 0;
00245 }
00246
00247
00248 SelectionTable quidSampling(0);
00249 int nrClasses;
00250 SelectQuids(quidSampling, nrClasses, dataFactory);
00251 Feature::AnnotatedFeatureTable featureSampling
00252 (Vector::ColumnVectorSet(true, 64, 0), Column::ColumnInt32(0));
00253
00254
00255 GetFeatures(featureSampling, quidSampling, dataFactory, config.surfParams);
00256 ILOG_INFO("got "<< featureSampling.Size() <<" samples for random forest");
00257
00258
00259 Feature::RandomTreeTable forest(0);
00260 for(int i=0 ; i<4 ; ++i)
00261 {
00262 Util::SetRandomSeed(i);
00263 Feature::RandomTree* tree = Feature::MakeRandomTree
00264 (&featureSampling, nrClasses, config.depth, config.tries);
00265 Write(tree, &forest);
00266 if(config.dumpTreeCounts)
00267 DumpTreeCounts(tree, i);
00268 if(config.dumpTree)
00269 DumpTree(tree);
00270 delete tree;
00271 }
00272
00273 if(forest.Size() > 0)
00274 {
00275
00276 dataFactory->WriteRandomForest(&forest);
00277 Feature::FeatureTable* ft = Feature::MakeFeatureTable(&forest);
00278 if(config.projectAndDump)
00279 ProjectAndDump(ft, featureSampling);
00280 dataFactory->WriteCodebook(ft);
00281 ILOG_INFO("saved code books");
00282 delete ft;
00283 }
00284 delete dataFactory;
00285 }
00286 return 0;
00287 }
00288
00289
00290 }
00291
00292 int
00293 main(int argc, char* argv[])
00294 {
00295 return RandomForest(argc, argv);
00296 }
00297