00001 #ifndef Impala_Core_Feature_MakeRandomTree_h
00002 #define Impala_Core_Feature_MakeRandomTree_h
00003
00004 #include "Core/Histogram/Entropy.h"
00005 #include "Core/Feature/RandomTree.h"
00006 #include "Util/FilterOperations.h"
00007 #include "Util/MakeFilter.h"
00008 #include "Util/Random.h"
00009
00010 namespace Impala
00011 {
00012 namespace Core
00013 {
00014 namespace Feature
00015 {
00016
00017 typedef Table::TableTem<Vector::ColumnVectorSet,
00018 Column::ColumnInt32> AnnotatedFeatureTable;
00019
00020 void
00021 Dump(AnnotatedFeatureTable* t, std::ostream& os)
00022 {
00023 for(int i=0 ; i<t->Size() ; ++i)
00024 {
00025 os << t->Get2(i) <<" ";
00026 Vector::VectorReal64 v = t->Get1(i);
00027 for(int i=0 ; i<v.Size() ; ++i)
00028 os << v[i] <<" ";
00029 os << "\n";
00030 }
00031 }
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00051 Histogram::Histogram1dTem<int>*
00052 MakeHistogram(const AnnotatedFeatureTable* data, int nrClasses, bool* filter)
00053 {
00054 int outLyers;
00055 Histogram::Histogram1dTem<int>* hist =
00056 new Histogram::Histogram1dTem<int>(-0.5, nrClasses+0.5, nrClasses,
00057 &outLyers);
00058 String a=data->GetInfo();
00059 for(int y=0 ; y<data->Size() ; ++y)
00060 if(filter[y])
00061 hist->AddWeight(data->Get2(y), 1.);
00062 return hist;
00063 }
00064
00067 double
00068 Gain(const AnnotatedFeatureTable* data, int nrClasses, bool* left, bool* right)
00069 {
00070 int size = data->Size();
00071 double nrL = Util::Count(left, size);
00072 double nrR = Util::Count(right, size);
00073 double total = nrL+nrR;
00074 if(total==0)
00075 return 0;
00076 Histogram::Histogram1dTem<int>* histL = MakeHistogram(data, nrClasses, left);
00077 Histogram::Histogram1dTem<int>* histR = MakeHistogram(data, nrClasses, right);
00078 double gain = - (nrL / total) * Entropy(histL) - (nrR / total) * Entropy(histR);
00079 delete histL;
00080 delete histR;
00081 return gain;
00082 }
00083
00089 void
00090 SplitSet(bool*& left, bool*& right, int dimension, double value,
00091 const AnnotatedFeatureTable* data, bool* filter)
00092 {
00093 int size = data->Size();
00094 bool* less = Util::MakeFilterElemLess(data->GetColumn1()->GetStorage(),
00095 dimension, value);
00096 left = Util::MakeIntersection(filter, less, size);
00097 right = Util::MakeRelativeComplement(filter, left, size);
00098 delete less;
00099 }
00100
00104 void
00105 TryRandomSplit(int& dimension, double& value, double& gain,
00106 const AnnotatedFeatureTable* data, bool* filter, int nrClasses,
00107 Util::Random& rng)
00108 {
00109 dimension = rng.GetInt(data->GetColumn1()->GetVectorLength(0));
00110 int n = Util::Count(filter, data->Size());
00111 int index;
00112 if(n==0)
00113 index = rng.GetInt(data->Size());
00114 else
00115 index = Util::IndexOfNth(filter, data->Size(), rng.GetInt(n));
00116 value = data->Get1(index)[dimension];
00117 bool* left = 0;
00118 bool* right = 0;
00119 SplitSet(left, right, dimension, value, data, filter);
00120 gain = Gain(data, nrClasses, left, right);
00121 delete left;
00122 delete right;
00123 }
00124
00131 void
00132 FindSplit(int& dimension, double& value, const AnnotatedFeatureTable* data,
00133 bool* filter, int nrClasses, int nrTrials,
00134 Util::Random& rng)
00135 {
00136 int bestDim;
00137 double bestVal;
00138 double bestGain;
00139 TryRandomSplit(bestDim, bestVal, bestGain, data, filter, nrClasses, rng);
00140 for(int i=1 ; i<nrTrials ; ++i)
00141 {
00142 int dim;
00143 double val;
00144 double gain;
00145 TryRandomSplit(dim, val, gain, data, filter, nrClasses, rng);
00146 if(gain > bestGain)
00147 {
00148 bestGain = gain;
00149 bestVal = val;
00150 bestDim = dim;
00151 }
00152 }
00153 dimension = bestDim;
00154 value = bestVal;
00155 }
00156
00159 int
00160 GetCodeWord()
00161 {
00162 static int sCodeWord=-1;
00163 ++sCodeWord;
00164 return sCodeWord;
00165 }
00166
00170 Feature::RandomTree*
00171 MakeRandomTree(const AnnotatedFeatureTable* data, bool* filter,
00172 int nrClasses, int maxDepth, int nrTrials,
00173 Util::Random& rng)
00174 {
00175 ILOG_FUNCTION(MakeRandomTree);
00176 if(maxDepth <= 0)
00177 {
00178 int code = GetCodeWord();
00179 ILOG_PROGRESS("code word "<< code, 4.);
00180 return new RandomTree(code, Util::Count(filter, data->Size()));
00181 }
00182
00183
00184 int dimension;
00185 double value;
00186 FindSplit(dimension, value, data, filter, nrClasses, nrTrials, rng);
00187
00188 bool* leftFilter=0;
00189 bool* rightFilter=0;
00190 SplitSet(leftFilter, rightFilter, dimension, value, data, filter);
00191
00192 RandomTree* left =
00193 MakeRandomTree(data, leftFilter, nrClasses, maxDepth-1, nrTrials, rng);
00194 RandomTree* right =
00195 MakeRandomTree(data, rightFilter, nrClasses, maxDepth-1, nrTrials, rng);
00196 delete leftFilter;
00197 delete rightFilter;
00198 return new RandomTree(dimension, value, left, right);
00199 }
00200
00201
00202 RandomTree*
00203 MakeRandomTree(const AnnotatedFeatureTable* data,
00204 int nrClasses, int maxDepth, int nrTrials,
00205 Util::Random& rng)
00206 {
00207 int size = data->Size();
00208 bool* filter = new bool[size];
00209 for(int i=0 ; i<size ; ++i)
00210 {
00211 filter[i] = true;
00212 }
00213 RandomTree* tree = MakeRandomTree(data, filter, nrClasses, maxDepth, nrTrials, rng);
00214 delete filter;
00215 return tree;
00216 }
00217
00218 }
00219 }
00220 }
00221
00222 #endif