00001 #ifndef Impala_Core_Training_FikSvm_h
00002 #define Impala_Core_Training_FikSvm_h
00003
00004 #include "Link/FikSvm/LinkFikSvm.h"
00005 #include "Core/Feature/WeightedFeatureList.h"
00006 #include "Core/Matrix/Mat.h"
00007 #include "Core/Matrix/SetRow.h"
00008 #include "Core/Array/PrintData.h"
00009 #include "Core/Array/Diff.h"
00010 #include "Core/Vector/Diff.h"
00011
00012 namespace Impala
00013 {
00014 namespace Core
00015 {
00016 namespace Training
00017 {
00018
00019
00020 class FikSvm
00021 {
00022 public:
00023
00024 typedef Link::FikSvm::fiksvm_approx_classifier FAC;
00025 typedef Core::Feature::WeightedFeatureList WeightedFeatureList;
00026 typedef Matrix::Mat Mat;
00027 typedef Vector::VectorTem<Real64> VectorReal64;
00028
00029 FikSvm(const WeightedFeatureList& features, int nrBins, int nrSV,
00030 Real64 rho, Real64 probA, Real64 probB, int probIndex)
00031 {
00032 mFeatures = features;
00033 mNrBins = nrBins;
00034 mNrSV = nrSV;
00035 mRho = rho;
00036 mProbA = probA;
00037 mProbB = probB;
00038 mProbIndex = probIndex;
00039 }
00040
00041 virtual
00042 ~FikSvm()
00043 {
00044 mFeatures.Clear();
00045 for (int i=0 ; i<mFikSvm.size() ; i++)
00046 delete mFikSvm[i];
00047 mFikSvm.clear();
00048 }
00049
00050 WeightedFeatureList*
00051 GetFeatureList()
00052 {
00053 return &mFeatures;
00054 }
00055
00056 int
00057 GetNrFeatures() const
00058 {
00059 return mFeatures.Size();
00060 }
00061
00062 String
00063 GetFeature(int idx) const
00064 {
00065 return mFeatures.GetFeature(idx);
00066 }
00067
00068 Real64
00069 GetWeight(int idx) const
00070 {
00071 return mFeatures.GetWeight(idx);
00072 }
00073
00074 int
00075 GetNrBins() const
00076 {
00077 return mNrBins;
00078 }
00079
00080 int
00081 GetNrSV() const
00082 {
00083 return mNrSV;
00084 }
00085
00086 Real64
00087 GetRho() const
00088 {
00089 return mRho;
00090 }
00091
00092 Real64
00093 GetProbA() const
00094 {
00095 return mProbA;
00096 }
00097
00098 Real64
00099 GetProbB() const
00100 {
00101 return mProbB;
00102 }
00103
00104 int
00105 GetProbIndex() const
00106 {
00107 return mProbIndex;
00108 }
00109
00110 void
00111 CreateSvm(int featDim, Real64* coefficients, Real64* supportVectors)
00112 {
00113 FAC* c =
00114 new FAC(featDim, mNrSV, coefficients, supportVectors, mRho, mNrBins);
00115 mFikSvm.push_back(c);
00116 }
00117
00118 void
00119 ImportSvm(Mat* h, Mat* a, Mat* b)
00120 {
00121 int featDim = h->CH();
00122 FAC* c = new FAC(featDim, mNrBins);
00123 c->rho = mRho;
00124 mFikSvm.push_back(c);
00125 int idx = mFikSvm.size() - 1;
00126 for (int i=0 ; i<featDim ; i++)
00127 {
00128 Mat* row = GetRowFromHMat(idx, i);
00129 Matrix::SetRow(row, 0, h, i);
00130 delete row;
00131 }
00132 Mat* dstA = GetAMat(idx);
00133 Matrix::SetRow(dstA, 0, a, 0);
00134 delete dstA;
00135 Mat* dstB = GetBMat(idx);
00136 Matrix::SetRow(dstB, 0, b, 0);
00137 delete dstB;
00138 }
00139
00140 int
00141 GetFeatDim(int idx) const
00142 {
00143 return mFikSvm[idx]->feat_dim;
00144 }
00145
00146 int
00147 GetNrSV(int idx) const
00148 {
00149 return mFikSvm[idx]->num_sv;
00150 }
00151
00152
00153 Mat*
00154 GetH(int idx)
00155 {
00156 Mat* h = Matrix::MatCreate<Mat>(GetFeatDim(idx), GetNrBins()+1);
00157 for (int i=0 ; i<GetFeatDim(idx) ; i++)
00158 {
00159 Matrix::SetRow(h, i, GetRowFromH(idx, i));
00160 }
00161 return h;
00162 }
00163
00164 VectorReal64
00165 GetA(int idx)
00166 {
00167 return VectorReal64(GetFeatDim(idx), mFikSvm[idx]->a, true);
00168
00169 }
00170
00171 Mat*
00172 GetAMat(int idx)
00173 {
00174 return Matrix::MatCreate<Mat>(1, GetFeatDim(idx), mFikSvm[idx]->a, true);
00175
00176 }
00177
00178 VectorReal64
00179 GetB(int idx)
00180 {
00181 return VectorReal64(GetFeatDim(idx), mFikSvm[idx]->b, true);
00182
00183 }
00184
00185 Mat*
00186 GetBMat(int idx)
00187 {
00188 return Matrix::MatCreate<Mat>(1, GetFeatDim(idx), mFikSvm[idx]->b, true);
00189
00190 }
00191
00192 Real64
00193 Apply(int idx, const VectorReal64& vec)
00194 {
00195 return mFikSvm[idx]->pwl_predict(vec.GetData());
00196 }
00197
00198
00199 Real64
00200 PredictProbability(Real64 decValue)
00201 {
00202 Real64 min_prob = 1e-7;
00203 int nr_class = 2;
00204 double prob_estimates[2];
00205 double** pairwise_prob=Malloc(double *, nr_class);
00206 for (int i=0 ; i<nr_class ; i++)
00207 pairwise_prob[i] = Malloc(double, nr_class);
00208 double pred = sigmoid_predict(decValue, mProbA, mProbB);
00209 pairwise_prob[0][1] = std::min(std::max(pred, min_prob), 1 - min_prob);
00210 pairwise_prob[1][0] = 1 - pairwise_prob[0][1];
00211 multiclass_probability(nr_class, pairwise_prob, prob_estimates);
00212 for (int i=0 ; i<nr_class ; i++)
00213 free(pairwise_prob[i]);
00214 free(pairwise_prob);
00215 return prob_estimates[mProbIndex];
00216 }
00217
00218 int
00219 Diff(FikSvm* arg)
00220 {
00221 int nDiff = GetFeatureList()->Diff(arg->GetFeatureList());
00222 if (nDiff > 0)
00223 {
00224 ILOG_ERROR("Features differ");
00225 }
00226 if (GetNrBins() != arg->GetNrBins())
00227 {
00228 ILOG_ERROR("NrBins differ");
00229 nDiff++;
00230 }
00231 if (GetNrSV() != arg->GetNrSV())
00232 {
00233 ILOG_ERROR("NrSV differ");
00234 nDiff++;
00235 }
00236 if (fabs(GetRho() - arg->GetRho()) > 0.00001)
00237 {
00238 ILOG_ERROR("Rho differs");
00239 nDiff++;
00240 }
00241 if (fabs(GetProbA() - arg->GetProbA()) > 0.00001)
00242 {
00243 ILOG_ERROR("ProbA differs");
00244 nDiff++;
00245 }
00246 if (fabs(GetProbB() - arg->GetProbB()) > 0.00001)
00247 {
00248 ILOG_ERROR("ProbB differs");
00249 nDiff++;
00250 }
00251 if (GetProbIndex() != arg->GetProbIndex())
00252 {
00253 ILOG_ERROR("ProbIndex differs");
00254 nDiff++;
00255 }
00256 if (nDiff > 0)
00257 return nDiff;
00258
00259 for (int i=0 ; i<GetNrFeatures() ; i++)
00260 {
00261 Mat* m1 = GetH(i);
00262 Mat* m2 = arg->GetH(i);
00263 int d = Array::Diff(m1, m2, 0.00001);
00264 if (d > 0)
00265 ILOG_ERROR("H[" << i << "] differs");
00266 nDiff += d;
00267 delete m1;
00268 delete m2;
00269 VectorReal64 v1 = GetA(i);
00270 VectorReal64 v2 = arg->GetA(i);
00271 d = Vector::Diff(v1, v2, 0.00001);
00272 if (d > 0)
00273 ILOG_ERROR("A[" << i << "] differs");
00274 nDiff += d;
00275 v1 = GetB(i);
00276 v2 = arg->GetB(i);
00277 d = Vector::Diff(v1, v2, 0.00001);
00278 if (d > 0)
00279 ILOG_ERROR("B[" << i << "] differs");
00280 nDiff += d;
00281 }
00282 if (nDiff > 0)
00283 ILOG_ERROR("Found " << nDiff << " differences in FAC's");
00284 return nDiff;
00285 }
00286
00287 void
00288 Dump(int idx, int cornerSize)
00289 {
00290 std::cout << "Dumping hiksvm" << std::endl;
00291 std::cout << "h = ";
00292 Mat* h = GetH(idx);
00293 Array::PrintDataCorners(h, cornerSize, cornerSize);
00294 delete h;
00295 std::cout << "a = " << GetA(idx).PrintE(cornerSize) << std::endl;
00296 std::cout << "b = " << GetB(idx).PrintE(cornerSize) << std::endl;
00297 }
00298
00299 private:
00300
00301 VectorReal64
00302 GetRowFromH(int idx, int row)
00303 {
00304 return VectorReal64(mNrBins+1, mFikSvm[idx]->h[row], true);
00305 }
00306
00307 Mat*
00308 GetRowFromHMat(int idx, int row)
00309 {
00310 return Matrix::MatCreate<Mat>(1, mNrBins+1, mFikSvm[idx]->h[row], true);
00311 }
00312
00313 WeightedFeatureList mFeatures;
00314 std::vector<FAC*> mFikSvm;
00315 int mNrBins;
00316 int mNrSV;
00317 Real64 mRho;
00318 Real64 mProbA;
00319 Real64 mProbB;
00320 int mProbIndex;
00321
00322 ILOG_VAR_DEC;
00323 };
00324
00325 ILOG_VAR_INIT(FikSvm, Impala.Core.Training);
00326
00327 }
00328 }
00329 }
00330
00331 #endif