正在阅读:

Python神经网络识别手写数字实例(MNIST图片数据集)

在上篇文章《Python神经网络识别手写数字实例(digits数据集)》中我们使用了前面实现的BP神经网络算法对sklearn中自带的手写数字进行了预测的实现,但是觉得使用别人提供好的数据集总不是很爽,如果能够自己来提取数据集那才牛逼,好,别着急,下面我们就这个干,从大量的图片中提取出我们需要的特征向量,然后训练模型,最后对图片中的数据进行预测。

在阅读本篇文章之前,建议看一前面的两篇文章《Python提取数字图片特征向量》《Python实现BP神经网络算法》,因为在这篇文章中我们需要用到其中的算法。

0x01 数据集

MNIST图片数据集:下载链接http://pan.baidu.com/s/1hsrHeAS 密码:34pj

在这个数据集中有6万多张0-9的黑白图片,分别分类在0-9这十个文件夹中,图片的大小是28*28的,如下图所示:

QQ截图20170814154926.jpg

因为这个数据集特别的大,如果使用全部的图片数据,那么在提取数据的时候可能会花费较长的时间,所以我们取每类图片的前1000张来进行测试即可。

0x02 提取图片特征向量

文件名:getImageDate.py


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#coding:utf8
import numpy as np
from PIL import Image
from sklearn.externals import joblib
import  os

class ImageData:
    def __init__(self, image):
        self.image = image

    # 二值化
    def point(self, z = 80):
        return self.image.point(lambda x:1. if x > z else 0.)

    # 将二值化后的数组转化成网格特征统计图
    def get_features(self, imArray,num):
        # 拿到数组的高度和宽度
        h, w = imArray.shape
        data = []
        for x in range(0, w / num):
            offset_y = x * num
            temp = []
            for y in range(0, h / num):
                offset_x = y * num
                # 统计每个区域的1的值
                temp.append(sum(sum(imArray[0 + offset_y:num + offset_y, 0 + offset_x:num + offset_x])))
            data.append(temp)
        return np.asarray(data)

    def getData(self,num):
        img = self.point()
        # 将图片转换为数组形式,元素为其像素的亮度值
        img_array = np.asarray(img)
        # 得到网格特征统计图
        features_array = self.get_features(img_array,num)
        # print features_array
        return features_array.reshape(features_array.shape[0]*features_array.shape[1])

看过《Python提取数字图片特征向量》这篇文章的同学应该知道这段代码的意思,在那片文章中我没有将它封装成类,这里我把它封装好了,只要知道返回值是一张图片的网格特征值,看不懂的就去看一下那篇文章吧。

0x03 导入所需库


1
2
3
4
5
6
7
8
9
10
#coding:utf8

from PIL import Image
from sklearn.cross_validation import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.preprocessing import LabelBinarizer
import numpy as np
from getImageData import ImageData
from BpNN import NeuralNetwork
import os

较上篇文章而言,我们就多引入了一个getImageData文件中的ImageData类,这个就是前面的获取图片网格特性向量的类。另外BpNN是我的神经网络算法的文件。后面我们把这两个文件下载链接给出。

0x04 提取图片的网格特征向量


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
###############提取图片中的特征向量####################
X = []
Y = []

for i in range(0, 10):
    #遍历文件夹,读取数字图片
    for f in os.listdir("numImage/%s" % i):
        #打开一张文件并灰度化
        im = Image.open("numImage/%s/%s" % (i, f)).convert("L")
        #使用ImageData类
        z = ImageData(im)
        #获取图片网格特征向量,2代表每上下2格和左右两格为一组
        data = z.getData(2)
        X.append(data*0.1)
        Y.append(i)

X = np.array(X)
Y = np.array(Y)

这里我们将数据集和标记分别保存在X和Y列表中,最后再分别转化为数组。

手写数字图片我放在了numImage文件夹下面,一个有0-9个子文件夹分别放着其对应的数组图片。关键的地方在于z.getData(2)这个地方,我们前面说过,图片的尺存是28*28的,我们如果直接颜色数据转化为二维数组那么就是28*28的点阵,如果转化为特征向量那就是28 * 28 = 784维,这个维度太高了,我们需要对其进行缩小,其方法就是划分网格来获取其网格特征数据,我们将其横竖分别隔两个像素划线那么,总共得到的14*14的网格数据,我们只计算每个网格中为1的数据和。最终得到的就是 14 * 14 = 196维的特征向量,较之前的相比已经小了很多了。具体的描述就去看前面的文章吧。

其实到这位置我们基本上就拿到了数据集,剩下的就是训练模型和预测评估了,如果你阅读过之前的文章那么下面就可以不用再看了。

0x05 构建BP神经网络模型


1
2
3
4
5
6
7
8
9
10
11
12
13
14
#切分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=1)
#对标记进行二值化
labels_train = LabelBinarizer().fit_transform(y_train)

###########构造神经网络模型################
#构建神经网络结构
nn = NeuralNetwork([14*14, 100, 10], 'logistic')
#训练模型
nn.fit(X_train, labels_train, learning_rate=0.2, epochs=60)
#保存模型
# joblib.dump(nn, 'model/nnModel.m')
#加载模型
# nn = joblib.load('model/nnModel.m')

这里我们对前面提取到的数据进行了部分处理,先是对数据集按照8:2的比例进行了切分,然后又对标记进行了二值化处理,什么是二值化前面已经说过了,这里再说一遍,二值化就是将其化为0和1的形式,比如0000000000就代表数字0,0100000000就代表数字1,0010000000就代表数字2.......。

在构造神经网络的就跟的时候我们将其输入层设置为14*14=196个神经元,就跟我们的特征向量中的特征值得个数是一样的,因为我们的分类一个是0-9这个10个类别,所以输出层的神经元个数也是10,中间隐藏层的个数是100个,学习率还是0.1,迭代次数是60。经过我的测试发现,因为数据集的数据特别的多,在训练的时候迭代次数不必太多就能达到我们想要的结果,并且每次的训练所需时间比较长,所以在迭代的时候建议不必太大。

因为每次训练模型都要花费很大的代价,所以这里我们就将模型保存到磁盘在,之后直接读取即可使用。

0x06 数字识别和模型测评


1
2
3
4
5
6
7
8
9
10
11
12
13
###############数字识别####################
#存储预测结果
predictions = []
#对测试集进行预测
for i in range(y_test.shape[0]):
    out = nn.predict(X_test[i])
    predictions.append(np.argmax(out))

###############模型评测#####################
#打印预测报告
print confusion_matrix(y_test, predictions)
#打印预测结果混淆矩阵
print classification_report(y_test, predictions)

这里其实已经没有什么好说的了,就是直接调用predict将测试集传入得到预测值,唯一要注意的一点就是输出的out是一个10元素的数组,因为其预测结果是每个类的概率值,所以预测结果应该是最大元素的索引。

最后我们打印了预测报告和预测结果的混淆矩阵。这两个模型评估的方式在前面的文章也有提过,可参看《Python机器学习SVM人脸识别实例》

最终输出结果如下:

QQ截图20170814163414.jpg

从图中的评分中我们可以看出,最终的模型评测的正确率只有0.91,比之前的sklearn中digits数据集要少5个百分点,这问题我也调了几个参数,不过也没有达到那么高,因为这就是个测试实例,没有牵扯到参数的调优,所以我就没有再进一步提高准确率,感兴趣的同学可以试一下,欢迎随时跟我进行交流。

0x07 结束语

看过前一篇文章的同学应该看出来了,这篇文章跟前一篇的不同之处关键多了一步图片网格特征向量的提取。

折腾了这么久,到此为止我们就把机器学习基础中的分类问题讲的差不多了,后面将开始新的篇章,线性预测的问题。

最后奉上本文全部代码文件:Python神经网络手写数字识别.rar

目前有:11条访客评论,博主回复5

  1. 胡子大叔
    2018-02-01 14:12

    你好,如果是自己手写一个数字,把它转成指定格式,然后去让程序识别出它是几,该怎么做呢?

  2. 简单点
    2018-02-01 14:15

    你好,要是自己手写一个数字,该怎么使用这个程序识别呢?

  3. ccccccfl
    2018-03-20 21:09

    不能用呢

    • 马瑞强
      2018-03-20 21:22

      注意Python环境是2.7 使用了Anaconda的框架,如果再Python3.X系列运行时要改一些代码的。

  4. ccccccfl
    2018-03-20 21:19

    能联系你一下吗?有一些问题请教

  5. ccccccfl
    2018-03-20 21:25

    我用的2.7 一些包我也下载了 新手 不太会用呢 提示我ValueError: numpy.dtype has the wrong size, try recompiling. Expected 52, got 56

    • 马瑞强
      2018-03-20 21:27

      建议你把原来的Python卸载,直接安装Anaconda2.7系列,它里面带了很多常用的库,不需要再自己安装。

  6. fghajskhfs
    2018-05-03 10:54

    你好,代码里没有Getimagedata这个文件,压缩文件里没有,可以发给我一下吗

    • 马瑞强
      2018-05-03 19:02

      文章中有代码,第一个代码块就是。

  7. Ling
    2018-06-03 21:24

    请问一下手写数字的图片在哪下载

  8. Ling
    2018-06-03 21:30

    你好,我想请问一下,您提取图片中的特征向量中的图片类型是什么,我运行有报错说‘float’object cannot be interpreted as an integer,这要怎么修改代码呢

    • 马瑞强
      2018-06-04 20:06

      注意一下,Demo中用的Python2.7,目测你的报错是Python3.X,Python3.X中两整数相除会得到浮点数哦,强制类型转换一下就好了。

  9. Ling
    2018-06-04 20:13

    好的,问题已解决,非常感谢❤️

  10. 小k
    2018-06-11 00:18

    你好,我用的py3,想请问一下,导入路径那里的%s是表示图片的名称嘛?例如1,2,3这样

  11. 小k
    2018-06-11 10:26

    你好,我在用py3改写以后运行报错,ndarray has no attribute point,但是我觉得在getImageData类中已经定义了point,不知道原因,想请教一下

留下脚印,证明你来过。

*

*

流汗坏笑撇嘴大兵流泪发呆抠鼻吓到偷笑得意呲牙亲亲疑问调皮可爱白眼难过愤怒惊讶鼓掌