00001 #ifndef Impala_Core_Training_ClassifierEvaluator_h
00002 #define Impala_Core_Training_ClassifierEvaluator_h
00003
00004 #include "Core/Table/Select.h"
00005 #include "Core/Training/ParameterEvaluator.h"
00006 #include "Core/Training/Svm.h"
00007 #include "Core/Training/Evaluation.h"
00008 #include "Core/Training/TrainDataSrc.h"
00009
00010 namespace Impala
00011 {
00012 namespace Core
00013 {
00014 namespace Training
00015 {
00016
00029 class ClassifierEvaluator : public ParameterEvaluator
00030 {
00031 public:
00032 typedef Table::AnnotationTable AnnotationTable;
00033
00034 ClassifierEvaluator(Classifier* classifier, TrainDataSrc* src,
00035 Evaluation* evaluator)
00036 {
00037 mClassifier = classifier;
00038 mDataSrc = src;
00039 mEvaluator = evaluator;
00040 }
00041
00042 virtual
00043 ~ClassifierEvaluator()
00044 {
00045 delete mClassifier;
00046 delete mDataSrc;
00047 delete mEvaluator;
00048 }
00049
00050 void
00051 SetRepetition(int repetition, int total)
00052 {
00053 ILOG_DEBUG_NODE("set repetition: "<< repetition);
00054 mRepetition = repetition;
00055 }
00056
00057 void
00058 SetFold(int fold, int total)
00059 {
00060 ILOG_DEBUG_NODE("set fold: "<< fold << ", "<< total);
00061 mFold = fold;
00062 mFoldCount = total;
00063 }
00064
00065 virtual double
00066 Evaluate(Util::PropertySet* parameters)
00067 {
00068 bool episode = parameters->GetBool("episode-constrained");
00069 mDataSrc->FilterTrainFold(mFold, mFoldCount, mRepetition, episode);
00070 ILOG_DEBUG_NODE("calling Classifier::Train, data size = "
00071 << mDataSrc->Size());
00072 mClassifier->Train(parameters, mDataSrc);
00073 int restrictTestFoldSet = parameters->GetInt("restrictTestFoldSet");
00074 mDataSrc->FilterTestFold(mFold, mFoldCount, mRepetition, episode,
00075 restrictTestFoldSet);
00076 ILOG_DEBUG_NODE("calling Classifier::Predict, data size = "
00077 << mDataSrc->Size());
00078 Table::ScoreTable* ranking = mClassifier->Predict(mDataSrc);
00079 double score = mEvaluator->Compute(ranking);
00080 delete ranking;
00081 return score;
00082 }
00083
00084 private:
00085 int mRepetition;
00086 int mFold;
00087 int mFoldCount;
00088 Evaluation* mEvaluator;
00089 Classifier* mClassifier;
00090 TrainDataSrc* mDataSrc;
00091 ILOG_VAR_DEC;
00092 };
00093
00094 ILOG_VAR_INIT(ClassifierEvaluator, Impala.Core.Training);
00095
00096 }
00097 }
00098 }
00099
00100 #endif