Home || Architecture || Video Search || Visual Search || Scripts || Applications || Important Messages || OGL || Src

svm_model* svm_train ( const svm_problem prob,
const svm_parameter param 
)

Definition at line 1850 of file svm.cpp.

References decision_function::alpha, EPSILON_SVR, svm_model::free_sv, svm_model::l, svm_problem::l, svm_model::label, Malloc, svm_model::nr_class, svm_model::nSV, NU_SVR, ONE_CLASS, svm_model::param, svm_model::probA, svm_parameter::probability, svm_model::probB, decision_function::rho, svm_model::rho, svm_model::SV, svm_model::sv_coef, svm_svr_probability(), svm_train_one(), svm_parameter::svm_type, and svm_problem::x.

Referenced by svm_binary_svc_probability(), svm_cross_validation(), and Impala::Core::Training::Svm::Train().

01851 {
01852         svm_model *model = Malloc(svm_model,1);
01853         model->param = *param;
01854         model->free_sv = 0;     // XXX
01855 
01856         if(param->svm_type == ONE_CLASS ||
01857            param->svm_type == EPSILON_SVR ||
01858            param->svm_type == NU_SVR)
01859         {
01860                 // regression or one-class-svm
01861                 model->nr_class = 2;
01862                 model->label = NULL;
01863                 model->nSV = NULL;
01864                 model->probA = NULL; model->probB = NULL;
01865                 model->sv_coef = Malloc(double *,1);
01866 
01867                 if(param->probability && 
01868                    (param->svm_type == EPSILON_SVR ||
01869                     param->svm_type == NU_SVR))
01870                 {
01871                         model->probA = Malloc(double,1);
01872                         model->probA[0] = svm_svr_probability(prob,param);
01873                 }
01874 
01875                 decision_function f = svm_train_one(prob,param,0,0);
01876                 model->rho = Malloc(double,1);
01877                 model->rho[0] = f.rho;
01878 
01879                 int nSV = 0;
01880                 int i;
01881                 for(i=0;i<prob->l;i++)
01882                         if(fabs(f.alpha[i]) > 0) ++nSV;
01883                 model->l = nSV;
01884                 model->SV = Malloc(svm_node *,nSV);
01885                 model->sv_coef[0] = Malloc(double,nSV);
01886                 int j = 0;
01887                 for(i=0;i<prob->l;i++)
01888                         if(fabs(f.alpha[i]) > 0)
01889                         {
01890                                 model->SV[j] = prob->x[i];
01891                                 model->sv_coef[0][j] = f.alpha[i];
01892                                 ++j;
01893                         }               
01894 
01895                 free(f.alpha);
01896         }
01897         else
01898         {
01899                 // classification
01900                 // find out the number of classes
01901                 int l = prob->l;
01902                 int max_nr_class = 16;
01903                 int nr_class = 0;
01904                 int *label = Malloc(int,max_nr_class);
01905                 int *count = Malloc(int,max_nr_class);
01906                 int *index = Malloc(int,l);
01907 
01908                 int i;
01909                 for(i=0;i<l;i++)
01910                 {
01911                         int this_label = (int)prob->y[i];
01912                         int j;
01913                         for(j=0;j<nr_class;j++)
01914                                 if(this_label == label[j])
01915                                 {
01916                                         ++count[j];
01917                                         break;
01918                                 }
01919                         index[i] = j;
01920                         if(j == nr_class)
01921                         {
01922                                 if(nr_class == max_nr_class)
01923                                 {
01924                                         max_nr_class *= 2;
01925                                         label = (int *)realloc(label,max_nr_class*sizeof(int));
01926                                         count = (int *)realloc(count,max_nr_class*sizeof(int));
01927                                 }
01928                                 label[nr_class] = this_label;
01929                                 count[nr_class] = 1;
01930                                 ++nr_class;
01931                         }
01932                 }
01933 
01934                 // group training data of the same class
01935 
01936                 int *start = Malloc(int,nr_class);
01937                 start[0] = 0;
01938                 for(i=1;i<nr_class;i++)
01939                         start[i] = start[i-1]+count[i-1];
01940 
01941                 svm_node **x = Malloc(svm_node *,l);
01942                 
01943                 for(i=0;i<l;i++)
01944                 {
01945                         x[start[index[i]]] = prob->x[i];
01946                         ++start[index[i]];
01947                 }
01948                 
01949                 start[0] = 0;
01950                 for(i=1;i<nr_class;i++)
01951                         start[i] = start[i-1]+count[i-1];
01952 
01953                 // calculate weighted C
01954 
01955                 double *weighted_C = Malloc(double, nr_class);
01956                 for(i=0;i<nr_class;i++)
01957                         weighted_C[i] = param->C;
01958                 for(i=0;i<param->nr_weight;i++)
01959                 {       
01960                         int j;
01961                         for(j=0;j<nr_class;j++)
01962                                 if(param->weight_label[i] == label[j])
01963                                         break;
01964                         if(j == nr_class)
01965                                 fprintf(stderr,"warning: class label %d specified in weight is not found\n", param->weight_label[i]);
01966                         else
01967                                 weighted_C[j] *= param->weight[i];
01968                 }
01969 
01970                 // train k*(k-1)/2 models
01971                 
01972                 bool *nonzero = Malloc(bool,l);
01973                 for(i=0;i<l;i++)
01974                         nonzero[i] = false;
01975                 decision_function *f = Malloc(decision_function,nr_class*(nr_class-1)/2);
01976 
01977                 double *probA=NULL,*probB=NULL;
01978                 if (param->probability)
01979                 {
01980                         probA=Malloc(double,nr_class*(nr_class-1)/2);
01981                         probB=Malloc(double,nr_class*(nr_class-1)/2);
01982                 }
01983 
01984                 int p = 0;
01985                 for(i=0;i<nr_class;i++)
01986                         for(int j=i+1;j<nr_class;j++)
01987                         {
01988                                 svm_problem sub_prob;
01989                                 int si = start[i], sj = start[j];
01990                                 int ci = count[i], cj = count[j];
01991                                 sub_prob.l = ci+cj;
01992                                 sub_prob.x = Malloc(svm_node *,sub_prob.l);
01993                                 sub_prob.y = Malloc(double,sub_prob.l);
01994                                 int k;
01995                                 for(k=0;k<ci;k++)
01996                                 {
01997                                         sub_prob.x[k] = x[si+k];
01998                                         sub_prob.y[k] = +1;
01999                                 }
02000                                 for(k=0;k<cj;k++)
02001                                 {
02002                                         sub_prob.x[ci+k] = x[sj+k];
02003                                         sub_prob.y[ci+k] = -1;
02004                                 }
02005 
02006                                 if(param->probability)
02007                                         svm_binary_svc_probability(&sub_prob,param,weighted_C[i],weighted_C[j],probA[p],probB[p]);
02008 
02009                                 f[p] = svm_train_one(&sub_prob,param,weighted_C[i],weighted_C[j]);
02010                                 for(k=0;k<ci;k++)
02011                                         if(!nonzero[si+k] && fabs(f[p].alpha[k]) > 0)
02012                                                 nonzero[si+k] = true;
02013                                 for(k=0;k<cj;k++)
02014                                         if(!nonzero[sj+k] && fabs(f[p].alpha[ci+k]) > 0)
02015                                                 nonzero[sj+k] = true;
02016                                 free(sub_prob.x);
02017                                 free(sub_prob.y);
02018                                 ++p;
02019                         }
02020 
02021                 // build output
02022 
02023                 model->nr_class = nr_class;
02024                 
02025                 model->label = Malloc(int,nr_class);
02026                 for(i=0;i<nr_class;i++)
02027                         model->label[i] = label[i];
02028                 
02029                 model->rho = Malloc(double,nr_class*(nr_class-1)/2);
02030                 for(i=0;i<nr_class*(nr_class-1)/2;i++)
02031                         model->rho[i] = f[i].rho;
02032 
02033                 if(param->probability)
02034                 {
02035                         model->probA = Malloc(double,nr_class*(nr_class-1)/2);
02036                         model->probB = Malloc(double,nr_class*(nr_class-1)/2);
02037                         for(i=0;i<nr_class*(nr_class-1)/2;i++)
02038                         {
02039                                 model->probA[i] = probA[i];
02040                                 model->probB[i] = probB[i];
02041                         }
02042                 }
02043                 else
02044                 {
02045                         model->probA=NULL;
02046                         model->probB=NULL;
02047                 }
02048 
02049                 int total_sv = 0;
02050                 int *nz_count = Malloc(int,nr_class);
02051                 model->nSV = Malloc(int,nr_class);
02052                 for(i=0;i<nr_class;i++)
02053                 {
02054                         int nSV = 0;
02055                         for(int j=0;j<count[i];j++)
02056                                 if(nonzero[start[i]+j])
02057                                 {       
02058                                         ++nSV;
02059                                         ++total_sv;
02060                                 }
02061                         model->nSV[i] = nSV;
02062                         nz_count[i] = nSV;
02063                 }
02064                 
02065                 info("Total nSV = %d\n",total_sv);
02066 
02067                 model->l = total_sv;
02068                 model->SV = Malloc(svm_node *,total_sv);
02069                 p = 0;
02070                 for(i=0;i<l;i++)
02071                         if(nonzero[i]) model->SV[p++] = x[i];
02072 
02073                 int *nz_start = Malloc(int,nr_class);
02074                 nz_start[0] = 0;
02075                 for(i=1;i<nr_class;i++)
02076                         nz_start[i] = nz_start[i-1]+nz_count[i-1];
02077 
02078                 model->sv_coef = Malloc(double *,nr_class-1);
02079                 for(i=0;i<nr_class-1;i++)
02080                         model->sv_coef[i] = Malloc(double,total_sv);
02081 
02082                 p = 0;
02083                 for(i=0;i<nr_class;i++)
02084                         for(int j=i+1;j<nr_class;j++)
02085                         {
02086                                 // classifier (i,j): coefficients with
02087                                 // i are in sv_coef[j-1][nz_start[i]...],
02088                                 // j are in sv_coef[i][nz_start[j]...]
02089 
02090                                 int si = start[i];
02091                                 int sj = start[j];
02092                                 int ci = count[i];
02093                                 int cj = count[j];
02094                                 
02095                                 int q = nz_start[i];
02096                                 int k;
02097                                 for(k=0;k<ci;k++)
02098                                         if(nonzero[si+k])
02099                                                 model->sv_coef[j-1][q++] = f[p].alpha[k];
02100                                 q = nz_start[j];
02101                                 for(k=0;k<cj;k++)
02102                                         if(nonzero[sj+k])
02103                                                 model->sv_coef[i][q++] = f[p].alpha[ci+k];
02104                                 ++p;
02105                         }
02106                 
02107                 free(label);
02108                 free(probA);
02109                 free(probB);
02110                 free(count);
02111                 free(index);
02112                 free(start);
02113                 free(x);
02114                 free(weighted_C);
02115                 free(nonzero);
02116                 for(i=0;i<nr_class*(nr_class-1)/2;i++)
02117                         free(f[i].alpha);
02118                 free(f);
02119                 free(nz_count);
02120                 free(nz_start);
02121         }
02122         return model;
02123 }

Here is the call graph for this function:


Generated on Fri Mar 19 10:17:12 2010 for ImpalaSrc by  doxygen 1.5.1