00001 #ifndef Impala_Core_Feature_Test_TestMakeRandomTree_h
00002 #define Impala_Core_Feature_Test_TestMakeRandomTree_h
00003
00004 #include <cppunit/extensions/HelperMacros.h>
00005 #include <algorithm>
00006
00007 #include "Core/Feature/MakeRandomTree.h"
00008
00009 namespace Impala
00010 {
00011 namespace Core
00012 {
00013 namespace Feature
00014 {
00015
00016 class TestMakeRandomTree : public CPPUNIT_NS::TestFixture
00017 {
00018 CPPUNIT_TEST_SUITE(TestMakeRandomTree);
00019 CPPUNIT_TEST(testMakeHistogram);
00020 CPPUNIT_TEST(testGain);
00021 CPPUNIT_TEST(testSplitSet);
00022 CPPUNIT_TEST(testTryRandomSplit);
00023 CPPUNIT_TEST(testFindSplit);
00024 CPPUNIT_TEST(testTopLevel);
00025 CPPUNIT_TEST(testReferenceMatlab);
00026 CPPUNIT_TEST_SUITE_END();
00027
00028 public:
00029
00030 void
00031 setUp()
00032 {
00033 mNumberClasses = 4;
00034 mFakeData = new AnnotatedFeatureTable
00035 (Vector::ColumnVectorSet(true, 3, 0), Column::ColumnInt32(0));
00036 mFakeData->Add(Vector::VectorReal64(0.00,-2.00, 6.00), 0);
00037 mFakeData->Add(Vector::VectorReal64(0.01,-2.01, 6.01), 0);
00038 mFakeData->Add(Vector::VectorReal64(0.02,-2.02, 6.02), 0);
00039 mFakeData->Add(Vector::VectorReal64(1.03, 0.03, 1.03), 1);
00040 mFakeData->Add(Vector::VectorReal64(2.04, 1.04, 2.04), 1);
00041 mFakeData->Add(Vector::VectorReal64(3.05, 0.05, 3.05), 1);
00042 mFakeData->Add(Vector::VectorReal64(3.06, 1.06, 1.06), 2);
00043 mFakeData->Add(Vector::VectorReal64(2.07, 0.07, 2.07), 2);
00044 mFakeData->Add(Vector::VectorReal64(1.08, 1.08, 3.08), 2);
00045 mFakeData->Add(Vector::VectorReal64(2.09, 0.09, 6.09), 3);
00046 mFakeData->Add(Vector::VectorReal64(2.10, 1.10, 6.10), 3);
00047 mFakeData->Add(Vector::VectorReal64(2.11, 0.11, 6.11), 3);
00048
00049 for(int i=0 ; i<12 ; ++i)
00050 {
00051 mFilterAll[i] = true;
00052 mFilterHalf[i] = i<6;
00053 mFilterHalfComp[i] = i>=6;
00054 mFilterQuarter[i] = i<3;
00055 mFilterQuarterComp[i] = i>=3;
00056 mFilterOdd[i] = i%2 == 0;
00057 mFilterEven[i] = i%2 == 1;
00058 }
00059 }
00060
00061 void
00062 tearDown()
00063 {
00064 delete mFakeData;
00065 }
00066
00067 void
00068 testMakeHistogram()
00069 {
00070 Histogram::Histogram1dTem<int>* histAll =
00071 MakeHistogram(mFakeData, mNumberClasses, mFilterAll);
00072 Histogram::Histogram1dTem<int>* histHalf =
00073 MakeHistogram(mFakeData, mNumberClasses, mFilterHalf);
00074 Histogram::Histogram1dTem<int>* histOdd =
00075 MakeHistogram(mFakeData, mNumberClasses, mFilterOdd);
00076 CPPUNIT_ASSERT_EQUAL(mNumberClasses, histAll->Size());
00077 CPPUNIT_ASSERT_EQUAL(mNumberClasses, histHalf->Size());
00078 CPPUNIT_ASSERT_EQUAL(mNumberClasses, histOdd->Size());
00079
00080
00081 for(int i=0 ; i<4 ; ++i)
00082 {
00083 CPPUNIT_ASSERT_EQUAL(3, histAll->Elem(i));
00084 CPPUNIT_ASSERT_EQUAL((i<2) ? 3 : 0, histHalf->Elem(i));
00085 CPPUNIT_ASSERT_EQUAL((i%2) ? 1 : 2, histOdd->Elem(i));
00086 }
00087 delete histAll;
00088 delete histHalf;
00089 delete histOdd;
00090 }
00091
00092 void
00093 testGain()
00094 {
00095 double gainHalf = Gain(mFakeData, mNumberClasses,
00096 mFilterHalf, mFilterHalfComp);
00097 double gainQuarter = Gain(mFakeData, mNumberClasses,
00098 mFilterQuarter, mFilterQuarterComp);
00099 double gainOdd = Gain(mFakeData, mNumberClasses,
00100 mFilterOdd, mFilterEven);
00101 CPPUNIT_ASSERT(gainHalf > gainQuarter);
00102 CPPUNIT_ASSERT(gainQuarter > gainOdd);
00103 }
00104
00105 void
00106 testSplitSet()
00107 {
00108 bool* left;
00109 bool* right;
00110 SplitSet(left, right, 1, -1., mFakeData, mFilterAll);
00111 CPPUNIT_ASSERT(Util::Equal(left, mFilterQuarter, 12));
00112 CPPUNIT_ASSERT(Util::Equal(right, mFilterQuarterComp, 12));
00113 delete left;
00114 delete right;
00115 SplitSet(left, right, 2, -3., mFakeData, mFilterOdd);
00116 CPPUNIT_ASSERT_EQUAL(0, Util::Count(left, 12));
00117 CPPUNIT_ASSERT(Util::Equal(right, mFilterOdd, 12));
00118 }
00119
00120 void
00121 testTryRandomSplit()
00122 {
00123 int dim;
00124 double val;
00125 double gain;
00126 bool filter[12] = {false};
00127 filter[3] = true;
00128 TryRandomSplit(dim, val, gain, mFakeData, filter, mNumberClasses);
00129 CPPUNIT_ASSERT_EQUAL(mFakeData->Get1(3)[dim], val);
00130 filter[3] = false;
00131 filter[7] = true;
00132 TryRandomSplit(dim, val, gain, mFakeData, filter, mNumberClasses);
00133 CPPUNIT_ASSERT_EQUAL(mFakeData->Get1(7)[dim], val);
00134 }
00135
00136 void
00137 testFindSplitWithSeed(int seed)
00138 {
00139 Util::SetRandomSeed(seed);
00140
00141 int dim;
00142 double val;
00143 double gain;
00144 TryRandomSplit(dim, val, gain, mFakeData, mFilterAll, mNumberClasses);
00145 Util::SetRandomSeed(seed);
00146 FindSplit(dim, val, mFakeData, mFilterAll, mNumberClasses, 20);
00147 bool* left;
00148 bool* right;
00149 SplitSet(left, right, dim, val, mFakeData, mFilterAll);
00150
00151
00152 CPPUNIT_ASSERT(gain <= Gain(mFakeData, mNumberClasses, left, right));
00153 }
00154
00155 void
00156 testFindSplit()
00157 {
00158 testFindSplitWithSeed(1);
00159 testFindSplitWithSeed(2);
00160 testFindSplitWithSeed(3);
00161 }
00162
00163 void
00164 testTopLevel()
00165 {
00166 RandomTree* tree = MakeRandomTree(mFakeData, mNumberClasses, 2, 1000);
00167
00168 int dim;
00169 double val;
00170 tree->GetSplit(dim, val);
00171 ILOG_DEBUG("dim:"<< dim <<" val:"<< val);
00172 CPPUNIT_ASSERT_EQUAL(2, dim);
00173 CPPUNIT_ASSERT_EQUAL(6.00, val);
00174 }
00175
00176 void
00177 testReferenceMatlab()
00178 {
00179 for(int i=0 ; i<3 ; ++i)
00180 ReferenceTest();
00181 }
00182
00183 void
00184 ReferenceTest()
00185 {
00186
00187
00188
00189
00190
00191
00192
00193
00194
00195
00196
00197
00198
00199
00200
00201
00202
00203
00204
00205
00206
00207
00208
00209
00210
00211
00212
00213
00214
00215 AnnotatedFeatureTable data(Vector::ColumnVectorSet(true, 2, 0),
00216 Column::ColumnInt32(0));
00217 data.Add(Vector::VectorReal64(1, 1), 1);
00218 data.Add(Vector::VectorReal64(1, 1), 1);
00219 data.Add(Vector::VectorReal64(0, 0), 2);
00220 data.Add(Vector::VectorReal64(0, 0), 2);
00221 data.Add(Vector::VectorReal64(0, 0), 2);
00222 data.Add(Vector::VectorReal64(0, 1), 3);
00223 data.Add(Vector::VectorReal64(0, 1), 3);
00224 data.Add(Vector::VectorReal64(0, 1), 3);
00225 data.Add(Vector::VectorReal64(0, 1), 3);
00226 data.Add(Vector::VectorReal64(1, 0), 4);
00227 data.Add(Vector::VectorReal64(1, 0), 4);
00228 data.Add(Vector::VectorReal64(1, 0), 4);
00229 data.Add(Vector::VectorReal64(1, 0), 4);
00230 data.Add(Vector::VectorReal64(1, 0), 4);
00231 RandomTree* tree = MakeRandomTree(&data, 5, 2, 25);
00232 }
00233
00234 private:
00235 int mNumberClasses;
00236 AnnotatedFeatureTable* mFakeData;
00237 bool mFilterAll[12];
00238 bool mFilterHalf[12];
00239 bool mFilterHalfComp[12];
00240 bool mFilterQuarter[12];
00241 bool mFilterQuarterComp[12];
00242 bool mFilterOdd[12];
00243 bool mFilterEven[12];
00244 ILOG_CLASS;
00245 };
00246
00247 ILOG_CLASS_INIT(TestMakeRandomTree, Impala.Core.Feature);
00248
00249 CPPUNIT_TEST_SUITE_REGISTRATION( TestMakeRandomTree );
00250
00251 }
00252 }
00253 }
00254
00255 #endif
00256