00001 #ifndef Impala_Core_Training_ParameterSearcher_h
00002 #define Impala_Core_Training_ParameterSearcher_h
00003
00004 #include "Util/PropertySet.h"
00005 #include "Util/RangeIterator.h"
00006
00007 #include "Link/Mpi/MpiFuncs.h"
00008 #include "Util/Mpi/Reduce.h"
00009 #include "Util/Mpi/JobManager.h"
00010 #include "Core/Training/ParameterEvaluator.h"
00011
00012 #include <algorithm>
00013 #include <sstream>
00014
00015 namespace Impala
00016 {
00017 namespace Core
00018 {
00019 namespace Training
00020 {
00021
00026 class ParameterSearcher
00027 {
00028 public:
00029 ParameterSearcher(Util::PropertySet* properties, ParameterEvaluator* evaluator)
00030 {
00031 #ifdef MPI_USED
00032 mParallelMode = true;
00033 #else
00034 mParallelMode = false;
00035 #endif
00036 mEvaluator = evaluator;
00037 mProperties = *properties;
00038 mBestScore = 0;
00039
00040 int i;
00041 for(i=0 ; i<mProperties.Size() ; ++i)
00042 {
00043 std::string val = mProperties.GetValue(i);
00044 ILOG_DEBUG("prop found: "<< val);
00045 if(Util::IsRangeDefinition(val))
00046 {
00047 mIterator.AddRange(mProperties.GetName(i), val);
00048 }
00049 }
00050 int r = mProperties.GetInt("repetitions");
00051 int f = mProperties.GetInt("folds");
00052 mScores.assign(mIterator.GetTotalIterations() * f * r, 0);
00053 }
00054
00055 virtual
00056 ~ParameterSearcher()
00057 {
00058 delete mEvaluator;
00059 }
00060
00061 void
00062 OverrideParallelMode(bool mode)
00063 {
00064 mParallelMode = mode;
00065 }
00066
00067 Util::PropertySet
00068 Search()
00069 {
00070 int r = mProperties.GetInt("repetitions");
00071 int f = mProperties.GetInt("folds");
00072 mScores.assign(mIterator.GetTotalIterations() * f * r, 0);
00073 mBestScore = 0;
00074
00075 if(mParallelMode)
00076 SearchParallel();
00077 else
00078 SearchSequential();
00079
00080
00081 std::vector<double> scores;
00082 int iters = mIterator.GetTotalIterations();
00083 scores.assign(iters, 0);
00084 int i;
00085 for(i=0 ; i<mScores.size() ; ++i)
00086 {
00087 scores[i%iters] += mScores[i];
00088 }
00089 double d = f*r;
00090 for(i=0 ; i<scores.size() ; ++i)
00091 {
00092 scores[i] /= d;
00093 }
00094
00095
00096 std::vector<double>::iterator best =
00097 std::max_element(scores.begin(), scores.end());
00098 int bestIteration = best - scores.begin();
00099 mIterator.SetIteration(bestIteration);
00100 mIterator.GetParameters(&mProperties);
00101 mBestScore = *best;
00102 ILOG_INFO_HEADNODE("best score = " << mBestScore <<
00103 " @ params: " << mProperties.GetDescription());
00104 return mProperties;
00105 }
00106
00107 double
00108 GetBestScore()
00109 {
00110 return mBestScore;
00111 }
00112
00113 std::vector<Util::PropertySet*>
00114 GetAllScores()
00115 {
00116 std::vector<Util::PropertySet*> res;
00117 int index = 0;
00118 for (int r=0 ; r<mProperties.GetInt("repetitions") ; r++)
00119 {
00120 for (int f=0 ; f<mProperties.GetInt("folds") ; f++)
00121 {
00122 for (int i=0 ; i<mIterator.GetTotalIterations() ; i++)
00123 {
00124 Util::PropertySet* props = new Util::PropertySet;
00125 mIterator.SetIteration(i);
00126 mIterator.GetParameters(props);
00127 props->Add("score", mScores[index]);
00128 index++;
00129 props->Add("repetition", r);
00130 props->Add("fold", f);
00131 res.push_back(props);
00132 }
00133 }
00134 }
00135 return res;
00136 }
00137
00138 private:
00139
00140 double
00141 CallEvaluator(int iteration, int index)
00142 {
00143 mIterator.SetIteration(iteration);
00144 mIterator.GetParameters(&mProperties);
00145 ILOG_DEBUG("calling evaluator::Evaluate");
00146 double score = mEvaluator->Evaluate(&mProperties);
00147 ILOG_DEBUG("evaluator::Evaluate returned");
00148 mScores[index] += score;
00149 return score;
00150 }
00151
00152 void
00153 SearchSequential()
00154 {
00155 if(mEvaluator == 0)
00156 {
00157 ILOG_ERROR("no evaluator set");
00158 return;
00159 }
00160
00161 int index = 0;
00162 int repetition, repetitionCount;
00163 repetitionCount = mProperties.GetInt("repetitions");
00164 ILOG_INFO("#repetitions " << repetitionCount)
00165 for(repetition=0 ; repetition<repetitionCount ; ++repetition)
00166 {
00167 ILOG_INFO("repetition #" << repetition)
00168 mEvaluator->SetRepetition(repetition, repetitionCount);
00169 int fold, foldCount;
00170 foldCount = mProperties.GetInt("folds");
00171 for(fold=0 ; fold<foldCount ; ++fold)
00172 {
00173 ILOG_INFO("fold #" << fold);
00174 mEvaluator->SetFold(fold, foldCount);
00175 int i;
00176 for(i=0 ; i<mIterator.GetTotalIterations() ; ++i)
00177 {
00178 double score = CallEvaluator(i, index);
00179 ++index;
00180 ILOG_INFO("it " << i << ", " << mProperties.GetDescription() <<
00181 " : " << score);
00182 }
00183 }
00184 }
00185 }
00186
00187
00188 void
00189 SearchParallel()
00190 {
00191 if(Link::Mpi::MyId() == 0)
00192 SearchServer();
00193 else
00194 SearchClient();
00195 Util::Mpi::Reduce(mScores);
00196 }
00197
00198 void
00199 SearchServer()
00200 {
00201 ILOG_DEBUG_NODE("I am the server");
00202
00203 Util::Mpi::JobManager jobManager;
00204 int f;
00205 for(f=0 ; f<mProperties.GetInt("folds") ; ++f)
00206 {
00207 int r;
00208 for(r=0 ; r<mProperties.GetInt("repetitions") ; ++r)
00209 {
00210 std::string id("r0f0");
00211 id[1] = '0'+r;
00212 id[3] = '0'+f;
00213 Util::PropertySet ps;
00214 ps.Add("repetition", r);
00215 ps.Add("fold", f);
00216
00217
00218 jobManager.CreateGroup(ps.GetDescription(), id,
00219 mIterator.GetTotalIterations());
00220 }
00221 }
00222
00223
00224 bool done=false;
00225 int runningJobs=0;
00226 while(true)
00227 {
00228
00229 int source;
00230 std::string message = Link::Mpi::ReceiveString(source);
00231 ILOG_DEBUG_NODE("server got mssg: " << message);
00232 Util::PropertySet job(message);
00233 if(job.GetString("JobManager::job-id") != "-1")
00234 --runningJobs;
00235
00236 jobManager.GetJob(&job);
00237
00238 if(job.GetString("JobManager::job-id") != "-1")
00239 ++runningJobs;
00240 std::ostringstream oss;
00241 job.Print(oss);
00242 ILOG_DEBUG_NODE("SERVER: sending job assignment " <<
00243 job.GetString("JobManager::job-id") << "," <<
00244 job.GetString("JobManager::group-id") << " to " << source);
00245 Link::Mpi::SendString(oss.str(), source);
00246 if(runningJobs == 0)
00247 break;
00248 }
00249 }
00250
00251 void
00252 SearchClient()
00253 {
00254 ILOG_DEBUG_NODE("I am a client");
00255 int id = Link::Mpi::MyId();
00256 int lastRepetition = -1;
00257 int lastFold = -1;
00258 Util::PropertySet job;
00259 job.Add("JobManager::job-id", "-1");
00260 while(true)
00261 {
00262 ILOG_DEBUG_NODE("client sends request for work");
00263
00264 Link::Mpi::SendString(job.GetDescription(),0);
00265
00266 int source;
00267 std::string message = Link::Mpi::ReceiveString(source);
00268 job.Parse(message);
00269 int repetition = job.GetInt("repetition", -1);
00270 int fold = job.GetInt("fold", -1);
00271 int iteration = job.GetInt("JobManager::job-id", -1);
00272
00273 ILOG_DEBUG_NODE("got assignment " << iteration << " from group " <<
00274 job.GetString("JobManager::group-id"));
00275 if(iteration == -1)
00276 break;
00277 if(repetition != lastRepetition)
00278 {
00279 lastFold = -1;
00280 mEvaluator->SetRepetition(repetition,
00281 mProperties.GetInt("repetitions"));
00282 lastRepetition = repetition;
00283 }
00284 if(fold != lastFold)
00285 {
00286 mEvaluator->SetFold(fold, mProperties.GetInt("folds"));
00287 lastFold = fold;
00288 }
00289 int iters = mIterator.GetTotalIterations();
00290 int folds = mProperties.GetInt("folds");
00291 int index = (((repetition * folds) + fold) * iters) + iteration;
00292 double score = CallEvaluator(iteration, index);
00293 ILOG_INFO("r" << repetition << " f" << fold << " i" << iteration <<
00294 ", " << mProperties.GetDescription() << " -> " << score);
00295 }
00296 }
00297
00298
00299 ParameterSearcher* operator=(ParameterSearcher&);
00300 ParameterSearcher(ParameterSearcher&);
00301
00302 ParameterEvaluator *mEvaluator;
00303 Util::PropertySet mProperties;
00304 Util::RangeIterator mIterator;
00305 std::vector<double> mScores;
00306 double mBestScore;
00307 int mScoreIndex;
00308 bool mParallelMode;
00309
00310 ILOG_VAR_DEC;
00311 };
00312
00313 ILOG_VAR_INIT(ParameterSearcher, Impala.Core.Training);
00314
00315
00316
00317 }
00318 }
00319 }
00320
00321 #endif