Python神经网络识别手写数字实例(MNIST图片数据集)

2017年8月14日16:52:54 23 56,075

在上篇文章《Python神经网络识别手写数字实例(digits数据集)》中我们使用了前面实现的BP神经网络算法对sklearn中自带的手写数字进行了预测的实现,但是觉得使用别人提供好的数据集总不是很爽,如果能够自己来提取数据集那才牛逼,好,别着急,下面我们就这个干,从大量的图片中提取出我们需要的特征向量,然后训练模型,最后对图片中的数据进行预测。

在阅读本篇文章之前,建议看一前面的两篇文章《Python提取数字图片特征向量》《Python实现BP神经网络算法》,因为在这篇文章中我们需要用到其中的算法。

0x01 数据集

MNIST图片数据集:下载链接http://pan.baidu.com/s/1hsrHeAS 密码:34pj

在这个数据集中有6万多张0-9的黑白图片,分别分类在0-9这十个文件夹中,图片的大小是28*28的,如下图所示:

Python神经网络识别手写数字实例(MNIST图片数据集)

因为这个数据集特别的大,如果使用全部的图片数据,那么在提取数据的时候可能会花费较长的时间,所以我们取每类图片的前1000张来进行测试即可。

0x02 提取图片特征向量

文件名:getImageDate.py

看过《Python提取数字图片特征向量》这篇文章的同学应该知道这段代码的意思,在那片文章中我没有将它封装成类,这里我把它封装好了,只要知道返回值是一张图片的网格特征值,看不懂的就去看一下那篇文章吧。

0x03 导入所需库

较上篇文章而言,我们就多引入了一个getImageData文件中的ImageData类,这个就是前面的获取图片网格特性向量的类。另外BpNN是我的神经网络算法的文件。后面我们把这两个文件下载链接给出。

0x04 提取图片的网格特征向量

这里我们将数据集和标记分别保存在X和Y列表中,最后再分别转化为数组。

手写数字图片我放在了numImage文件夹下面,一个有0-9个子文件夹分别放着其对应的数组图片。关键的地方在于z.getData(2)这个地方,我们前面说过,图片的尺存是28*28的,我们如果直接颜色数据转化为二维数组那么就是28*28的点阵,如果转化为特征向量那就是28 * 28 = 784维,这个维度太高了,我们需要对其进行缩小,其方法就是划分网格来获取其网格特征数据,我们将其横竖分别隔两个像素划线那么,总共得到的14*14的网格数据,我们只计算每个网格中为1的数据和。最终得到的就是 14 * 14 = 196维的特征向量,较之前的相比已经小了很多了。具体的描述就去看前面的文章吧。

其实到这位置我们基本上就拿到了数据集,剩下的就是训练模型和预测评估了,如果你阅读过之前的文章那么下面就可以不用再看了。

0x05 构建BP神经网络模型

这里我们对前面提取到的数据进行了部分处理,先是对数据集按照8:2的比例进行了切分,然后又对标记进行了二值化处理,什么是二值化前面已经说过了,这里再说一遍,二值化就是将其化为0和1的形式,比如0000000000就代表数字0,0100000000就代表数字1,0010000000就代表数字2.......。

在构造神经网络的就跟的时候我们将其输入层设置为14*14=196个神经元,就跟我们的特征向量中的特征值得个数是一样的,因为我们的分类一个是0-9这个10个类别,所以输出层的神经元个数也是10,中间隐藏层的个数是100个,学习率还是0.1,迭代次数是60。经过我的测试发现,因为数据集的数据特别的多,在训练的时候迭代次数不必太多就能达到我们想要的结果,并且每次的训练所需时间比较长,所以在迭代的时候建议不必太大。

因为每次训练模型都要花费很大的代价,所以这里我们就将模型保存到磁盘在,之后直接读取即可使用。

0x06 数字识别和模型测评

这里其实已经没有什么好说的了,就是直接调用predict将测试集传入得到预测值,唯一要注意的一点就是输出的out是一个10元素的数组,因为其预测结果是每个类的概率值,所以预测结果应该是最大元素的索引。

最后我们打印了预测报告和预测结果的混淆矩阵。这两个模型评估的方式在前面的文章也有提过,可参看《Python机器学习SVM人脸识别实例》

最终输出结果如下:

Python神经网络识别手写数字实例(MNIST图片数据集)

从图中的评分中我们可以看出,最终的模型评测的正确率只有0.91,比之前的sklearn中digits数据集要少5个百分点,这问题我也调了几个参数,不过也没有达到那么高,因为这就是个测试实例,没有牵扯到参数的调优,所以我就没有再进一步提高准确率,感兴趣的同学可以试一下,欢迎随时跟我进行交流。

0x07 结束语

看过前一篇文章的同学应该看出来了,这篇文章跟前一篇的不同之处关键多了一步图片网格特征向量的提取。

折腾了这么久,到此为止我们就把机器学习基础中的分类问题讲的差不多了,后面将开始新的篇章,线性预测的问题。

最后奉上本文全部代码文件:Python神经网络手写数字识别.rar

发表评论

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

目前评论:23   其中:访客  16   博主  7

    • 胡子大叔 0

      你好,如果是自己手写一个数字,把它转成指定格式,然后去让程序识别出它是几,该怎么做呢?

      • 简单点 0

        你好,要是自己手写一个数字,该怎么使用这个程序识别呢?

        • ccccccfl 1

          不能用呢

            • 马瑞强 Admin

              @ccccccfl 注意Python环境是2.7 使用了Anaconda的框架,如果再Python3.X系列运行时要改一些代码的。

            • ccccccfl 1

              能联系你一下吗?有一些问题请教

              • ccccccfl 1

                我用的2.7 一些包我也下载了 新手 不太会用呢 提示我ValueError: numpy.dtype has the wrong size, try recompiling. Expected 52, got 56

                  • 马瑞强 Admin

                    @ccccccfl 建议你把原来的Python卸载,直接安装Anaconda2.7系列,它里面带了很多常用的库,不需要再自己安装。

                  • fghajskhfs 0

                    你好,代码里没有Getimagedata这个文件,压缩文件里没有,可以发给我一下吗

                    • Ling 1

                      请问一下手写数字的图片在哪下载

                      • Ling 1

                        你好,我想请问一下,您提取图片中的特征向量中的图片类型是什么,我运行有报错说‘float’object cannot be interpreted as an integer,这要怎么修改代码呢

                          • 马瑞强 Admin

                            @Ling 注意一下,Demo中用的Python2.7,目测你的报错是Python3.X,Python3.X中两整数相除会得到浮点数哦,强制类型转换一下就好了。

                          • Ling 1

                            好的,问题已解决,非常感谢❤️

                            • 小k 1

                              你好,我用的py3,想请问一下,导入路径那里的%s是表示图片的名称嘛?例如1,2,3这样

                              • 小k 1

                                你好,我在用py3改写以后运行报错,ndarray has no attribute point,但是我觉得在getImageData类中已经定义了point,不知道原因,想请教一下

                                • 小L 0

                                  temp.append(sum(sum(imArray[0 + offset_y:num + offset_y, 0 + offset_x:num + offset_x])))
                                  请问这行代码具体要怎么理解呢?

                                    • 马瑞强 Admin

                                      @小L 矩阵区域相加,具体请参考https://www.k2zone.cn/?p=977

                                    • 请输入您的QQ号 1

                                      No module named ‘getImageData’怎么解决

                                        • 马瑞强 Admin

                                          @请输入您的QQ号 请仔细阅读文章,“文件名:getImageDate.py”,导入这个文件就是getImageData模块

                                        • gorge 0

                                          你好,我用python3运行程序出现这个错误: D:\Anaconda3\lib\site-packages\sklearn\metrics\classification.py:1135: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples.
                                          ‘precision’, ‘predicted’, average, warn_for)

                                          请问怎么改代码啊

                                          • 哆啦潘 0

                                            请问博主 X.append(data*0.1) 这里为什么要乘 0.1

                                            • 李明 1

                                              您好 getImageData库找不到