Fork me on GitHub

Cifar10模型学习

数据集


概述

数据集包括60000张RGB图像,每张图像尺寸为32*32,分为10个类,每类6000张。10个类分别是:飞机,汽车,鸟,猫,鹿,狗,青蛙,马,船,卡车。60000张图像中,50000张用于训练,10000张用于测试;测试数据分为5批,每批10000张,其中每类图像数量不一定相等。

Python版本数据在这里下载,MATLAB版本数据在这里下载,二进制版本数据在这里下载

数据布局

60000张图像分为6批。数据通过cPickle库处理,产生序列化对象,在实际编码中可以通过Python语言读取这些对象,得到一个字典结构的数据;如此一来,每个batch就都包含一个字典,字典里面包含两个元素:data,labels。

data
//一个10000*3072的numpy数组,数据类型为无符号整型
//10000表示图像数量,每一行就是一张图像
//3072表示的是一张独立图像的数据,每张图像尺寸为32*32,RGB三通道(32*32*3)
//图像按照行顺序存储,就是说原始图像R通道像素矩阵的第一行32个值就存储在该数组某一行的前32位

labels
//一维数组,包含10000个元素,每个元素取值范围0~9,代表图像所属类别

以上是Python/MATLAB版本的数据,还有一种是二进制,也就是TensorFlow所采用的。同样将数据分成6个batch。除此之外还有一个batches.meta.txt文件,这是一个ASCII文件,同样是把0~9类数字类标同每一类类名对应起来。

<1 x label><3072 x pixel>
...
<1 x label><3072 x pixel>

//第一个字节是类别,在0~9之间
//接下来是3072个字节,内容就是一幅图像的数据(同Python版本类似)
//每个文件都是10000行这样的数据,没有分隔行,所以文件大小就是30730000字节
//按行,height*weight*deepth的方式,前1024字节为R通道,然后是G通道,然后是B通道

代码解析


下图是CIFAR10文件内容介绍:

cifar10_input.py

该文件定义了模型如何数据文件(cifar10_data中的文件),在读取完成后还要对图像进行随机剪裁,翻转,标准化,等等。经过这一系列的预处理之后,数据才可以作为模型的标准输入。

cifar10.py

接下来是模型定义文件,该文件内容包括:数据输入方法,模型的建立,训练方法等。下面这张图在文件的头部,相当于是文件内容的一个总述。


它给出了整个模型的大致框架:

* 首先通过distorted_inputs()方法读取数据,并存入变量inputs(图像数据)和labels(分类标签)中
* 然后通过inference()方法来完成预测,其实就是神经网络模型的搭建过程,包括:卷积,池化,局部响应归一化,全连接,softmax归一化等;预测结果存入变量predictions
* 最后是训练,主要是计算损失loss、计算梯度、进行变量更新以及呈现最终结果等;


上表给出了该脚本定义的所有方法功能,其中有三个比较重要的函数方法:inference(),loss(),train()。根据inference()可以得到CNN的模型如下:

cifar10_train.py

模型的训练方法。

cifar10_eval.py

测试数据时用到的脚本。里面定义了评估模型的方法,还有评估时的输出内容。后面如果需要实现自定义一张图片作为输入让模型来进行预测,主要是需要修改该文件中的代码。

代码运行


首先需要用训练数据作为输入,让模型不断学习,优化模型参数。训练完成后,模型确定,可以用测试数据来测试模型的准确性。依次执行如下命令:

python cifar10_train.py
python cifar10_eval.py

实例——单张图片作为输入


在成功运行脚本的基础上,我们想让模型对我们自定义的输入图片进行预测,然后输出该图片属于哪一类。这就需要修改脚本,主要的修改对象是cifar10_eval.py。

1、定义一个读取操作
源码中的读取是基于.bin文件的,这是事先处理过的,但实际上输入模型的数据并不是.bin文件,而是一个结构为【128,24,24,3】的张量,所以我们需要把待预测的单张图片处理成模型所需要的结构形式,同时要对图像进行标准化等处理(CIFAR就这么干了,模型的需要,不处理的话结果不正确)。这相当于是一个输入接口,运行时输入python-cifar10_eval(),程序会提示输入文件名称,注意在输入文件名时要加入单引号(例如’cat.jpg’)。


2、修改cifar-10eval.py中的evaluate()方法
这个方法就是模型数据的入口,源码中有两个输入:images,labels,但实际上我们只需要image,因为我们不希望去和label对比,而是直接输出检测结果。这里iamge值获取就用到上面的read_image()方法。


3、修改cifar-10eval.py中的eval_once()方法
这里是程序的输出位置,源码会输出测试后的准确率,这个准确率是通过模型预测值和labels正确值之间进行比对来确定的。模型预测完后,得到logits,这是一个结构为【128,10】的张量,每一行代表一张图片的结果,行向量中最大值所在位置就代表这张图属于哪一个类。


4、结果
并不是每张图片都能准确预测,有的会出错;
在我的实例中,cat,ship,horse,deer,frog,airplane,一般是可以较准确分类;
bird预测为airplane,automobile预测为ship:可能是在图片的中央,二者平面拓扑结构比较相似;
truck预测为airplane,dog预测为airplane。


上图输出了模型预测的输入图片最有可能的类:cat,dog,bird。

-------------本文结束感谢您的阅读-------------
ChengQian wechat
有问题可以通过微信一起讨论!