一、代码实现
# KNN
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
# data = load_iris()
url = 'https://www.gairuo.com/file/data/dataset/iris.data'
data = pd.read_csv(url)
data ["species"] = data["species"].map({"setosa":1,"virginica":0,"versicolor":2})
# 删除重复的数据
data = data.drop_duplicates()
# 查看各个类别的鸢尾花
data["species"].value_counts()
# KNN
class KNN:
"""
使用Python实现K近邻算法 实现分类
"""
def __init__(self,k):
"""
初始化方法
:parameter
:param k:int 邻居的个数
"""
self.k = k
def fit(self, X, y):
"""
训练方法
:parameter
:param X: 类数组类型 ,形状为:[样本数量,特征数量] 待训练的样本特征
:param y: 类数组类型,形状为:[样本数量] 每个样本的目标值 (标签)
:return: None
"""
# 将X,y转换成ndarray数组
self.X = np.asarray(X)
self.y = np.asarray(y)
def predict(self,X):
"""
根据参数传递的样本,对样本进行预测。
:parameter
:param X: 类数组类型 ,形状为:[样本数量,特征数量] 待训练的样本特征
:return: 数组的类型 (预测的结果)
"""
X = np.asarray(X)
result = []
# 对ndarray数组进行遍历,每次取数组的一行。
for x in X:
# 对于测试集的每个样本依次对训练集中的所有样本求距离
dis = np.sqrt(np.sum((x - self.X) ** 2,axis=1))
# 返回数组排序后,每个元素在原数组中的索引
index = dis.argsort()
# 进行截断,只取k个元素[取距离最近的k个元素的索引]
index = index[:self.k]
# 返回数组中每个元素出现的次数。元素必须是非负的整数
count = np.bincount(self.y[index])
# 返回ndarray数组中最大的元素对应的索引。该索引是我们判定的类别。
# 最大元素就是出现次数最多的元素
result.append(count.argmax())
pass
return np.asarray(result)
pass
# 提取出每个类别的鸢尾花数据
t0 = data[data["species"] == 0]
t1 = data[data["species"] == 1]
t2 = data[data["species"] == 2]
# 打乱顺序 每次打乱顺序都是一样的
t0 = t0.sample(len(t0),random_state=0)
t1 = t1.sample(len(t1),random_state=0)
t2 = t2.sample(len(t2),random_state=0)
# 构建训练集和测试集 concat:拼接数组 axis = 0 纵向拼接
train_X = pd.concat([t0.iloc[:40,:-1],t1.iloc[:40,:-1],t2.iloc[:40,:-1]],axis=0)
train_y = pd.concat([t0.iloc[:40,-1],t1.iloc[:40,-1],t2.iloc[:40,-1]],axis=0)
test_X = pd.concat([t0.iloc[40:,:-1],t1.iloc[40:,:-1],t2.iloc[40:,:-1]],axis=0)
test_y = pd.concat([t0.iloc[40:,-1],t1.iloc[40:,-1],t2.iloc[40:,-1]],axis=0)
# 创建KNN对象,进行训练与测试
knn = KNN(k=3)
knn.fit(train_X,train_y)
# 进行测试
result = knn.predict(test_X)
# 预测正确率
print("预测正确率 =",np.sum(result == test_y)/len(result))
二、代码结果
预测正确率 = 0.9629629629629629