Home || Architecture || Video Search || Visual Search || Scripts || Applications || Important Messages || OGL || Src

TestMakeRandomTree.h

Go to the documentation of this file.
00001 #ifndef Impala_Core_Feature_Test_TestMakeRandomTree_h
00002 #define Impala_Core_Feature_Test_TestMakeRandomTree_h
00003 
00004 #include <cppunit/extensions/HelperMacros.h>
00005 #include <algorithm>
00006 
00007 #include "Core/Feature/MakeRandomTree.h"
00008 
00009 namespace Impala
00010 {
00011 namespace Core
00012 {
00013 namespace Feature
00014 {
00015 
00016 class TestMakeRandomTree : public CPPUNIT_NS::TestFixture
00017 {
00018     CPPUNIT_TEST_SUITE(TestMakeRandomTree);
00019     CPPUNIT_TEST(testMakeHistogram);
00020     CPPUNIT_TEST(testGain);
00021     CPPUNIT_TEST(testSplitSet);
00022     CPPUNIT_TEST(testTryRandomSplit);
00023     CPPUNIT_TEST(testFindSplit);
00024     CPPUNIT_TEST(testTopLevel);
00025     CPPUNIT_TEST(testReferenceMatlab);
00026     CPPUNIT_TEST_SUITE_END();
00027     
00028 public:
00029 
00030     void
00031     setUp()
00032     {
00033         mNumberClasses = 4;
00034         mFakeData = new AnnotatedFeatureTable
00035             (Vector::ColumnVectorSet(true, 3, 0), Column::ColumnInt32(0));
00036         mFakeData->Add(Vector::VectorReal64(0.00,-2.00, 6.00), 0);
00037         mFakeData->Add(Vector::VectorReal64(0.01,-2.01, 6.01), 0);
00038         mFakeData->Add(Vector::VectorReal64(0.02,-2.02, 6.02), 0);
00039         mFakeData->Add(Vector::VectorReal64(1.03, 0.03, 1.03), 1);
00040         mFakeData->Add(Vector::VectorReal64(2.04, 1.04, 2.04), 1);
00041         mFakeData->Add(Vector::VectorReal64(3.05, 0.05, 3.05), 1);
00042         mFakeData->Add(Vector::VectorReal64(3.06, 1.06, 1.06), 2);
00043         mFakeData->Add(Vector::VectorReal64(2.07, 0.07, 2.07), 2);
00044         mFakeData->Add(Vector::VectorReal64(1.08, 1.08, 3.08), 2);
00045         mFakeData->Add(Vector::VectorReal64(2.09, 0.09, 6.09), 3);
00046         mFakeData->Add(Vector::VectorReal64(2.10, 1.10, 6.10), 3);
00047         mFakeData->Add(Vector::VectorReal64(2.11, 0.11, 6.11), 3);
00048 
00049         for(int i=0 ; i<12 ; ++i)
00050         {
00051             mFilterAll[i] = true;
00052             mFilterHalf[i] = i<6;
00053             mFilterHalfComp[i] = i>=6;
00054             mFilterQuarter[i] = i<3;
00055             mFilterQuarterComp[i] = i>=3;
00056             mFilterOdd[i] = i%2 == 0;
00057             mFilterEven[i] = i%2 == 1;
00058         }
00059     }
00060 
00061     void
00062     tearDown()
00063     {
00064         delete mFakeData;
00065     }
00066 
00067     void
00068     testMakeHistogram()
00069     {
00070         Histogram::Histogram1dTem<int>* histAll =
00071             MakeHistogram(mFakeData, mNumberClasses, mFilterAll);
00072         Histogram::Histogram1dTem<int>* histHalf =
00073             MakeHistogram(mFakeData, mNumberClasses, mFilterHalf);
00074         Histogram::Histogram1dTem<int>* histOdd =
00075             MakeHistogram(mFakeData, mNumberClasses, mFilterOdd);
00076         CPPUNIT_ASSERT_EQUAL(mNumberClasses, histAll->Size());
00077         CPPUNIT_ASSERT_EQUAL(mNumberClasses, histHalf->Size());
00078         CPPUNIT_ASSERT_EQUAL(mNumberClasses, histOdd->Size());
00079         // this loop iterates over the four elements in the histograms and does
00080         // the appropriate tests.
00081         for(int i=0 ; i<4 ; ++i)
00082         {
00083             CPPUNIT_ASSERT_EQUAL(3, histAll->Elem(i));
00084             CPPUNIT_ASSERT_EQUAL((i<2) ? 3 : 0, histHalf->Elem(i));
00085             CPPUNIT_ASSERT_EQUAL((i%2) ? 1 : 2, histOdd->Elem(i));
00086         }
00087         delete histAll;
00088         delete histHalf;
00089         delete histOdd;
00090     }
00091 
00092     void
00093     testGain()
00094     {
00095         double gainHalf = Gain(mFakeData, mNumberClasses,
00096                                mFilterHalf, mFilterHalfComp);
00097         double gainQuarter = Gain(mFakeData, mNumberClasses,
00098                                   mFilterQuarter, mFilterQuarterComp);
00099         double gainOdd = Gain(mFakeData, mNumberClasses,
00100                               mFilterOdd, mFilterEven);
00101         CPPUNIT_ASSERT(gainHalf > gainQuarter);
00102         CPPUNIT_ASSERT(gainQuarter > gainOdd);
00103     }
00104     
00105     void
00106     testSplitSet()
00107     {
00108         bool* left;
00109         bool* right;
00110         SplitSet(left, right, 1, -1., mFakeData, mFilterAll);
00111         CPPUNIT_ASSERT(Util::Equal(left, mFilterQuarter, 12));
00112         CPPUNIT_ASSERT(Util::Equal(right, mFilterQuarterComp, 12));
00113         delete left;
00114         delete right;
00115         SplitSet(left, right, 2, -3., mFakeData, mFilterOdd);
00116         CPPUNIT_ASSERT_EQUAL(0, Util::Count(left, 12));
00117         CPPUNIT_ASSERT(Util::Equal(right, mFilterOdd, 12));
00118     }
00119 
00120     void
00121     testTryRandomSplit()
00122     {
00123         int dim;
00124         double val;
00125         double gain;
00126         bool filter[12] = {false};
00127         filter[3] = true;
00128         TryRandomSplit(dim, val, gain, mFakeData, filter, mNumberClasses);
00129         CPPUNIT_ASSERT_EQUAL(mFakeData->Get1(3)[dim], val);
00130         filter[3] = false;
00131         filter[7] = true;
00132         TryRandomSplit(dim, val, gain, mFakeData, filter, mNumberClasses);
00133         CPPUNIT_ASSERT_EQUAL(mFakeData->Get1(7)[dim], val);
00134     }
00135 
00136     void
00137     testFindSplitWithSeed(int seed)
00138     {
00139         Util::SetRandomSeed(seed);
00140         // get the gain of the first random split
00141         int dim;
00142         double val;
00143         double gain;
00144         TryRandomSplit(dim, val, gain, mFakeData, mFilterAll, mNumberClasses);
00145         Util::SetRandomSeed(seed);
00146         FindSplit(dim, val, mFakeData, mFilterAll, mNumberClasses, 20);
00147         bool* left;
00148         bool* right;
00149         SplitSet(left, right, dim, val, mFakeData, mFilterAll);
00150         // Because we reset the random seed the gain of SplitSet must be greater
00151         // or equal to the previously computed gain.
00152         CPPUNIT_ASSERT(gain <= Gain(mFakeData, mNumberClasses, left, right));
00153     }
00154 
00155     void
00156     testFindSplit()
00157     {
00158         testFindSplitWithSeed(1);
00159         testFindSplitWithSeed(2);
00160         testFindSplitWithSeed(3);
00161     }
00162 
00163     void
00164     testTopLevel()
00165     {
00166         RandomTree* tree = MakeRandomTree(mFakeData, mNumberClasses, 2, 1000);
00167         //we'll assume 50 tries leads to optimal results (only 12 data points)
00168         int dim;
00169         double val;
00170         tree->GetSplit(dim, val);
00171         ILOG_DEBUG("dim:"<< dim <<" val:"<< val);
00172         CPPUNIT_ASSERT_EQUAL(2, dim);
00173         CPPUNIT_ASSERT_EQUAL(6.00, val);
00174     }
00175 
00176     void
00177     testReferenceMatlab()
00178     {
00179         for(int i=0 ; i<3 ; ++i)
00180             ReferenceTest();
00181     }
00182     
00183     void
00184     ReferenceTest()
00185     {
00186         /* problem and outcome copied from reference implementation of Jasper
00187 
00188            data = ...
00189            [1 1;
00190             1 1;
00191             0 0;
00192             0 0;
00193             0 0;
00194             0 1;
00195             0 1;
00196             0 1;
00197             0 1;
00198             1 0;
00199             1 0;
00200             1 0;
00201             1 0;
00202             1 0];
00203 
00204            class = [1;1;2;2;2;3;3;3;3;4;4;4;4;4];
00205 
00206            % depth = 2, nTrial = 25
00207            [maps boundaries counts] = RandomEntropyTreeIdxTest(data, class, 2, 25);
00208   
00209            === Output ===
00210            gain = -0.9242
00211            gain = 0
00212            gain = 0
00213          */
00214         
00215         AnnotatedFeatureTable data(Vector::ColumnVectorSet(true, 2, 0),
00216                                    Column::ColumnInt32(0));
00217         data.Add(Vector::VectorReal64(1, 1), 1);
00218         data.Add(Vector::VectorReal64(1, 1), 1);
00219         data.Add(Vector::VectorReal64(0, 0), 2);
00220         data.Add(Vector::VectorReal64(0, 0), 2);
00221         data.Add(Vector::VectorReal64(0, 0), 2);
00222         data.Add(Vector::VectorReal64(0, 1), 3);
00223         data.Add(Vector::VectorReal64(0, 1), 3);
00224         data.Add(Vector::VectorReal64(0, 1), 3);
00225         data.Add(Vector::VectorReal64(0, 1), 3);
00226         data.Add(Vector::VectorReal64(1, 0), 4);
00227         data.Add(Vector::VectorReal64(1, 0), 4);
00228         data.Add(Vector::VectorReal64(1, 0), 4);
00229         data.Add(Vector::VectorReal64(1, 0), 4);
00230         data.Add(Vector::VectorReal64(1, 0), 4);
00231         RandomTree* tree = MakeRandomTree(&data, 5, 2, 25);
00232     }
00233     
00234 private:
00235     int mNumberClasses;
00236     AnnotatedFeatureTable* mFakeData;
00237     bool mFilterAll[12];
00238     bool mFilterHalf[12];
00239     bool mFilterHalfComp[12];
00240     bool mFilterQuarter[12];
00241     bool mFilterQuarterComp[12];
00242     bool mFilterOdd[12];
00243     bool mFilterEven[12];
00244     ILOG_CLASS;
00245 };
00246 
00247 ILOG_CLASS_INIT(TestMakeRandomTree, Impala.Core.Feature);
00248 
00249 CPPUNIT_TEST_SUITE_REGISTRATION( TestMakeRandomTree );
00250     
00251 } // namespace Feature
00252 } // namespace Sandbox
00253 } // namespace Impala
00254 
00255 #endif
00256     

Generated on Fri Mar 19 09:31:08 2010 for ImpalaSrc by  doxygen 1.5.1