00001 #ifndef Impala_Core_Feature_Test_TestMakeRandomTree_h
00002 #define Impala_Core_Feature_Test_TestMakeRandomTree_h
00003
00004 #include <cppunit/extensions/HelperMacros.h>
00005 #include <algorithm>
00006
00007 #include "Core/Feature/MakeRandomTree.h"
00008
00009 namespace Impala
00010 {
00011 namespace Core
00012 {
00013 namespace Feature
00014 {
00015
00016 class TestMakeRandomTree : public CPPUNIT_NS::TestFixture
00017 {
00018 CPPUNIT_TEST_SUITE(TestMakeRandomTree);
00019 CPPUNIT_TEST(testMakeHistogram);
00020 CPPUNIT_TEST(testGain);
00021 CPPUNIT_TEST(testSplitSet);
00022 CPPUNIT_TEST(testTryRandomSplit);
00023 CPPUNIT_TEST(testFindSplit);
00024 CPPUNIT_TEST(testTopLevel);
00025 CPPUNIT_TEST(testReferenceMatlab);
00026 CPPUNIT_TEST_SUITE_END();
00027
00028 public:
00029
00030 void
00031 setUp()
00032 {
00033 mNumberClasses = 4;
00034 mFakeData = new AnnotatedFeatureTable
00035 (Vector::ColumnVectorSet(true, 3, 0), Column::ColumnInt32(0));
00036 mFakeData->Add(Vector::VectorReal64(0.00,-2.00, 6.00), 0);
00037 mFakeData->Add(Vector::VectorReal64(0.01,-2.01, 6.01), 0);
00038 mFakeData->Add(Vector::VectorReal64(0.02,-2.02, 6.02), 0);
00039 mFakeData->Add(Vector::VectorReal64(1.03, 0.03, 1.03), 1);
00040 mFakeData->Add(Vector::VectorReal64(2.04, 1.04, 2.04), 1);
00041 mFakeData->Add(Vector::VectorReal64(3.05, 0.05, 3.05), 1);
00042 mFakeData->Add(Vector::VectorReal64(3.06, 1.06, 1.06), 2);
00043 mFakeData->Add(Vector::VectorReal64(2.07, 0.07, 2.07), 2);
00044 mFakeData->Add(Vector::VectorReal64(1.08, 1.08, 3.08), 2);
00045 mFakeData->Add(Vector::VectorReal64(2.09, 0.09, 6.09), 3);
00046 mFakeData->Add(Vector::VectorReal64(2.10, 1.10, 6.10), 3);
00047 mFakeData->Add(Vector::VectorReal64(2.11, 0.11, 6.11), 3);
00048
00049 for(int i=0 ; i<12 ; ++i)
00050 {
00051 mFilterAll[i] = true;
00052 mFilterHalf[i] = i<6;
00053 mFilterHalfComp[i] = i>=6;
00054 mFilterQuarter[i] = i<3;
00055 mFilterQuarterComp[i] = i>=3;
00056 mFilterOdd[i] = i%2 == 0;
00057 mFilterEven[i] = i%2 == 1;
00058 }
00059 }
00060
00061 void
00062 tearDown()
00063 {
00064 delete mFakeData;
00065 }
00066
00067 void
00068 testMakeHistogram()
00069 {
00070 Histogram::Histogram1dTem<int>* histAll =
00071 MakeHistogram(mFakeData, mNumberClasses, mFilterAll);
00072 Histogram::Histogram1dTem<int>* histHalf =
00073 MakeHistogram(mFakeData, mNumberClasses, mFilterHalf);
00074 Histogram::Histogram1dTem<int>* histOdd =
00075 MakeHistogram(mFakeData, mNumberClasses, mFilterOdd);
00076 CPPUNIT_ASSERT_EQUAL(mNumberClasses, histAll->Size());
00077 CPPUNIT_ASSERT_EQUAL(mNumberClasses, histHalf->Size());
00078 CPPUNIT_ASSERT_EQUAL(mNumberClasses, histOdd->Size());
00079
00080
00081 for(int i=0 ; i<4 ; ++i)
00082 {
00083 CPPUNIT_ASSERT_EQUAL(3, histAll->Elem(i));
00084 CPPUNIT_ASSERT_EQUAL((i<2) ? 3 : 0, histHalf->Elem(i));
00085 CPPUNIT_ASSERT_EQUAL((i%2) ? 1 : 2, histOdd->Elem(i));
00086 }
00087 delete histAll;
00088 delete histHalf;
00089 delete histOdd;
00090 }
00091
00092 void
00093 testGain()
00094 {
00095 double gainHalf = Gain(mFakeData, mNumberClasses,
00096 mFilterHalf, mFilterHalfComp);
00097 double gainQuarter = Gain(mFakeData, mNumberClasses,
00098 mFilterQuarter, mFilterQuarterComp);
00099 double gainOdd = Gain(mFakeData, mNumberClasses,
00100 mFilterOdd, mFilterEven);
00101 CPPUNIT_ASSERT(gainHalf > gainQuarter);
00102 CPPUNIT_ASSERT(gainQuarter > gainOdd);
00103 }
00104
00105 void
00106 testSplitSet()
00107 {
00108 bool* left;
00109 bool* right;
00110 SplitSet(left, right, 1, -1., mFakeData, mFilterAll);
00111 CPPUNIT_ASSERT(Util::Equal(left, mFilterQuarter, 12));
00112 CPPUNIT_ASSERT(Util::Equal(right, mFilterQuarterComp, 12));
00113 delete left;
00114 delete right;
00115 SplitSet(left, right, 2, -3., mFakeData, mFilterOdd);
00116 CPPUNIT_ASSERT_EQUAL(0, Util::Count(left, 12));
00117 CPPUNIT_ASSERT(Util::Equal(right, mFilterOdd, 12));
00118 }
00119
00120 void
00121 testTryRandomSplit()
00122 {
00123 int dim;
00124 double val;
00125 double gain;
00126 bool filter[12] = {false};
00127 filter[3] = true;
00128 Util::Random rng;
00129 TryRandomSplit(dim, val, gain, mFakeData, filter, mNumberClasses, rng);
00130 CPPUNIT_ASSERT_EQUAL(mFakeData->Get1(3)[dim], val);
00131 filter[3] = false;
00132 filter[7] = true;
00133 TryRandomSplit(dim, val, gain, mFakeData, filter, mNumberClasses, rng);
00134 CPPUNIT_ASSERT_EQUAL(mFakeData->Get1(7)[dim], val);
00135 }
00136
00137 void
00138 testFindSplitWithSeed(int seed)
00139 {
00140
00141 int dim;
00142 double val;
00143 double gain;
00144 Util::Random rng;
00145 rng.SetSeed(seed);
00146 TryRandomSplit(dim, val, gain, mFakeData, mFilterAll, mNumberClasses, rng);
00147 rng.SetSeed(seed);
00148 FindSplit(dim, val, mFakeData, mFilterAll, mNumberClasses, 20, rng);
00149 bool* left;
00150 bool* right;
00151 SplitSet(left, right, dim, val, mFakeData, mFilterAll);
00152
00153
00154 CPPUNIT_ASSERT(gain <= Gain(mFakeData, mNumberClasses, left, right));
00155 }
00156
00157 void
00158 testFindSplit()
00159 {
00160 testFindSplitWithSeed(1);
00161 testFindSplitWithSeed(2);
00162 testFindSplitWithSeed(3);
00163 }
00164
00165 void
00166 testTopLevel()
00167 {
00168 Util::Random rng;
00169 RandomTree* tree = MakeRandomTree(mFakeData, mNumberClasses, 2, 1000, rng);
00170
00171 int dim;
00172 double val;
00173 tree->GetSplit(dim, val);
00174 ILOG_DEBUG("dim:"<< dim <<" val:"<< val);
00175 CPPUNIT_ASSERT_EQUAL(2, dim);
00176 CPPUNIT_ASSERT_EQUAL(6.00, val);
00177 }
00178
00179 void
00180 testReferenceMatlab()
00181 {
00182 for(int i=0 ; i<3 ; ++i)
00183 ReferenceTest();
00184 }
00185
00186 void
00187 ReferenceTest()
00188 {
00189
00190
00191
00192
00193
00194
00195
00196
00197
00198
00199
00200
00201
00202
00203
00204
00205
00206
00207
00208
00209
00210
00211
00212
00213
00214
00215
00216
00217
00218 AnnotatedFeatureTable data(Vector::ColumnVectorSet(true, 2, 0),
00219 Column::ColumnInt32(0));
00220 data.Add(Vector::VectorReal64(1, 1), 1);
00221 data.Add(Vector::VectorReal64(1, 1), 1);
00222 data.Add(Vector::VectorReal64(0, 0), 2);
00223 data.Add(Vector::VectorReal64(0, 0), 2);
00224 data.Add(Vector::VectorReal64(0, 0), 2);
00225 data.Add(Vector::VectorReal64(0, 1), 3);
00226 data.Add(Vector::VectorReal64(0, 1), 3);
00227 data.Add(Vector::VectorReal64(0, 1), 3);
00228 data.Add(Vector::VectorReal64(0, 1), 3);
00229 data.Add(Vector::VectorReal64(1, 0), 4);
00230 data.Add(Vector::VectorReal64(1, 0), 4);
00231 data.Add(Vector::VectorReal64(1, 0), 4);
00232 data.Add(Vector::VectorReal64(1, 0), 4);
00233 data.Add(Vector::VectorReal64(1, 0), 4);
00234 RandomTree* tree = MakeRandomTree(&data, 5, 2, 25, Util::Random::GetGlobal());
00235 }
00236
00237 private:
00238 int mNumberClasses;
00239 AnnotatedFeatureTable* mFakeData;
00240 bool mFilterAll[12];
00241 bool mFilterHalf[12];
00242 bool mFilterHalfComp[12];
00243 bool mFilterQuarter[12];
00244 bool mFilterQuarterComp[12];
00245 bool mFilterOdd[12];
00246 bool mFilterEven[12];
00247 ILOG_CLASS;
00248 };
00249
00250 ILOG_CLASS_INIT(TestMakeRandomTree, Impala.Core.Feature);
00251
00252 CPPUNIT_TEST_SUITE_REGISTRATION( TestMakeRandomTree );
00253
00254 }
00255 }
00256 }
00257
00258 #endif
00259