00001 #ifndef Impala_Core_Feature_RandomTree_h
00002 #define Impala_Core_Feature_RandomTree_h
00003
00004 #include "Core/Vector/Types.h"
00005 #include "Core/Feature/RandomTreeTable.h"
00006
00007 namespace Impala
00008 {
00009 namespace Core
00010 {
00011 namespace Feature
00012 {
00013
00017 class RandomTree
00018 {
00019 public:
00020
00021 RandomTree(int codeword)
00022 {
00023 if(codeword < 0)
00024 ILOG_ERROR("illegal codeword in RandomTree node c'tor");
00025 mCodeWord = codeword;
00026 mSplitDimension = -1;
00027 mSplitValue = -1;
00028 mLeft = 0;
00029 mRight = 0;
00030 mCount = -1;
00031 }
00032
00033 RandomTree(int codeword, int count)
00034 {
00035 if(codeword < 0)
00036 ILOG_ERROR("illegal codeword in RandomTree node c'tor");
00037 mCodeWord = codeword;
00038 mSplitDimension = -1;
00039 mSplitValue = -1;
00040 mLeft = 0;
00041 mRight = 0;
00042 mCount = count;
00043 }
00044
00045 RandomTree(int splitDimension, double splitValue,
00046 RandomTree* left, RandomTree* right)
00047 {
00048 mCodeWord = -1;
00049 mSplitDimension = splitDimension;
00050 mSplitValue = splitValue;
00051 mLeft = left;
00052 mRight = right;
00053 mCount = -2;
00054 }
00055
00056 ~RandomTree()
00057 {
00058 if(mCodeWord == -1)
00059 {
00060 delete mLeft;
00061 delete mRight;
00062 }
00063 }
00064
00065 int
00066 GetCodeWord(const Vector::VectorReal64& v)
00067 {
00068 ILOG_DEBUG("GetCodeWord mCW="<<mCodeWord<<" mD="<<mSplitDimension<<
00069 " mV="<<mSplitValue);
00070 if(mCodeWord >= 0)
00071 return mCodeWord;
00072 ILOG_DEBUG("left="<<(int*)mLeft<<" right="<<(int*)mRight);
00073 if(v[mSplitDimension] < mSplitValue)
00074 return mLeft->GetCodeWord(v);
00075 return mRight->GetCodeWord(v);
00076 }
00077
00078 void
00079 GetSplit(int& dimension, double& value)
00080 {
00081 dimension = mSplitDimension;
00082 value = mSplitValue;
00083 }
00084
00085 void
00086 Dump(std::ostream& os, int nrNodes)
00087 {
00088 Dump(os, 0, nrNodes);
00089 }
00090
00091 void
00092 DumpCount(std::ostream& os)
00093 {
00094 if(mCodeWord >= 0)
00095 os << mCodeWord <<": "<< mCount << "\n";
00096 else
00097 {
00098 mLeft->DumpCount(os);
00099 mRight->DumpCount(os);
00100 }
00101 }
00102
00103 bool
00104 operator ==(const RandomTree& that)
00105 {
00106 if(mCodeWord != that.mCodeWord)
00107 return false;
00108 if(mCodeWord >= 0)
00109 return true;
00110 return mSplitValue == that.mSplitValue &&
00111 mSplitDimension == that.mSplitDimension &&
00112 *mLeft == *(that.mLeft) &&
00113 *mRight == *(that.mRight);
00114 }
00115
00116 bool
00117 operator !=(const RandomTree& that)
00118 {
00119 return ! operator ==(that);
00120 }
00121
00122
00123 private:
00124 void
00125 Dump(std::ostream& os, int level, int& nrNodes)
00126 {
00127 if(nrNodes == 0)
00128 return;
00129 --nrNodes;
00130 for(int i=0 ; i<level ; ++i)
00131 os << " ";
00132 if(mCodeWord >= 0)
00133 os <<"w:"<< mCodeWord << std::endl;
00134 else
00135 {
00136 os <<"d:"<< mSplitDimension <<" v:"<< mSplitValue << std::endl;
00137 mLeft->Dump(os, level+1, nrNodes);
00138 mRight->Dump(os, level+1, nrNodes);
00139 }
00140 }
00141
00142
00143 int mCodeWord;
00144 int mCount;
00145
00146 int mSplitDimension;
00147 double mSplitValue;
00148 RandomTree* mLeft;
00149 RandomTree* mRight;
00150 friend void Write(RandomTree*, RandomTreeTable*);
00151 ILOG_VAR_DECL;
00152 };
00153 ILOG_VAR_INIT(RandomTree, Impala.Core.Feature);
00154
00155
00157 void
00158 Write(RandomTree* tree, RandomTreeTable* table)
00159 {
00160 table->Add(tree->mCodeWord, tree->mSplitDimension, tree->mSplitValue);
00161 if(tree->mCodeWord == -1)
00162 {
00163 Write(tree->mLeft, table);
00164 Write(tree->mRight, table);
00165 }
00166 }
00167
00173 RandomTree*
00174 Read(RandomTreeTable* table, int& index)
00175 {
00176 ILOG_VAR(Impala.Core.Feature.RandomTree.Read);
00177 if(index > table->Size())
00178 ILOG_ERROR("parse error in Read(RandomTreeTable*,int)");
00179 int word = table->Get1(index);
00180 int dimension = table->Get2(index);
00181 double value = table->Get3(index);
00182 ++index;
00183 if(word>=0)
00184 return new RandomTree(word);
00185 else
00186 {
00187 RandomTree* l = Read(table, index);
00188 RandomTree* r = Read(table, index);
00189 return new RandomTree(dimension, value, l, r);
00190 }
00191 }
00192
00193 }
00194 }
00195 }
00196
00197 #endif