00001 #ifndef Impala_Core_Feature_MakeRandomTree_h
00002 #define Impala_Core_Feature_MakeRandomTree_h
00003
00004 #include "Core/Histogram/Entropy.h"
00005 #include "Core/Feature/RandomTree.h"
00006 #include "Util/FilterOperations.h"
00007 #include "Util/MakeFilter.h"
00008 #include "Util/Random.h"
00009
00010 namespace Impala
00011 {
00012 namespace Core
00013 {
00014 namespace Feature
00015 {
00016
00017 typedef Table::TableTem<Vector::ColumnVectorSet,
00018 Column::ColumnInt32> AnnotatedFeatureTable;
00019
00020 void
00021 Dump(AnnotatedFeatureTable* t, std::ostream& os)
00022 {
00023 for(int i=0 ; i<t->Size() ; ++i)
00024 {
00025 os << t->Get2(i) <<" ";
00026 Vector::VectorReal64 v = t->Get1(i);
00027 for(int i=0 ; i<v.Size() ; ++i)
00028 os << v[i] <<" ";
00029 os << "\n";
00030 }
00031 }
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00051 Histogram::Histogram1dTem<int>*
00052 MakeHistogram(const AnnotatedFeatureTable* data, int nrClasses, bool* filter)
00053 {
00054 int outLyers;
00055 Histogram::Histogram1dTem<int>* hist =
00056 new Histogram::Histogram1dTem<int>(-0.5, nrClasses+0.5, nrClasses,
00057 &outLyers);
00058 String a=data->GetInfo();
00059 for(int y=0 ; y<data->Size() ; ++y)
00060 if(filter[y])
00061 hist->AddWeight(data->Get2(y), 1.);
00062 return hist;
00063 }
00064
00067 double
00068 Gain(const AnnotatedFeatureTable* data, int nrClasses, bool* left, bool* right)
00069 {
00070 int size = data->Size();
00071 double nrL = Util::Count(left, size);
00072 double nrR = Util::Count(right, size);
00073 double total = nrL+nrR;
00074 if(total==0)
00075 return 0;
00076 Histogram::Histogram1dTem<int>* histL = MakeHistogram(data, nrClasses, left);
00077 Histogram::Histogram1dTem<int>* histR = MakeHistogram(data, nrClasses, right);
00078 double gain = - (nrL / total) * Entropy(histL) - (nrR / total) * Entropy(histR);
00079 delete histL;
00080 delete histR;
00081 return gain;
00082 }
00083
00089 void
00090 SplitSet(bool*& left, bool*& right, int dimension, double value,
00091 const AnnotatedFeatureTable* data, bool* filter)
00092 {
00093 int size = data->Size();
00094 bool* less = Util::MakeFilterElemLess(data->GetColumn1()->GetStorage(),
00095 dimension, value);
00096 left = Util::MakeIntersection(filter, less, size);
00097 right = Util::MakeRelativeComplement(filter, left, size);
00098 delete less;
00099 }
00100
00104 void
00105 TryRandomSplit(int& dimension, double& value, double& gain,
00106 const AnnotatedFeatureTable* data, bool* filter, int nrClasses)
00107 {
00108 dimension = Util::RandomInt(data->GetColumn1()->GetVectorLength(0));
00109 int n = Util::Count(filter, data->Size());
00110 int index;
00111 if(n==0)
00112 index = Util::RandomInt(data->Size());
00113 else
00114 index = Util::IndexOfNth(filter, data->Size(), Util::RandomInt(n));
00115 value = data->Get1(index)[dimension];
00116 bool* left = 0;
00117 bool* right = 0;
00118 SplitSet(left, right, dimension, value, data, filter);
00119 gain = Gain(data, nrClasses, left, right);
00120 delete left;
00121 delete right;
00122 }
00123
00130 void
00131 FindSplit(int& dimension, double& value, const AnnotatedFeatureTable* data,
00132 bool* filter, int nrClasses, int nrTrials)
00133 {
00134 int bestDim;
00135 double bestVal;
00136 double bestGain;
00137 TryRandomSplit(bestDim, bestVal, bestGain, data, filter, nrClasses);
00138 for(int i=1 ; i<nrTrials ; ++i)
00139 {
00140 int dim;
00141 double val;
00142 double gain;
00143 TryRandomSplit(dim, val, gain, data, filter, nrClasses);
00144 if(gain > bestGain)
00145 {
00146 bestGain = gain;
00147 bestVal = val;
00148 bestDim = dim;
00149 }
00150 }
00151 dimension = bestDim;
00152 value = bestVal;
00153 }
00154
00157 int
00158 GetCodeWord()
00159 {
00160 static int sCodeWord=-1;
00161 ++sCodeWord;
00162 return sCodeWord;
00163 }
00164
00168 Feature::RandomTree*
00169 MakeRandomTree(const AnnotatedFeatureTable* data, bool* filter,
00170 int nrClasses, int maxDepth, int nrTrials)
00171 {
00172 ILOG_FUNCTION(MakeRandomTree);
00173 if(maxDepth <= 0)
00174 {
00175 int code = GetCodeWord();
00176 ILOG_PROGRESS("code word "<< code, 4.);
00177 return new RandomTree(code, Util::Count(filter, data->Size()));
00178 }
00179
00180
00181 int dimension;
00182 double value;
00183 FindSplit(dimension, value, data, filter, nrClasses, nrTrials);
00184
00185 bool* leftFilter=0;
00186 bool* rightFilter=0;
00187 SplitSet(leftFilter, rightFilter, dimension, value, data, filter);
00188
00189 RandomTree* left =
00190 MakeRandomTree(data, leftFilter, nrClasses, maxDepth-1, nrTrials);
00191 RandomTree* right =
00192 MakeRandomTree(data, rightFilter, nrClasses, maxDepth-1, nrTrials);
00193 delete leftFilter;
00194 delete rightFilter;
00195 return new RandomTree(dimension, value, left, right);
00196 }
00197
00198
00199 RandomTree*
00200 MakeRandomTree(const AnnotatedFeatureTable* data,
00201 int nrClasses, int maxDepth, int nrTrials)
00202 {
00203 int size = data->Size();
00204 bool* filter = new bool[size];
00205 for(int i=0 ; i<size ; ++i)
00206 {
00207 filter[i] = true;
00208 }
00209 RandomTree* tree = MakeRandomTree(data, filter, nrClasses, maxDepth, nrTrials);
00210 delete filter;
00211 return tree;
00212 }
00213
00214 }
00215 }
00216 }
00217
00218 #endif