SVM分类器C++语言实现

为部分代码,只做参考。文中很多变量类型为自己定义的数据结构。

 

头文件:

#ifndef SVM_C_H
#define SVM_C_H
#include"Process.h"

extern void Label(std::vector<sample_type> &PSample,std::vector<sample_type> &NSample);//标签赋值;
extern void Train();
extern void Test();
extern void Classify();
extern void ParamsSelection(std::vector<sample_type> &PSample,std::vector<sample_type> &NSample);
extern void Train_opencv_ori();
extern void Train_opencv_opt();
extern void Classify_opencv(int mode);
void SVM_Params_ori();
void SVM_Params_opt();
#endif;

cpp文件:

#include "StdAfx.h"
#include"Svm_c.h"

dlib::svm_c_trainer<kernel_type>trainer;
std::vector<sample_type>AllSamples;

std::vector<double>All_labels;
funct_type learned_function;
dlib::vector_normalizer<sample_type>normalizer;
dlib::rand rnd;

//cv::Mat Classes;
CvSVMParams SVM_params;
CvSVM svm;
int respones;
int PrePoNum=0;
int PreNgNum=0;
std::vector<int>PSIndex;//分为正样本的索引;
std::vector<int>NGIndex;//分为负样本的索引;
CvParamGrid nuGrid;
CvParamGrid coeffGrid;
CvParamGrid degreeGrid;

void Label(std::vector<sample_type> &PSample,std::vector<sample_type> &NSample)
{
    std::cout<<"labeling..."<<std::endl;
    AllSamples.clear();
    int PSnum=PSample.size();
    int NGnum=NSample.size();
    int Num=0;
    Num=PSnum+NGnum;
    if(PSnum>0&&NGnum>0)
    {
        for(int i=0;i<Num;++i)
        {
            if(i<PSnum)
            {
                AllSamples.push_back(PSample[i]);
                All_labels.push_back(1);
            }
            else
            {
                AllSamples.push_back(NSample[i-PSnum]);
                All_labels.push_back(-1);
            }
        }
        normalizer.train(AllSamples);
        for(unsigned long i=0;i<AllSamples.size();++i)
        {
            AllSamples[i]=normalizer(AllSamples[i]);
            std::cout<<AllSamples[i](0)<<" "<<AllSamples[i](1)<<" "<<AllSamples[i](2)<<" "<<AllSamples[i](3)<<std::endl;
        }
        dlib::randomize_samples(AllSamples,All_labels);//可要可不要;
    }
    else
        std::cout<<"训练样本无效!"<<std::endl;
}
void Train()
{
    if(AllSamples.size()>0)
    {
        std::cout<<"Doing cross calidation"<<std::endl;
        for(double gamma=0.00001;gamma<=1;gamma*=5)
        {
            for(double C=1;C<100000;C*=5)
            {
                trainer.set_kernel(kernel_type(gamma));
                trainer.set_c(C);
                std::cout<<"gamma: "<<gamma<<" C: "<<C;
                std::cout<<"  cross validation accuracy: "<<dlib::cross_validate_trainer(trainer,AllSamples,All_labels,10);
            }
        }
        learned_function.normalizer=normalizer;
        learned_function.function=trainer.train(AllSamples,All_labels);
        std::cout<<"\nnumber of support vector in our learned_functions are "<<learned_function.function.basis_vectors.size()<<std::endl;
        dlib::serialize("saved_function.dat")<<learned_function;
    }
    else
        std::cout<<"无法训练!"<<std::endl;
}
void Test()
{
    trainer.set_kernel(kernel_type(0.00625));
    trainer.set_c(5);
    
}
void Classify()
{
    dlib::deserialize("saved_function.dat")>>learned_function;
    if(CSample.size()>0)
    {
        for(int i=0;i<CSample.size();i++)
            std::cout<<"分类结果: "<<learned_function(CSample[i])<<std::endl;
    }
    else
    {
        std::cout<<"无测试样本。"<<std::endl;
    }
    std::cout<<"分类结束!"<<endl;
}
void ParamsSelection(cv::vector<sample_type> &Features_All_p,cv::vector<sample_type> &Features_All_n)
{
    std::cout<<"开始选择参数..."<<std::endl;
    //double gamma=1.0/(dlib::compute_mean_squared_distance(dlib::randomly_subsample(All_samples,20)));
    const double gamma=dlib::verbose_find_gamma_with_big_centroid_gap(AllSamples,All_labels);
    dlib::kcentroid<kernel_type> kc(kernel_type(gamma),0.001,40);//最后一个参数可以调整;
    std::cout<<dlib::rank_features(kc,AllSamples,All_labels)<<std::endl;
}
void Train_opencv_ori()
{
    SVM_Params_ori();
    std::cout<<"开始训练"<<std::endl;
    svm.train(trainingdatas,Classes,cv::Mat(),cv::Mat(),SVM_params);
    std::cout<<"SVM分类器训练完毕。"<<std::endl;
    svm.save("svm_ori.xml");
    std::cout<<"模型保存完毕。"<<std::endl;
}
void Train_opencv_opt()
{
    SVM_Params_opt();
    std::cout<<"开始训练"<<std::endl;
    svm.train_auto(trainingdatas,Classes,cv::Mat(),cv::Mat(),SVM_params,10,svm.get_default_grid(CvSVM::C),svm.get_default_grid(CvSVM::GAMMA),svm.get_default_grid(CvSVM::P),nuGrid,coeffGrid,degreeGrid);
    CvSVMParams SVM_params_return=svm.get_params();
    std::cout<<"SVM分类器训练完毕。"<<std::endl;
    svm.save("svm_opt_15_20170925_16features.xml");
    std::cout<<"模型保存完毕。"<<std::endl;
}
void SVM_Params_opt()
{
    SVM_params.svm_type=CvSVM::C_SVC;
    SVM_params.kernel_type=CvSVM::RBF;
    SVM_params.C=1;
    SVM_params.gamma=0.0001;
    SVM_params.term_crit=cvTermCriteria(CV_TERMCRIT_ITER,15000,0.001);
    CvParamGrid nuGrid=CvParamGrid(1,1,0.0);
    CvParamGrid coeffGrid=CvParamGrid(1,1,0.0);
    CvParamGrid degreeGrid=CvParamGrid(1,1,0.0);
}
void SVM_Params_ori()//设置SVM参数;
{
    SVM_params.svm_type=CvSVM::C_SVC;
    SVM_params.kernel_type=CvSVM::RBF;
    SVM_params.degree=0;
    SVM_params.gamma=1;
    SVM_params.coef0=0;
    SVM_params.C=1;
    SVM_params.nu=0;
    SVM_params.p=0;
    SVM_params.term_crit=cvTermCriteria(CV_TERMCRIT_ITER,100000,0.01);
}
void Classify_opencv(int mode)
{
    string modelpath;
    if(mode==1)
        modelpath="svm_ori.xml";
    if(mode==2)
        modelpath="svm_opt_15_20170925_16features.xml";
    FileStorage svm_fs(modelpath,FileStorage::READ);
    if(svm_fs.isOpened())
    {
        respones=0;
        svm.load(modelpath.c_str());
        std::cout<<std::endl;
        PSIndex.clear();
        NGIndex.clear();
        std::cout<<"开始分类"<<std::endl;
        //std::cout<<PredictingDatas<<std::endl;
        for(int i=0;i<PredictingDatas.rows;i++)
        {
            Mat classMat=PredictingDatas.rowRange(i,i+1);
            classMat=classMat.reshape(1,1);
            respones=(int)svm.predict(classMat);
            
            if(respones==1)
            {
                PrePoNum++;
                PSIndex.push_back(i);
                
            }
            else
            {
                PreNgNum++;
                NGIndex.push_back(i);
            }
        }
        //std::cout<<"正样本路径: "<<std::endl;
        PrintFileName(PSIndex,1);
    //    std::cout<<"负样本路径: "<<std::endl;
        PrintFileName(NGIndex,2);
        std::cout<<"分类结束!正样本数: "<<PrePoNum<<" 负样本数: "<<PreNgNum<<std::endl;
        
    }
}
代码数据结构参考:

#ifndef DATASTRUCT_H
#define DATASTRUCT_H

#include<iostream>
#include<dlib/svm.h>
#include<vector>
#include<dlib/rand.h>
#include<opencv2\highgui\highgui.hpp>
#include<opencv2\imgproc\imgproc.hpp>
#include<opencv2\core\core.hpp>
#include<opencv2\ml\ml.hpp>
#include<opencv\cv.hpp>
#include<stdio.h>
#include<stdlib.h>
#include<math.h>
#include<fstream>
#include<io.h>
#include<cassert>
#include<iterator>
#include<functional>
#include<algorithm>
#include<opencv2/opencv.hpp>

#define FEATURESNUM 4

typedef std::vector<std::vector<int> >Vec2D;
typedef struct _GLMCFeatures
{
    _GLMCFeatures():energy(0.0),entropy(0.0),contrast(0.0),idmoment(0.0)
    {
    }
    double energy;
    double entropy;
    double contrast;
    double idmoment;
}GLCMFeatures;

typedef struct _StandValue
{
    _StandValue():mean_train_energy(0.0),mean_train_entropy(0.0),mean_train_contrast(0.0),mean_train_idmoment(0.0),sigma_train_energy(0.0),sigma_train_entropy(0.0),sigma_train_contrast(0.0),sigma_train_idmoment(0.0)
    {}
    double mean_train_energy;
    double mean_train_entropy;
    double mean_train_contrast;
    double mean_train_idmoment;
    double sigma_train_energy;
    double sigma_train_entropy;
    double sigma_train_contrast;
    double sigma_train_idmoment;
}
StandValue;

typedef struct _NormaData
{
    _NormaData():mean_energy_hor(0.0),mean_entropy_hor(0.0),mean_contrast_hor(0.0),mean_idmoment_hor(0.0),
                  mean_energy_ver(0.0),mean_entropy_ver(0.0),mean_contrast_ver(0.0),mean_idmoment_ver(0.0),
                  mean_energy_45(0.0),mean_entropy_45(0.0),mean_contrast_45(0.0),mean_idmoment_45(0.0),
                  mean_energy_135(0.0),mean_entropy_135(0.0),mean_contrast_135(0.0),mean_idmoment_135(0.0),
                  sum_energy_hor(0.0), sum_entropy_hor(0.0), sum_contrast_hor(0.0), sum_idmoment_hor(0.0),
                  sum_energy_ver(0.0), sum_entropy_ver(0.0), sum_contrast_ver(0.0), sum_idmoment_ver(0.0),
                  sum_energy_45(0.0), sum_entropy_45(0.0), sum_contrast_45(0.0), sum_idmoment_45(0.0),
                  sum_energy_135(0.0), sum_entropy_135(0.0), sum_contrast_135(0.0), sum_idmoment_135(0.0),
                  pow_energy_hor(0.0),pow_entropy_hor(0.0),pow_contrast_hor(0.0),pow_idmoment_hor(0.0),
                  pow_energy_ver(0.0),pow_entropy_ver(0.0),pow_contrast_ver(0.0),pow_idmoment_ver(0.0),
                  pow_energy_45(0.0),pow_entropy_45(0.0),pow_contrast_45(0.0),pow_idmoment_45(0.0),
                  pow_energy_135(0.0),pow_entropy_135(0.0),pow_contrast_135(0.0),pow_idmoment_135(0.0),
                  spow_energy_hor(0.0),spow_entropy_hor(0.0),spow_contrast_hor(0.0),spow_idmoment_hor(0.0),
                  spow_energy_ver(0.0),spow_entropy_ver(0.0),spow_contrast_ver(0.0),spow_idmoment_ver(0.0),
                  spow_energy_45(0.0),spow_entropy_45(0.0),spow_contrast_45(0.0),spow_idmoment_45(0.0),
                  spow_energy_135(0.0),spow_entropy_135(0.0),spow_contrast_135(0.0),spow_idmoment_135(0.0)
    {}
    double mean_energy_hor;
    double mean_entropy_hor;
    double mean_contrast_hor;
    double mean_idmoment_hor;
    double sum_entropy_hor;
    double sum_energy_hor;
    double sum_contrast_hor;
    double sum_idmoment_hor;
    double pow_entropy_hor;
    double pow_energy_hor;
    double pow_contrast_hor;
    double pow_idmoment_hor;
    double spow_entropy_hor;
    double spow_energy_hor;
    double spow_contrast_hor;
    double spow_idmoment_hor;

    double mean_energy_ver;
    double mean_entropy_ver;
    double mean_contrast_ver;
    double mean_idmoment_ver;
    double sum_entropy_ver;
    double sum_energy_ver;
    double sum_contrast_ver;
    double sum_idmoment_ver;
    double pow_entropy_ver;
    double pow_energy_ver;
    double pow_contrast_ver;
    double pow_idmoment_ver;
    double spow_entropy_ver;
    double spow_energy_ver;
    double spow_contrast_ver;
    double spow_idmoment_ver;

    double mean_energy_45;
    double mean_entropy_45;
    double mean_contrast_45;
    double mean_idmoment_45;
    double sum_entropy_45;
    double sum_energy_45;
    double sum_contrast_45;
    double sum_idmoment_45;
    double pow_entropy_45;
    double pow_energy_45;
    double pow_contrast_45;
    double pow_idmoment_45;
    double spow_entropy_45;
    double spow_energy_45;
    double spow_contrast_45;
    double spow_idmoment_45;

    double mean_energy_135;
    double mean_entropy_135;
    double mean_contrast_135;
    double mean_idmoment_135;
    double sum_entropy_135;
    double sum_energy_135;
    double sum_contrast_135;
    double sum_idmoment_135;
    double pow_entropy_135;
    double pow_energy_135;
    double pow_contrast_135;
    double pow_idmoment_135;
    double spow_entropy_135;
    double spow_energy_135;
    double spow_contrast_135;
    double spow_idmoment_135;
}NormaData;

//dlib库的相关变量定义;
typedef dlib::matrix<double,4,1>sample_type;
typedef dlib::radial_basis_kernel<sample_type>kernel_type;
typedef dlib::radial_basis_kernel<sample_type>kernel_type;
typedef dlib::decision_function<kernel_type>dec_funct_type;
typedef dlib::normalized_function<dec_funct_type>funct_type;
//灰度共生矩阵相关定义;
extern Vec2D Vec_hor;
extern Vec2D Vec_ver;
extern Vec2D Vec_45;
extern Vec2D Vec_135;
extern StandValue standValue_hor;
extern StandValue standValue_ver;
extern StandValue standValue_45;
extern StandValue standValue_135;
extern GLCMFeatures features_hor;
extern GLCMFeatures features_ver;
extern GLCMFeatures features_45;
extern GLCMFeatures features_135;

extern dlib::svm_c_trainer<kernel_type>trainer;
extern std::vector<sample_type>PSample;
extern std::vector<sample_type>NSample;
extern std::vector<sample_type>CSample;
extern std::vector<sample_type>AllSamples;
extern std::vector<double>All_labels;
extern funct_type learned_function;
extern dlib::rand rnd;

extern std::string RootFileName_P;
extern std::string RootFileName_N;
extern std::string RootFileName_C;
extern std::string RootSavePath_PS;
extern std::string RootSavePath_NG;
extern std::string RootPath_glcm;
extern std::string RootProjectPath;
//extern std::string SavePath_glcm_n;
extern std::vector<std::string>Vec_ImageFiles_p;
extern std::vector<std::string>Vec_ImageFiles_n;
extern std::vector<std::string>Vec_ImageFiles_c;
extern std::vector<std::string>Vec_RoiFiles_p;
extern std::vector<std::string>Vec_RoiFiles_n;
extern std::vector<std::string>Vec_RoiFiles_c;
extern std::vector<std::string>Vec_RoiFiles_A;
extern std::vector<cv::Mat>BMPImages_p;
extern std::vector<cv::Mat>BMPImages_n;
extern std::vector<cv::Mat>BMPImages_c;
extern std::vector<cv::Mat>BMPclass_p;
extern std::vector<cv::Mat>BMPclass_n;
extern std::vector<cv::Mat>ROI_p;
extern std::vector<cv::Mat>ROI_n;
extern std::vector<cv::Mat>ROI_c;
extern std::string tempSave;

extern cv::Mat Classes;
extern std::vector<int>trainLabels;
extern cv::Mat trainAlldatas;
extern cv::Mat trainingdatas;
extern cv::Mat PreDatas;
extern cv::Mat PredictingDatas;
extern CvSVMParams SVM_params;
extern CvSVM svm;
extern int respones;
extern int PrePoNum;
extern int PreNgNum;
extern std::vector<int>PSIndex;//分为正样本的索引;
extern std::vector<int>NGIndex;//分为负样本的索引;
extern int savenum;

#endif
 

上一篇:extern "C"与C++中的C函数调用(3)—— 如何在C++中调用C函数


下一篇:extern存储类