00001 #ifndef Impala_Core_Training_ClassifierEvaluator_h
00002 #define Impala_Core_Training_ClassifierEvaluator_h
00003
00004 #include "Core/Table/Select.h"
00005 #include "Core/Training/ParameterEvaluator.h"
00006 #include "Core/Training/Svm.h"
00007 #include "Core/Training/Evaluation.h"
00008 #include "Core/Training/TrainDataSrc.h"
00009
00010 namespace Impala
00011 {
00012 namespace Core
00013 {
00014 namespace Training
00015 {
00016
00029 class ClassifierEvaluator : public ParameterEvaluator
00030 {
00031 public:
00032 typedef Table::AnnotationTable AnnotationTable;
00033
00034 ClassifierEvaluator(Classifier* classifier, TrainDataSrc* src,
00035 Evaluation* evaluator)
00036 {
00037 mClassifier = classifier;
00038 mDataSrc = src;
00039 mEvaluator = evaluator;
00040 }
00041
00042 virtual
00043 ~ClassifierEvaluator()
00044 {
00045 delete mClassifier;
00046 delete mDataSrc;
00047 delete mEvaluator;
00048 }
00049
00050 void
00051 SetRepetition(int repetition, int total)
00052 {
00053 ILOG_DEBUG_NODE("set repetition: "<< repetition);
00054 mRepetition = repetition;
00055 }
00056
00057 void
00058 SetFold(int fold, int total)
00059 {
00060 ILOG_DEBUG_NODE("set fold: "<< fold << ", "<< total);
00061 mFold = fold;
00062 mFoldCount = total;
00063 }
00064
00065 virtual double
00066 Evaluate(Util::PropertySet* parameters)
00067 {
00068 bool episode = parameters->GetBool("episode-constrained");
00069 ILOG_DEBUG_NODE("calling TrainDataSrc::FilterTrainFold");
00070 mDataSrc->FilterTrainFold(mFold, mFoldCount, mRepetition, episode);
00071 ILOG_DEBUG_NODE("calling Classifier::Train, data size = "
00072 << mDataSrc->Size());
00073 mClassifier->Train(parameters, mDataSrc);
00074 ILOG_DEBUG_NODE("calling TrainDataSrc::FilterTestFold");
00075 mDataSrc->FilterTestFold(mFold, mFoldCount, mRepetition, episode);
00076 ILOG_DEBUG_NODE("calling Classifier::Predict, data size = "
00077 << mDataSrc->Size());
00078 Table::ScoreTable* ranking = mClassifier->Predict(mDataSrc);
00079 ILOG_DEBUG_NODE("calling Evaluation::Compute");
00080 double score = mEvaluator->Compute(ranking);
00081 delete ranking;
00082 return score;
00083 }
00084
00085 private:
00086 int mRepetition;
00087 int mFold;
00088 int mFoldCount;
00089 Evaluation* mEvaluator;
00090 Classifier* mClassifier;
00091 TrainDataSrc* mDataSrc;
00092 ILOG_VAR_DEC;
00093 };
00094
00095 ILOG_VAR_INIT(ClassifierEvaluator, Impala.Core.Training);
00096
00097 }
00098 }
00099 }
00100
00101 #endif