kNN(K-Nearest Neighbor)最邻近规则分类

KNN最邻近规则,主要应用领域是对未知事物的识别,即推断未知事物属于哪一类,推断思想是,基于欧几里得定理,推断未知事物的特征和哪一类已知事物的的特征最接近;

K近期邻(k-Nearest Neighbor,KNN)分类算法,是一个理论上比較成熟的方法,也是最简单的机器学习算法之中的一个。该方法的思路是:假设一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。KNN算法中,所选择的邻居都是已经正确分类的对象。该方法在定类决策上仅仅根据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。 KNN方法尽管从原理上也依赖于极限定理,但在类别决策时,仅仅与极少量的相邻样本有关。因为KNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,KNN方法较其它方法更为适合。

  KNN算法不仅能够用于分类,还能够用于回归。通过找出一个样本的k个近期邻居,将这些邻居的属性的平均值赋给该样本,就能够得到该样本的属性。更实用的方法是将不同距离的邻居对该样本产生的影响给予不同的权值(weight),如权值与距离成正比(组合函数)。

  该算法在分类时有个基本的不足是,当样本不平衡时,如一个类的样本容量非常大,而其它类样本容量非常小时,有可能导致当输入一个新样本时,该样本的K个邻居中大容量类的样本占多数。 该算法仅仅计算“近期的”邻居样本,某一类的样本数量非常大,那么或者这类样本并不接近目标样本,或者这类样本非常靠近目标样本。不管如何,数量并不能影响执行结果。能够採用权值的方法(和该样本距离小的邻居权值大)来改进。该方法的还有一个不足之处是计算量较大,由于对每个待分类的文本都要计算它到全体已知样本的距离,才干求得它的K个近期邻点。眼下经常使用的解决方法是事先对已知样本点进行剪辑,事先去除对分类作用不大的样本。该算法比較适用于样本容量比較大的类域的自己主动分类,而那些样本容量较小的类域採用这样的算法比較easy产生误分。

K-NN能够说是一种最直接的用来分类未知数据的方法。基本通过以下这张图跟文字说明就能够明确K-NN是干什么的

kNN(K-Nearest Neighbor)最邻近规则分类

简单来说,K-NN能够看成:有那么一堆你已经知道分类的数据,然后当一个新数据进入的时候,就開始跟训练数据里的每一个点求距离,然后挑离这个训练数据近期的K个点看看这几个点属于什么类型,然后用少数服从多数的原则,给新数据归类。

算法步骤:

step.1---初始化距离为最大值

step.2---计算未知样本和每一个训练样本的距离dist

step.3---得到眼下K个最临近样本中的最大距离maxdist

step.4---假设dist小于maxdist,则将该训练样本作为K-近期邻样本

step.5---反复步骤2、3、4,直到未知样本和全部训练样本的距离都算完

step.6---统计K-近期邻样本中每一个类标号出现的次数

step.7---选择出现频率最大的类标号作为未知样本的类标号

KNN的matlab简单实现代码

function target=KNN(in,out,test,k)

% in:       training samples data,n*d matrix

% out: training samples' class label,n*1

% test:     testing data

% target:   class label given by knn

% k:        the number of neighbors

ClassLabel=unique(out);

c=length(ClassLabel);

n=size(in,1);

% target=zeros(size(test,1),1);

dist=zeros(size(in,1),1);

for j=1:size(test,1)

    cnt=zeros(c,1);

    for i=1:n

        dist(i)=norm(in(i,:)-test(j,:));

    end

    [d,index]=sort(dist);

    for i=1:k

        ind=find(ClassLabel==out(index(i)));

        cnt(ind)=cnt(ind)+1;

    end

    [m,ind]=max(cnt);

    target(j)=ClassLabel(ind);

end

R语言的实现代码例如以下

library(class)

data(iris)

names(iris)

m1<-knn.cv(iris[,1:4],iris[,5],k=3,prob=TRUE)

attributes(.Last.value)

library(MASS)

m2<-lda(iris[,1:4],iris[,5])  与判别分析进行比較

b<-data.frame(Sepal.Length=6,Sepal.Width=4,Petal.Length=5,Petal.Width=6)

p1<-predict(m2,b,type="class")

C++ 实现 :

//    KNN.cpp     K-近期邻分类算法

//

////////////////////////////////////////////////////////////////////////////////////////////////////////

#include <stdlib.h>

#include <stdio.h>

#include <memory.h>

#include <string.h>

#include <iostream>

#include <math.h>

#include <fstream>

using namespace std;

////////////////////////////////////////////////////////////////////////////////////////////////////////

//

//    宏定义

//

////////////////////////////////////////////////////////////////////////////////////////////////////////

#define  ATTR_NUM  4                        //属性数目

#define  MAX_SIZE_OF_TRAINING_SET  1000      //训练数据集的最大大小

#define  MAX_SIZE_OF_TEST_SET      100       //測试数据集的最大大小

#define  MAX_VALUE  10000.0                  //属性最大值

#define  K  7

//结构体

struct dataVector {

 int ID;                      //ID号

 char classLabel[15];             //分类标号

 double attributes[ATTR_NUM]; //属性 

};

struct distanceStruct {

 int ID;                      //ID号

 double distance;             //距离

 char classLabel[15];             //分类标号

};

////////////////////////////////////////////////////////////////////////////////////////////////////////

//

//    全局变量

//

////////////////////////////////////////////////////////////////////////////////////////////////////////

struct dataVector gTrainingSet[MAX_SIZE_OF_TRAINING_SET]; //训练数据集

struct dataVector gTestSet[MAX_SIZE_OF_TEST_SET];         //測试数据集

struct distanceStruct gNearestDistance[K];                //K个近期邻距离

int curTrainingSetSize=0;                                 //训练数据集的大小

int curTestSetSize=0;                                     //測试数据集的大小

////////////////////////////////////////////////////////////////////////////////////////////////////////

//

//    求 vector1=(x1,x2,...,xn)和vector2=(y1,y2,...,yn)的欧几里德距离

//

////////////////////////////////////////////////////////////////////////////////////////////////////////

double Distance(struct dataVector vector1,struct dataVector vector2)

{

 double dist,sum=0.0;

 for(int i=0;i<ATTR_NUM;i++)

 {

  sum+=(vector1.attributes[i]-vector2.attributes[i])*(vector1.attributes[i]-vector2.attributes[i]);

 }

 dist=sqrt(sum);

 return dist;

}

////////////////////////////////////////////////////////////////////////////////////////////////////////

//

//    得到gNearestDistance中的最大距离,返回下标

//

////////////////////////////////////////////////////////////////////////////////////////////////////////

int GetMaxDistance()

{

 int maxNo=0;

 for(int i=1;i<K;i++)

 {

  if(gNearestDistance[i].distance>gNearestDistance[maxNo].distance) maxNo = i;

 }

    return maxNo;

}

////////////////////////////////////////////////////////////////////////////////////////////////////////

//

//    对未知样本Sample分类

//

////////////////////////////////////////////////////////////////////////////////////////////////////////

char* Classify(struct dataVector Sample)

{

 double dist=0;

 int maxid=0,freq[K],i,tmpfreq=1;;

 char *curClassLable=gNearestDistance[0].classLabel;

 memset(freq,1,sizeof(freq));

 //step.1---初始化距离为最大值

 for(i=0;i<K;i++)

 {

  gNearestDistance[i].distance=MAX_VALUE;

 }

 //step.2---计算K-近期邻距离

 for(i=0;i<curTrainingSetSize;i++)

 {

  //step.2.1---计算未知样本和每一个训练样本的距离

  dist=Distance(gTrainingSet[i],Sample);

  //step.2.2---得到gNearestDistance中的最大距离

  maxid=GetMaxDistance();

  //step.2.3---假设距离小于gNearestDistance中的最大距离,则将该样本作为K-近期邻样本

  if(dist<gNearestDistance[maxid].distance)

  {

   gNearestDistance[maxid].ID=gTrainingSet[i].ID;

   gNearestDistance[maxid].distance=dist;

   strcpy(gNearestDistance[maxid].classLabel,gTrainingSet[i].classLabel);

  }

 }

 //step.3---统计每一个类出现的次数

 for(i=0;i<K;i++) 

 {

  for(int j=0;j<K;j++)

  {

   if((i!=j)&&(strcmp(gNearestDistance[i].classLabel,gNearestDistance[j].classLabel)==0))

   {

    freq[i]+=1;

   }

  }

 }

 //step.4---选择出现频率最大的类标号

 for(i=0;i<K;i++)

 {

  if(freq[i]>tmpfreq) 

  {

   tmpfreq=freq[i];

    curClassLable=gNearestDistance[i].classLabel;

  }

 }

 return curClassLable;

}

////////////////////////////////////////////////////////////////////////////////////////////////////////

//

//    主函数

//

////////////////////////////////////////////////////////////////////////////////////////////////////////

void main()

{  

 char c;

    char *classLabel="";

 int i,j, rowNo=0,TruePositive=0,FalsePositive=0;

 ifstream filein("iris.data");

 FILE *fp;

 if(filein.fail()){cout<<"Can't open data.txt"<<endl; return;}

 //step.1---读文件 

 while(!filein.eof())

 {

  rowNo++;//第一组数据rowNo=1

  if(curTrainingSetSize>=MAX_SIZE_OF_TRAINING_SET)

  {

   cout<<"The training set has "<<MAX_SIZE_OF_TRAINING_SET<<" examples!"<<endl<<endl;

   break ;

  }  

  //rowNo%3!=0的100组数据作为训练数据集

  if(rowNo%3!=0)

  {   

   gTrainingSet[curTrainingSetSize].ID=rowNo;

   for(int i = 0;i < ATTR_NUM;i++)

   {     

    filein>>gTrainingSet[curTrainingSetSize].attributes[i];

    filein>>c;

   }   

   filein>>gTrainingSet[curTrainingSetSize].classLabel;

   curTrainingSetSize++;

   

  }

  //剩下rowNo%3==0的50组做測试数据集

  else if(rowNo%3==0)

  {

   gTestSet[curTestSetSize].ID=rowNo;

   for(int i = 0;i < ATTR_NUM;i++)

   {    

    filein>>gTestSet[curTestSetSize].attributes[i];

    filein>>c;

   }  

   filein>>gTestSet[curTestSetSize].classLabel;

   curTestSetSize++;

  }

 }

 filein.close();

 //step.2---KNN算法进行分类,并将结果写到文件iris_OutPut.txt

 fp=fopen("iris_OutPut.txt","w+t");

 //用KNN算法进行分类

 fprintf(fp,"************************************程序说明***************************************\n");

 fprintf(fp,"** 採用KNN算法对iris.data分类。为了操作方便,对各组数据加入rowNo属性,第一组rowNo=1!\n");

 fprintf(fp,"** 共同拥有150组数据,选择rowNo模3不等于0的100组作为训练数据集,剩下的50组做測试数据集\n");

 fprintf(fp,"***********************************************************************************\n\n");

 fprintf(fp,"************************************实验结果***************************************\n\n");

 for(i=0;i<curTestSetSize;i++)

 {

        fprintf(fp,"************************************第%d组数据**************************************\n",i+1);

  classLabel =Classify(gTestSet[i]);

     if(strcmp(classLabel,gTestSet[i].classLabel)==0)//相等时,分类正确

  {

   TruePositive++;

  }

  cout<<"rowNo: ";

  cout<<gTestSet[i].ID<<"    \t";

  cout<<"KNN分类结果:      ";

cout<<classLabel<<"(正确类标号: ";

  cout<<gTestSet[i].classLabel<<")\n";

  fprintf(fp,"rowNo:  %3d   \t  KNN分类结果:  %s ( 正确类标号:  %s )\n",gTestSet[i].ID,classLabel,gTestSet[i].classLabel);

  if(strcmp(classLabel,gTestSet[i].classLabel)!=0)//不等时,分类错误

  {

  // cout<<"   ***分类错误***\n";

   fprintf(fp,"                                                                      ***分类错误***\n");

  }

  fprintf(fp,"%d-最临近数据:\n",K);

  for(j=0;j<K;j++)

  {

  // cout<<gNearestDistance[j].ID<<"\t"<<gNearestDistance[j].distance<<"\t"<<gNearestDistance[j].classLabel[15]<<endl;

   fprintf(fp,"rowNo:  %3d   \t   Distance:  %f   \tClassLable:    %s\n",gNearestDistance[j].ID,gNearestDistance[j].distance,gNearestDistance[j].classLabel);

  }

  fprintf(fp,"\n");

 }

    FalsePositive=curTestSetSize-TruePositive;

 fprintf(fp,"***********************************结果分析**************************************\n",i);

 fprintf(fp,"TP(True positive): %d\nFP(False positive): %d\naccuracy: %f\n",TruePositive,FalsePositive,double(TruePositive)/(curTestSetSize-1));

 fclose(fp);

    return;

}

以上内容为參考网上有关资料;加以总结;

上一篇:golang json


下一篇:java基础知识点整理