为部分代码,只做参考。文中很多变量类型为自己定义的数据结构。
头文件:
#ifndef SVM_C_H
#define SVM_C_H
#include"Process.h"
extern void Label(std::vector<sample_type> &PSample,std::vector<sample_type> &NSample);//标签赋值;
extern void Train();
extern void Test();
extern void Classify();
extern void ParamsSelection(std::vector<sample_type> &PSample,std::vector<sample_type> &NSample);
extern void Train_opencv_ori();
extern void Train_opencv_opt();
extern void Classify_opencv(int mode);
void SVM_Params_ori();
void SVM_Params_opt();
#endif;
cpp文件:
#include "StdAfx.h"
#include"Svm_c.h"
dlib::svm_c_trainer<kernel_type>trainer;
std::vector<sample_type>AllSamples;
std::vector<double>All_labels;
funct_type learned_function;
dlib::vector_normalizer<sample_type>normalizer;
dlib::rand rnd;
//cv::Mat Classes;
CvSVMParams SVM_params;
CvSVM svm;
int respones;
int PrePoNum=0;
int PreNgNum=0;
std::vector<int>PSIndex;//分为正样本的索引;
std::vector<int>NGIndex;//分为负样本的索引;
CvParamGrid nuGrid;
CvParamGrid coeffGrid;
CvParamGrid degreeGrid;
void Label(std::vector<sample_type> &PSample,std::vector<sample_type> &NSample)
{
std::cout<<"labeling..."<<std::endl;
AllSamples.clear();
int PSnum=PSample.size();
int NGnum=NSample.size();
int Num=0;
Num=PSnum+NGnum;
if(PSnum>0&&NGnum>0)
{
for(int i=0;i<Num;++i)
{
if(i<PSnum)
{
AllSamples.push_back(PSample[i]);
All_labels.push_back(1);
}
else
{
AllSamples.push_back(NSample[i-PSnum]);
All_labels.push_back(-1);
}
}
normalizer.train(AllSamples);
for(unsigned long i=0;i<AllSamples.size();++i)
{
AllSamples[i]=normalizer(AllSamples[i]);
std::cout<<AllSamples[i](0)<<" "<<AllSamples[i](1)<<" "<<AllSamples[i](2)<<" "<<AllSamples[i](3)<<std::endl;
}
dlib::randomize_samples(AllSamples,All_labels);//可要可不要;
}
else
std::cout<<"训练样本无效!"<<std::endl;
}
void Train()
{
if(AllSamples.size()>0)
{
std::cout<<"Doing cross calidation"<<std::endl;
for(double gamma=0.00001;gamma<=1;gamma*=5)
{
for(double C=1;C<100000;C*=5)
{
trainer.set_kernel(kernel_type(gamma));
trainer.set_c(C);
std::cout<<"gamma: "<<gamma<<" C: "<<C;
std::cout<<" cross validation accuracy: "<<dlib::cross_validate_trainer(trainer,AllSamples,All_labels,10);
}
}
learned_function.normalizer=normalizer;
learned_function.function=trainer.train(AllSamples,All_labels);
std::cout<<"\nnumber of support vector in our learned_functions are "<<learned_function.function.basis_vectors.size()<<std::endl;
dlib::serialize("saved_function.dat")<<learned_function;
}
else
std::cout<<"无法训练!"<<std::endl;
}
void Test()
{
trainer.set_kernel(kernel_type(0.00625));
trainer.set_c(5);
}
void Classify()
{
dlib::deserialize("saved_function.dat")>>learned_function;
if(CSample.size()>0)
{
for(int i=0;i<CSample.size();i++)
std::cout<<"分类结果: "<<learned_function(CSample[i])<<std::endl;
}
else
{
std::cout<<"无测试样本。"<<std::endl;
}
std::cout<<"分类结束!"<<endl;
}
void ParamsSelection(cv::vector<sample_type> &Features_All_p,cv::vector<sample_type> &Features_All_n)
{
std::cout<<"开始选择参数..."<<std::endl;
//double gamma=1.0/(dlib::compute_mean_squared_distance(dlib::randomly_subsample(All_samples,20)));
const double gamma=dlib::verbose_find_gamma_with_big_centroid_gap(AllSamples,All_labels);
dlib::kcentroid<kernel_type> kc(kernel_type(gamma),0.001,40);//最后一个参数可以调整;
std::cout<<dlib::rank_features(kc,AllSamples,All_labels)<<std::endl;
}
void Train_opencv_ori()
{
SVM_Params_ori();
std::cout<<"开始训练"<<std::endl;
svm.train(trainingdatas,Classes,cv::Mat(),cv::Mat(),SVM_params);
std::cout<<"SVM分类器训练完毕。"<<std::endl;
svm.save("svm_ori.xml");
std::cout<<"模型保存完毕。"<<std::endl;
}
void Train_opencv_opt()
{
SVM_Params_opt();
std::cout<<"开始训练"<<std::endl;
svm.train_auto(trainingdatas,Classes,cv::Mat(),cv::Mat(),SVM_params,10,svm.get_default_grid(CvSVM::C),svm.get_default_grid(CvSVM::GAMMA),svm.get_default_grid(CvSVM::P),nuGrid,coeffGrid,degreeGrid);
CvSVMParams SVM_params_return=svm.get_params();
std::cout<<"SVM分类器训练完毕。"<<std::endl;
svm.save("svm_opt_15_20170925_16features.xml");
std::cout<<"模型保存完毕。"<<std::endl;
}
void SVM_Params_opt()
{
SVM_params.svm_type=CvSVM::C_SVC;
SVM_params.kernel_type=CvSVM::RBF;
SVM_params.C=1;
SVM_params.gamma=0.0001;
SVM_params.term_crit=cvTermCriteria(CV_TERMCRIT_ITER,15000,0.001);
CvParamGrid nuGrid=CvParamGrid(1,1,0.0);
CvParamGrid coeffGrid=CvParamGrid(1,1,0.0);
CvParamGrid degreeGrid=CvParamGrid(1,1,0.0);
}
void SVM_Params_ori()//设置SVM参数;
{
SVM_params.svm_type=CvSVM::C_SVC;
SVM_params.kernel_type=CvSVM::RBF;
SVM_params.degree=0;
SVM_params.gamma=1;
SVM_params.coef0=0;
SVM_params.C=1;
SVM_params.nu=0;
SVM_params.p=0;
SVM_params.term_crit=cvTermCriteria(CV_TERMCRIT_ITER,100000,0.01);
}
void Classify_opencv(int mode)
{
string modelpath;
if(mode==1)
modelpath="svm_ori.xml";
if(mode==2)
modelpath="svm_opt_15_20170925_16features.xml";
FileStorage svm_fs(modelpath,FileStorage::READ);
if(svm_fs.isOpened())
{
respones=0;
svm.load(modelpath.c_str());
std::cout<<std::endl;
PSIndex.clear();
NGIndex.clear();
std::cout<<"开始分类"<<std::endl;
//std::cout<<PredictingDatas<<std::endl;
for(int i=0;i<PredictingDatas.rows;i++)
{
Mat classMat=PredictingDatas.rowRange(i,i+1);
classMat=classMat.reshape(1,1);
respones=(int)svm.predict(classMat);
if(respones==1)
{
PrePoNum++;
PSIndex.push_back(i);
}
else
{
PreNgNum++;
NGIndex.push_back(i);
}
}
//std::cout<<"正样本路径: "<<std::endl;
PrintFileName(PSIndex,1);
// std::cout<<"负样本路径: "<<std::endl;
PrintFileName(NGIndex,2);
std::cout<<"分类结束!正样本数: "<<PrePoNum<<" 负样本数: "<<PreNgNum<<std::endl;
}
}
代码数据结构参考:
#ifndef DATASTRUCT_H
#define DATASTRUCT_H
#include<iostream>
#include<dlib/svm.h>
#include<vector>
#include<dlib/rand.h>
#include<opencv2\highgui\highgui.hpp>
#include<opencv2\imgproc\imgproc.hpp>
#include<opencv2\core\core.hpp>
#include<opencv2\ml\ml.hpp>
#include<opencv\cv.hpp>
#include<stdio.h>
#include<stdlib.h>
#include<math.h>
#include<fstream>
#include<io.h>
#include<cassert>
#include<iterator>
#include<functional>
#include<algorithm>
#include<opencv2/opencv.hpp>
#define FEATURESNUM 4
typedef std::vector<std::vector<int> >Vec2D;
typedef struct _GLMCFeatures
{
_GLMCFeatures():energy(0.0),entropy(0.0),contrast(0.0),idmoment(0.0)
{
}
double energy;
double entropy;
double contrast;
double idmoment;
}GLCMFeatures;
typedef struct _StandValue
{
_StandValue():mean_train_energy(0.0),mean_train_entropy(0.0),mean_train_contrast(0.0),mean_train_idmoment(0.0),sigma_train_energy(0.0),sigma_train_entropy(0.0),sigma_train_contrast(0.0),sigma_train_idmoment(0.0)
{}
double mean_train_energy;
double mean_train_entropy;
double mean_train_contrast;
double mean_train_idmoment;
double sigma_train_energy;
double sigma_train_entropy;
double sigma_train_contrast;
double sigma_train_idmoment;
}
StandValue;
typedef struct _NormaData
{
_NormaData():mean_energy_hor(0.0),mean_entropy_hor(0.0),mean_contrast_hor(0.0),mean_idmoment_hor(0.0),
mean_energy_ver(0.0),mean_entropy_ver(0.0),mean_contrast_ver(0.0),mean_idmoment_ver(0.0),
mean_energy_45(0.0),mean_entropy_45(0.0),mean_contrast_45(0.0),mean_idmoment_45(0.0),
mean_energy_135(0.0),mean_entropy_135(0.0),mean_contrast_135(0.0),mean_idmoment_135(0.0),
sum_energy_hor(0.0), sum_entropy_hor(0.0), sum_contrast_hor(0.0), sum_idmoment_hor(0.0),
sum_energy_ver(0.0), sum_entropy_ver(0.0), sum_contrast_ver(0.0), sum_idmoment_ver(0.0),
sum_energy_45(0.0), sum_entropy_45(0.0), sum_contrast_45(0.0), sum_idmoment_45(0.0),
sum_energy_135(0.0), sum_entropy_135(0.0), sum_contrast_135(0.0), sum_idmoment_135(0.0),
pow_energy_hor(0.0),pow_entropy_hor(0.0),pow_contrast_hor(0.0),pow_idmoment_hor(0.0),
pow_energy_ver(0.0),pow_entropy_ver(0.0),pow_contrast_ver(0.0),pow_idmoment_ver(0.0),
pow_energy_45(0.0),pow_entropy_45(0.0),pow_contrast_45(0.0),pow_idmoment_45(0.0),
pow_energy_135(0.0),pow_entropy_135(0.0),pow_contrast_135(0.0),pow_idmoment_135(0.0),
spow_energy_hor(0.0),spow_entropy_hor(0.0),spow_contrast_hor(0.0),spow_idmoment_hor(0.0),
spow_energy_ver(0.0),spow_entropy_ver(0.0),spow_contrast_ver(0.0),spow_idmoment_ver(0.0),
spow_energy_45(0.0),spow_entropy_45(0.0),spow_contrast_45(0.0),spow_idmoment_45(0.0),
spow_energy_135(0.0),spow_entropy_135(0.0),spow_contrast_135(0.0),spow_idmoment_135(0.0)
{}
double mean_energy_hor;
double mean_entropy_hor;
double mean_contrast_hor;
double mean_idmoment_hor;
double sum_entropy_hor;
double sum_energy_hor;
double sum_contrast_hor;
double sum_idmoment_hor;
double pow_entropy_hor;
double pow_energy_hor;
double pow_contrast_hor;
double pow_idmoment_hor;
double spow_entropy_hor;
double spow_energy_hor;
double spow_contrast_hor;
double spow_idmoment_hor;
double mean_energy_ver;
double mean_entropy_ver;
double mean_contrast_ver;
double mean_idmoment_ver;
double sum_entropy_ver;
double sum_energy_ver;
double sum_contrast_ver;
double sum_idmoment_ver;
double pow_entropy_ver;
double pow_energy_ver;
double pow_contrast_ver;
double pow_idmoment_ver;
double spow_entropy_ver;
double spow_energy_ver;
double spow_contrast_ver;
double spow_idmoment_ver;
double mean_energy_45;
double mean_entropy_45;
double mean_contrast_45;
double mean_idmoment_45;
double sum_entropy_45;
double sum_energy_45;
double sum_contrast_45;
double sum_idmoment_45;
double pow_entropy_45;
double pow_energy_45;
double pow_contrast_45;
double pow_idmoment_45;
double spow_entropy_45;
double spow_energy_45;
double spow_contrast_45;
double spow_idmoment_45;
double mean_energy_135;
double mean_entropy_135;
double mean_contrast_135;
double mean_idmoment_135;
double sum_entropy_135;
double sum_energy_135;
double sum_contrast_135;
double sum_idmoment_135;
double pow_entropy_135;
double pow_energy_135;
double pow_contrast_135;
double pow_idmoment_135;
double spow_entropy_135;
double spow_energy_135;
double spow_contrast_135;
double spow_idmoment_135;
}NormaData;
//dlib库的相关变量定义;
typedef dlib::matrix<double,4,1>sample_type;
typedef dlib::radial_basis_kernel<sample_type>kernel_type;
typedef dlib::radial_basis_kernel<sample_type>kernel_type;
typedef dlib::decision_function<kernel_type>dec_funct_type;
typedef dlib::normalized_function<dec_funct_type>funct_type;
//灰度共生矩阵相关定义;
extern Vec2D Vec_hor;
extern Vec2D Vec_ver;
extern Vec2D Vec_45;
extern Vec2D Vec_135;
extern StandValue standValue_hor;
extern StandValue standValue_ver;
extern StandValue standValue_45;
extern StandValue standValue_135;
extern GLCMFeatures features_hor;
extern GLCMFeatures features_ver;
extern GLCMFeatures features_45;
extern GLCMFeatures features_135;
extern dlib::svm_c_trainer<kernel_type>trainer;
extern std::vector<sample_type>PSample;
extern std::vector<sample_type>NSample;
extern std::vector<sample_type>CSample;
extern std::vector<sample_type>AllSamples;
extern std::vector<double>All_labels;
extern funct_type learned_function;
extern dlib::rand rnd;
extern std::string RootFileName_P;
extern std::string RootFileName_N;
extern std::string RootFileName_C;
extern std::string RootSavePath_PS;
extern std::string RootSavePath_NG;
extern std::string RootPath_glcm;
extern std::string RootProjectPath;
//extern std::string SavePath_glcm_n;
extern std::vector<std::string>Vec_ImageFiles_p;
extern std::vector<std::string>Vec_ImageFiles_n;
extern std::vector<std::string>Vec_ImageFiles_c;
extern std::vector<std::string>Vec_RoiFiles_p;
extern std::vector<std::string>Vec_RoiFiles_n;
extern std::vector<std::string>Vec_RoiFiles_c;
extern std::vector<std::string>Vec_RoiFiles_A;
extern std::vector<cv::Mat>BMPImages_p;
extern std::vector<cv::Mat>BMPImages_n;
extern std::vector<cv::Mat>BMPImages_c;
extern std::vector<cv::Mat>BMPclass_p;
extern std::vector<cv::Mat>BMPclass_n;
extern std::vector<cv::Mat>ROI_p;
extern std::vector<cv::Mat>ROI_n;
extern std::vector<cv::Mat>ROI_c;
extern std::string tempSave;
extern cv::Mat Classes;
extern std::vector<int>trainLabels;
extern cv::Mat trainAlldatas;
extern cv::Mat trainingdatas;
extern cv::Mat PreDatas;
extern cv::Mat PredictingDatas;
extern CvSVMParams SVM_params;
extern CvSVM svm;
extern int respones;
extern int PrePoNum;
extern int PreNgNum;
extern std::vector<int>PSIndex;//分为正样本的索引;
extern std::vector<int>NGIndex;//分为负样本的索引;
extern int savenum;
#endif