Python神经网络识别手写数字实例(digits数据集)

2017年8月13日16:22:38 1 25,528

Python神经网络识别手写数字实例(digits数据集)

前言

在上一篇文章《Python实现BP神经网络算法》中,我们使用Python语言实现了一个简单的BP神经网络算法,并且使用异或数据集进行了简单的预测。建议在阅读本文之前先看一下算法的实现,因为我们这篇文章要用到神经网络算法。因为这篇文章中要用封装好的代码,所以我就在这上传一份(BpNN.py)

这篇文章我们就使用之前我们实现的神经网络算法对手写数字进行识别预测。在本篇文章中我们使用sklearn中自带的手写数字数据集(digits),这个数据集中并没有图片,而是经过提取得到的手写数字特征和标记,就免去了我们的提取数据的麻烦,但是在实际的应用中是需要我们对图片中的数据进行提取的,在下一篇文章中我们将介绍如何通过图片得到数据集并完成数字识别。

下面我就分模块介绍一下代码。

导入所需库

其中load_digits,就是我们这篇文章要用到的手写数字数据集。想必除了最后一个,其他个库应该不陌生了,最后一个就是我们上上篇文章实现的神经网络算法类,通过这种引入方式,我们可以直接在本文件中使用NeuralNetwork类。

加载数据集并进行数据预处理

数据集的加载非常简单,直接调用load_digits即可,具体里面包含了什么东西,大家可以print一下看看,在这里说一下主要用到的几个:

digits.data:手写数字特征向量数据集,每一个元素都是一个64维的特征向量。

digits.target:特征向量对应的标记,每一个元素都是自然是0-9的数字。

digits.images:对应着data中的数据,每一个元素都是8*8的二维数组,其元素代表的是灰度值,转化为以为是便是特征向量。如下图所示:

Python神经网络识别手写数字实例(digits数据集)

对于数据的预处理我们进数据集的特征值进行了简单的处理,让其每一个特征值都处在0-1之间,便于下面构造神经网络。

然后我们将数据集进行了切分,得到了8:2比例的训练集和测试集。

最后我们对标记进行了二值化处理,因为神经网络只认识0和1,所以我们将其化为了0和1的形式,比如0000000000代表数字0, 0100000000代表数组1, 0010000000代表数字2,依次化为该形式。

构造神经网络模型

在构造神经网络模型的时候,我们将输入神经元设置为64是因为我们的输入层有64个神经元;因为我们的输出结果一共分为了0-9这10类,所以我们的输出层就有10个神经元。当然,这样预测值肯定就会有10个结果,但我们只要预测输出中数值最大的那一个的索引,就是我们的最终预测结果;我们所用的激活函数是logistic,在上篇文章文已经介绍过了。

在构造模型的时候,我们设置了学习率是0.2,迭代次数是100,这两个参数和之前隐藏层的构造都不是最优的接过,如果想得到最优结构还需后期进行调优处理。

在这里我们对数据模型进行了保存,为什么要保存呢?因为在每次训练模型的时候我们都需要花费很长的时间,sklearn中带的joblib库可以帮助我们将模型保存为文件的形式存储到磁盘上,以后只要使用load方法直接加载该文件就可以直接使用该模型进行预测了。

数字识别和模型测评

因为前面我们已经切分得到了测试集,所以这里我们就直接拿测试集进行数字识别预测,代码中可以看出,每次调用predict预测之后的输出out都要经过np.argmax()处理,再保存到列表中,这是为什么呢?前面已经说过了,每次的预测输出都是一个含有是10个元素的数组,数组中最大的值得索引便是我们预测的结果,np.argmax()便是返回数组中最大值得索引。

有了预测结果之后我们就可以对模型进行评估,可打印出了预测报告和混淆矩阵。结果如下:

Python神经网络识别手写数字实例(digits数据集)

从结果中我们可以看到,预测的成功率已经高到0.96,通过后期的调优,以及适当增加训练次数,相信预测的成功率还会增加。

结束语

本文的手写数字识别只是一个简单的预测实例,如果想真正的做到预测我们还需进一步进行参数的调整等各调优工作。除此之外,有些同学可能不知道数据集中的特征向量是怎么来的,这一点不用方,可以参考我前面的文章《Python提取数字图片特征向量》

因为有些时候我们数据集需要自己寻找和处理,所以我们下一篇文章将从提取手写数字特征向量开始,一步一步完成手写数字的是识别算法。

最后奉上本文全部代码文件:

文件下载

发表评论

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

目前评论:1   其中:访客  1   博主  0

    • 博sir 0

      你好,博主。我想咨询一下,我怎么输出我的预测值呢,看了好多博主的都是直接输出准确率和混淆矩阵。我是一个初学者,一名研究生,还望回复。。谢谢您