Home || Visual Search || Applications || Architecture || Important Messages || OGL || Src

TestMakeRandomTree.h

Go to the documentation of this file.
00001 #ifndef Impala_Core_Feature_Test_TestMakeRandomTree_h
00002 #define Impala_Core_Feature_Test_TestMakeRandomTree_h
00003 
00004 #include <cppunit/extensions/HelperMacros.h>
00005 #include <algorithm>
00006 
00007 #include "Core/Feature/MakeRandomTree.h"
00008 
00009 namespace Impala
00010 {
00011 namespace Core
00012 {
00013 namespace Feature
00014 {
00015 
00016 class TestMakeRandomTree : public CPPUNIT_NS::TestFixture
00017 {
00018     CPPUNIT_TEST_SUITE(TestMakeRandomTree);
00019     CPPUNIT_TEST(testMakeHistogram);
00020     CPPUNIT_TEST(testGain);
00021     CPPUNIT_TEST(testSplitSet);
00022     CPPUNIT_TEST(testTryRandomSplit);
00023     CPPUNIT_TEST(testFindSplit);
00024     CPPUNIT_TEST(testTopLevel);
00025     CPPUNIT_TEST(testReferenceMatlab);
00026     CPPUNIT_TEST_SUITE_END();
00027     
00028 public:
00029 
00030     void
00031     setUp()
00032     {
00033         mNumberClasses = 4;
00034         mFakeData = new AnnotatedFeatureTable
00035             (Vector::ColumnVectorSet(true, 3, 0), Column::ColumnInt32(0));
00036         mFakeData->Add(Vector::VectorReal64(0.00,-2.00, 6.00), 0);
00037         mFakeData->Add(Vector::VectorReal64(0.01,-2.01, 6.01), 0);
00038         mFakeData->Add(Vector::VectorReal64(0.02,-2.02, 6.02), 0);
00039         mFakeData->Add(Vector::VectorReal64(1.03, 0.03, 1.03), 1);
00040         mFakeData->Add(Vector::VectorReal64(2.04, 1.04, 2.04), 1);
00041         mFakeData->Add(Vector::VectorReal64(3.05, 0.05, 3.05), 1);
00042         mFakeData->Add(Vector::VectorReal64(3.06, 1.06, 1.06), 2);
00043         mFakeData->Add(Vector::VectorReal64(2.07, 0.07, 2.07), 2);
00044         mFakeData->Add(Vector::VectorReal64(1.08, 1.08, 3.08), 2);
00045         mFakeData->Add(Vector::VectorReal64(2.09, 0.09, 6.09), 3);
00046         mFakeData->Add(Vector::VectorReal64(2.10, 1.10, 6.10), 3);
00047         mFakeData->Add(Vector::VectorReal64(2.11, 0.11, 6.11), 3);
00048 
00049         for(int i=0 ; i<12 ; ++i)
00050         {
00051             mFilterAll[i] = true;
00052             mFilterHalf[i] = i<6;
00053             mFilterHalfComp[i] = i>=6;
00054             mFilterQuarter[i] = i<3;
00055             mFilterQuarterComp[i] = i>=3;
00056             mFilterOdd[i] = i%2 == 0;
00057             mFilterEven[i] = i%2 == 1;
00058         }
00059     }
00060 
00061     void
00062     tearDown()
00063     {
00064         delete mFakeData;
00065     }
00066 
00067     void
00068     testMakeHistogram()
00069     {
00070         Histogram::Histogram1dTem<int>* histAll =
00071             MakeHistogram(mFakeData, mNumberClasses, mFilterAll);
00072         Histogram::Histogram1dTem<int>* histHalf =
00073             MakeHistogram(mFakeData, mNumberClasses, mFilterHalf);
00074         Histogram::Histogram1dTem<int>* histOdd =
00075             MakeHistogram(mFakeData, mNumberClasses, mFilterOdd);
00076         CPPUNIT_ASSERT_EQUAL(mNumberClasses, histAll->Size());
00077         CPPUNIT_ASSERT_EQUAL(mNumberClasses, histHalf->Size());
00078         CPPUNIT_ASSERT_EQUAL(mNumberClasses, histOdd->Size());
00079         // this loop iterates over the four elements in the histograms and does
00080         // the appropriate tests.
00081         for(int i=0 ; i<4 ; ++i)
00082         {
00083             CPPUNIT_ASSERT_EQUAL(3, histAll->Elem(i));
00084             CPPUNIT_ASSERT_EQUAL((i<2) ? 3 : 0, histHalf->Elem(i));
00085             CPPUNIT_ASSERT_EQUAL((i%2) ? 1 : 2, histOdd->Elem(i));
00086         }
00087         delete histAll;
00088         delete histHalf;
00089         delete histOdd;
00090     }
00091 
00092     void
00093     testGain()
00094     {
00095         double gainHalf = Gain(mFakeData, mNumberClasses,
00096                                mFilterHalf, mFilterHalfComp);
00097         double gainQuarter = Gain(mFakeData, mNumberClasses,
00098                                   mFilterQuarter, mFilterQuarterComp);
00099         double gainOdd = Gain(mFakeData, mNumberClasses,
00100                               mFilterOdd, mFilterEven);
00101         CPPUNIT_ASSERT(gainHalf > gainQuarter);
00102         CPPUNIT_ASSERT(gainQuarter > gainOdd);
00103     }
00104     
00105     void
00106     testSplitSet()
00107     {
00108         bool* left;
00109         bool* right;
00110         SplitSet(left, right, 1, -1., mFakeData, mFilterAll);
00111         CPPUNIT_ASSERT(Util::Equal(left, mFilterQuarter, 12));
00112         CPPUNIT_ASSERT(Util::Equal(right, mFilterQuarterComp, 12));
00113         delete left;
00114         delete right;
00115         SplitSet(left, right, 2, -3., mFakeData, mFilterOdd);
00116         CPPUNIT_ASSERT_EQUAL(0, Util::Count(left, 12));
00117         CPPUNIT_ASSERT(Util::Equal(right, mFilterOdd, 12));
00118     }
00119 
00120     void
00121     testTryRandomSplit()
00122     {
00123         int dim;
00124         double val;
00125         double gain;
00126         bool filter[12] = {false};
00127         filter[3] = true;
00128         Util::Random rng;
00129         TryRandomSplit(dim, val, gain, mFakeData, filter, mNumberClasses, rng);
00130         CPPUNIT_ASSERT_EQUAL(mFakeData->Get1(3)[dim], val);
00131         filter[3] = false;
00132         filter[7] = true;
00133         TryRandomSplit(dim, val, gain, mFakeData, filter, mNumberClasses, rng);
00134         CPPUNIT_ASSERT_EQUAL(mFakeData->Get1(7)[dim], val);
00135     }
00136 
00137     void
00138     testFindSplitWithSeed(int seed)
00139     {
00140         // get the gain of the first random split
00141         int dim;
00142         double val;
00143         double gain;
00144         Util::Random rng;
00145         rng.SetSeed(seed);
00146         TryRandomSplit(dim, val, gain, mFakeData, mFilterAll, mNumberClasses, rng);
00147         rng.SetSeed(seed);
00148         FindSplit(dim, val, mFakeData, mFilterAll, mNumberClasses, 20, rng);
00149         bool* left;
00150         bool* right;
00151         SplitSet(left, right, dim, val, mFakeData, mFilterAll);
00152         // Because we reset the random seed the gain of SplitSet must be greater
00153         // or equal to the previously computed gain.
00154         CPPUNIT_ASSERT(gain <= Gain(mFakeData, mNumberClasses, left, right));
00155     }
00156 
00157     void
00158     testFindSplit()
00159     {
00160         testFindSplitWithSeed(1);
00161         testFindSplitWithSeed(2);
00162         testFindSplitWithSeed(3);
00163     }
00164 
00165     void
00166     testTopLevel()
00167     {
00168         Util::Random rng;
00169         RandomTree* tree = MakeRandomTree(mFakeData, mNumberClasses, 2, 1000, rng);
00170         //we'll assume 50 tries leads to optimal results (only 12 data points)
00171         int dim;
00172         double val;
00173         tree->GetSplit(dim, val);
00174         ILOG_DEBUG("dim:"<< dim <<" val:"<< val);
00175         CPPUNIT_ASSERT_EQUAL(2, dim);
00176         CPPUNIT_ASSERT_EQUAL(6.00, val);
00177     }
00178 
00179     void
00180     testReferenceMatlab()
00181     {
00182         for(int i=0 ; i<3 ; ++i)
00183             ReferenceTest();
00184     }
00185     
00186     void
00187     ReferenceTest()
00188     {
00189         /* problem and outcome copied from reference implementation of Jasper
00190 
00191            data = ...
00192            [1 1;
00193             1 1;
00194             0 0;
00195             0 0;
00196             0 0;
00197             0 1;
00198             0 1;
00199             0 1;
00200             0 1;
00201             1 0;
00202             1 0;
00203             1 0;
00204             1 0;
00205             1 0];
00206 
00207            class = [1;1;2;2;2;3;3;3;3;4;4;4;4;4];
00208 
00209            % depth = 2, nTrial = 25
00210            [maps boundaries counts] = RandomEntropyTreeIdxTest(data, class, 2, 25);
00211   
00212            === Output ===
00213            gain = -0.9242
00214            gain = 0
00215            gain = 0
00216          */
00217 
00218         AnnotatedFeatureTable data(Vector::ColumnVectorSet(true, 2, 0),
00219                                    Column::ColumnInt32(0));
00220         data.Add(Vector::VectorReal64(1, 1), 1);
00221         data.Add(Vector::VectorReal64(1, 1), 1);
00222         data.Add(Vector::VectorReal64(0, 0), 2);
00223         data.Add(Vector::VectorReal64(0, 0), 2);
00224         data.Add(Vector::VectorReal64(0, 0), 2);
00225         data.Add(Vector::VectorReal64(0, 1), 3);
00226         data.Add(Vector::VectorReal64(0, 1), 3);
00227         data.Add(Vector::VectorReal64(0, 1), 3);
00228         data.Add(Vector::VectorReal64(0, 1), 3);
00229         data.Add(Vector::VectorReal64(1, 0), 4);
00230         data.Add(Vector::VectorReal64(1, 0), 4);
00231         data.Add(Vector::VectorReal64(1, 0), 4);
00232         data.Add(Vector::VectorReal64(1, 0), 4);
00233         data.Add(Vector::VectorReal64(1, 0), 4);
00234         RandomTree* tree = MakeRandomTree(&data, 5, 2, 25, Util::Random::GetGlobal());
00235     }
00236     
00237 private:
00238     int mNumberClasses;
00239     AnnotatedFeatureTable* mFakeData;
00240     bool mFilterAll[12];
00241     bool mFilterHalf[12];
00242     bool mFilterHalfComp[12];
00243     bool mFilterQuarter[12];
00244     bool mFilterQuarterComp[12];
00245     bool mFilterOdd[12];
00246     bool mFilterEven[12];
00247     ILOG_CLASS;
00248 };
00249 
00250 ILOG_CLASS_INIT(TestMakeRandomTree, Impala.Core.Feature);
00251 
00252 CPPUNIT_TEST_SUITE_REGISTRATION( TestMakeRandomTree );
00253     
00254 } // namespace Feature
00255 } // namespace Sandbox
00256 } // namespace Impala
00257 
00258 #endif
00259     

Generated on Thu Jan 13 09:04:26 2011 for ImpalaSrc by  doxygen 1.5.1