机器学习实战_k近邻算法识别手写数字

  代码如下:

import numpy as np
import operator
from os import listdir

def classify0(inX, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]
    diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet
    sqDiffMat = diffMat ** 2
    sqDistances = sqDiffMat.sum(axis=1)
    distances = sqDistances ** 0.5
    sortedDistIndices = distances.argsort()
    classCount = {}

    for i in range(k):
        voteIlabel = labels[sortedDistIndices[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1

    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

def img2vector(filename):
    returnVect = np.zeros((1, 1024))
    fr = open(filename)

    for i in range(32):
        lineStr = fr.readline()

        for j in range(32):
            returnVect[0, 32 * i + j] = int(lineStr[j])

    return returnVect

def handwritingClassTest():
    hwLabels = []
    trainingFileList = listdir('trainingDigits') 
    m = len(trainingFileList)  
    trainingMat = np.zeros((m, 1024))  

    for i in range(m):  
        fileNameStr = trainingFileList[i]  
        classNumber = int(fileNameStr.split('_')[0])  
        hwLabels.append(classNumber)  
        trainingMat[i, :] = img2vector('trainingDigits/%s' % (fileNameStr))

    testFileList = listdir('testDigits')  
    errorCount = 0.0  
    mTest = len(testFileList)  

    for i in range(mTest):  
        fileNameStr = testFileList[i]  
        classNumber = int(fileNameStr.split('_')[0])  
        vectorUnderTest = img2vector('testDigits/%s' % (fileNameStr))  
        classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)  
        print("the classifier came back with: %d\t the real answer is: %d" % (classifierResult, classNumber))

        if (classifierResult != classNumber):
            errorCount += 1.0

    print("the total number of errors is: %d\nthe total error rate is %f%%" % (errorCount, errorCount / mTest))

if __name__ == '__main__':
    handwritingClassTest()
上一篇:aarch64交叉编译dfu-programmer


下一篇:Python:numpy + sympy 求解 Ax = 0