机器学习实战
第二章
例题3
手写识别系统
书中 2.3 Page 28
代码如下:
# _*_ coding: utf-8 _*_
# @Time : 2020/11/20 下午10:29
# @Version:V 0.1
# @File : kNN.py
# @desc : none
# @software: PyCharm
import numpy as np
from numpy.core._multiarray_umath import ndarray
from os import listdir
from sklearn.neighbors import KNeighborsClassifier as kNN
def image_to_vector(filename: str) -> ndarray:
"""
将32x32的二进制图像转换为1x1024向量
:argument
filename: 文件名
:return
返回的二进制图像的1x1024向量
"""
vector = np.zeros((1, 1024)) # 构建图像向量
with open(filename, 'r') as f: # 打开文件
for i in range(0, 32): # 按行读取文件
str_line = f.readline() # 读取一行数据
for j in range(0, 32): # 将一行的数据传入图像向量中
vector[0, i * 32 + j] = int(str_line[j]) # 将单个整数数据传入图像向量
return vector
def hand_writing_training(filepath: str) -> kNN:
"""使用训练数据集训练KD-tree分类器
:argument
filepath: 训练数据集地址
:return
返回训练完成的分类器
"""
labels = [] # 训练数据标签
training_list = listdir(filepath) # 训练数据集的地址
m = len(training_list) # 训练数据个数
training_set = np.zeros((m, 1024)) # 训练数据集
for i in range(m): # 将所有训练数据导入array
filename = training_list[i] # 获取训练数据的文件名
training_data = image_to_vector(filepath + '/' + filename) # 提取单个训练数据
training_set[i, :] = training_data # 将单个数据放入数据集中
label = int(filename.split('_')[0]) # 提取数据标记
labels.append(label) # 记录数据标记
classifier = kNN(n_neighbors=3, algorithm='kd_tree', p=2) # 构建分类器 5近邻;KDTree;欧式距离
classifier.fit(training_set, labels)
return classifier
def hand_writing_testing(filepath: str, classifier: kNN):
""": 使用测试数据测试分类器
:argument
filepath: 测试数据集路径
classifier: 训练完成的分类器
:return
none
"""
test_list = listdir(filepath) # 测试数据集的地址
m = len(test_list) # 测试数据的个数
error_count = 0
for i in range(m):
filename = test_list[i] # 获取测试数据文件
test_data = image_to_vector(filepath + '/' + filename) # 获取测试数据
class_num = int(filename.split('_')[0]) # 获取测试数据的分类
classify_result = classifier.predict(test_data) # 获得预测结果
if class_num != classify_result: # 预测判断
error_count += 1.0 # 记录预测错误数
print('分类结果为%d\t真实结果为%d' % (classify_result, class_num)) # 输出预测有误的测试
error_rate = error_count / m # 预测错误率
print('测试总数等于%d错误总数等于%d\t错误率等于%.3f%%' % (m, error_count, error_rate * 100)) # 输出预测总数及预测错误率
def main():
"""
主函数
"""
train_filepath = 'a_filepath'
classifier = hand_writing_training(train_filepath)
test_filepath = 'b_filepath'
hand_writing_testing(test_filepath, classifier)
if __name__ == '__main__':
main()