caffe笔记:测试自己的手写数字图片2

2017-02-09 Lu Huang 更多博文 » 博客 » GitHub »

原文链接 https://hlthu.github.io/caffe/2017/02/09/caffe-mnist-my-picture2.html
注:以下为加速网络访问所做的原文缓存,经过重新格式化,可能存在格式方面的问题,或偶有遗漏信息,请以原文为准。


在前面的文章caffe笔记:测试自己的手写数字图片中,不管输入的是什么图片,预测的结果都是在0~4这5个中。我今天又重新google了一下,发现了这个博客:深度学习初探——使用Caffe识别数字。发现该博客后面提到的方法很好的解决了上文中提到的问题。

1. 准备更多数据

我用画图软件又画了几个手写数字出来,如下图所示,读者也可以从我的github上下载

my-mnist

2. 准备相关设置

需要分别提供模型描述文件、模型权值文件、输入待分类图像。其中输入待分类图像在上一步已经准备好,模型权值文件也已经训练好了,模型描述文件和文章caffe笔记:测试自己的手写数字图片中的一样。

即我们需要提供的文件包括:

  • 模型描述文件:classificat_net.prototxt
  • 模型权值文件:lenet_iter_10000.caffemodel
  • 输入待分类图像:*.png

3. 编写测试代码

整个测试代码如下,读者也可以到我的github上阅读。文件保存为predict.py

import sys
import argparse
import cv2
import caffe

if __name__ == '__main__':
    # 命令行参数设置,文件名*.png通过--png参数传入
    parse = argparse.ArgumentParser()
    parse.add_argument('--png')
    args = parse.parse_args()
    # 模型描述文件和权值描述文件
    model = './classificat_net.prototxt'
    weights = './lenet_iter_10000.caffemodel'
    # 构建一个网络
    net = caffe.Net(model, weights, caffe.TEST)
    # 读入图像
    img = cv2.imread(args.png, cv2.IMREAD_GRAYSCALE)
    # 将图像输入网络
    net.blobs['data'].data[...] = img
    # 网络前向计算
    out = net.forward()
    # 输出
    prob = out['prob'][0]
    # 检测输出中为1的即为识别结果
    for index, item in enumerate(prob):
        if item == 1:
            print index

4. 开始测试

依次执行下面这条命令,*可以用0~9代替,输出的结果显示大约有8个识别正确,而且不会出现前文提到的错误,还是比较理想的。

$ python predict.py --png *.png

参考

  1. ubuntu 16.04上配置cuda+caffe环境
  2. caffe笔记:测试自己的手写数字图片
  3. 深度学习初探——使用Caffe识别数字