博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
使用kNN算法实现简单的手写文字识别
阅读量:6574 次
发布时间:2019-06-24

本文共 3347 字,大约阅读时间需要 11 分钟。

0. 介绍

kNN,即k-Nearest Neighbor(k近邻算法), 简介可参考. 本文是《机器学习实战》一书第二章的例子, 主要利用kNN实现简单的手写文字识别.

书中使用Python实现, 本文是使用R语言. 数据集中的图片分辨率为32*32, 并且该数据已经预处理成文本文件, 即类似点阵字体, 使用1代表有文字的像素, 0表示空白.

1. kNN算法实现

算法的步骤主要有:

  1. 计算测试数据到所有训练数据的距离

  2. 对1中计算的距离排序, 选出最小k个训练数据

  3. 在2中选出的k个数据中选取出现几率最大的标签, 此即算法对测试数据的分类

排序的时候, 利用的是order方法, 取出降序排序元素的索引, 这在numpy中对应的方法是argsort.

实现代码如下:

classify0 <- function(inX, dataSet, labels, k){    dataSetSize = length(dataSet[,1])    #扩展测试向量inX    oneMat = matrix(1, dataSetSize, 1)    dataMat = oneMat %*% inX    #计算距离    dataMat = dataMat - dataSet    sqDiffMat = dataMat ** 2    sqDistances = rowSums(sqDiffMat)    distances = sqDistances ** 0.5    #选择距离最小的k个点    #按第一列升序排列获取序号    sortedDistIndicies = order(distances)    voteLabelsCount = rep(0, length(labels))    for(i in 1:k){        #获取第k小距离数据的标签        label = labels[sortedDistIndicies[i]]        index = which(labels == label)        voteLabelsCount[index[1]] = voteLabelsCount[index[1]] + 1    }    sortedVoteLabelsCount = order(-voteLabelsCount)    return(labels[sortedVoteLabelsCount[1]])}

2. 准备数据

本次实践准备的数据在两个文件目录中,

  • trainingDigits -- 包含2000个例子, 每个数字大概200个.

  • testDigits -- 包含大约900个例子.

trainingDigits中的数据将用于训练分类器, testDigits中的数据将用于测试分类器的效果.

由于原始数据是32*32的矩阵, 现在需要将其转化为1*1024的向量. 程序如下:

img2vector <- function(filename){    returnVect = matrix(0,1,1024)    con = file(filename, "r")    for(i in 0:31){        line = readLines(con,n=1)        for(j in 1:32){            returnVect[1,(32*i+j)] = as.numeric(substr(line,j,j))        }    }    close(con)    return(returnVect)}

3. 测试算法

主要的任务是从数据文件中提取所有的用例, 然后调用上面所述的classify0img2vector函数实现识别工作, 并计算错误率以供参考.

图像文本文件的命名格式为"a_b.txt", a表示当前文件的数字, b表示这是该数字的第b个例子. R对于文本的处理是比较弱的, 不过对于这点内容还是能应付, 使用了一点正则替换搞定.

处理完数据调用核心的classify0函数即可. 具体代码如下:

hardwritingTest <- function(){    print("the test start.")    print("read trainingDigits.")    trainingFileList = Sys.glob("trainingDigits/*.txt")    m = length(trainingFileList)    hwLabels = rep(0, m)    trainingMat = matrix(0,m,1024)    for(i in 1:m){        fileNameStr = trainingFileList[i]        #提取数字        fileStr = sub("trainingDigits/", "", fileNameStr)        fileStr = sub("_[0-9]+.txt", "", fileStr)        classNumStr = as.numeric(fileStr)        hwLabels[i] = classNumStr        trainingMat[i,] = img2vector(trainingFileList[i])    }    print("read testDigits.")    testFileList = Sys.glob("testDigits/*.txt")    errorCount = 0.0    mTest = length(testFileList)    for(i in 1:mTest){        fileNameStr = testFileList[i]        fileStr = sub("testDigits/", "", fileNameStr)        fileStr = sub("_[0-9]+.txt", "", fileStr)        classNumStr = as.numeric(fileStr)        vectorUnderTest = img2vector(testFileList[i])        print(paste0("classify the ", i, "th testDigit."))        classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)        print(paste0("--the classifier came back with: ", classifierResult, ", the real answer is: ", classNumStr))        if(classifierResult != classNumStr){            errorCount = errorCount + 1.0        }    }    print(paste0("the total number of errors is: ", errorCount))    print(paste0("the total error rate is: ", (errorCount / mTest)))}

4. 小结

kNN算法的分类思路是很简单的, 实现起来也很方便. 在对数据集测试的时候, 错误率在1.27%, 这个结果还是比较不错的.

不足之处是这种即时训练消耗了过多的时间和空间, 时间主要消耗在读取文件建立数据集和计算距离的时候. 在实际过程中, 前者可以缓存数据, 达到一次读取多次使用; 后者便很难优化了, 这其中涉及到了高阶矩阵的运算, 开销较大. 因此该算法在大规模数据时不宜采用.

转载地址:http://lggjo.baihongyu.com/

你可能感兴趣的文章
使用phppgadmin 遇到的小问题
查看>>
BFS小结
查看>>
Jquery页面跳转
查看>>
poj 3211 Washing Clothes (01)
查看>>
Ruby小白入门笔记之<Rubymine工具的快捷键>
查看>>
Media Session API 为当前正在播放的视频,音频,提供元数据来自定义媒体通知
查看>>
yum -y install php-mysql 版本冲突
查看>>
【7.17总结】 匈牙利算法(二分图最大匹配)
查看>>
JDBC(转)
查看>>
大端小段详解—转载
查看>>
告别LVS:使用keepalived+nginx实现负载均衡代理多个https
查看>>
征服 Redis + Jedis + Spring (一)—— 配置&常规操作(GET SET DEL)
查看>>
[转载]触摸屏网站制作的小细节
查看>>
[转载]INNO Setup 使用笔记
查看>>
Servlet--HttpSession接口,HttpSessionContext接口,Cookie类
查看>>
Android世界第一个activity启动过程
查看>>
RR调度(Round-robin scheduling)简单介绍
查看>>
重载函数编译后的新名字
查看>>
oracle resetlog与noresetlog的作用(转载)
查看>>
linux服务器内存占用太高-释放内存
查看>>