caffe笔记:测试自己的手写数字图片2
原文链接 https://hlthu.github.io/caffe/2017/02/09/caffe-mnist-my-picture2.html
注:以下为加速网络访问所做的原文缓存,经过重新格式化,可能存在格式方面的问题,或偶有遗漏信息,请以原文为准。
在前面的文章caffe笔记:测试自己的手写数字图片中,不管输入的是什么图片,预测的结果都是在0~4这5个中。我今天又重新google了一下,发现了这个博客:深度学习初探——使用Caffe识别数字。发现该博客后面提到的方法很好的解决了上文中提到的问题。
1. 准备更多数据
我用画图软件又画了几个手写数字出来,如下图所示,读者也可以从我的github上下载。
2. 准备相关设置
需要分别提供模型描述文件、模型权值文件、输入待分类图像。其中输入待分类图像在上一步已经准备好,模型权值文件也已经训练好了,模型描述文件和文章caffe笔记:测试自己的手写数字图片中的一样。
即我们需要提供的文件包括:
- 模型描述文件:
classificat_net.prototxt
。 - 模型权值文件:
lenet_iter_10000.caffemodel
。 - 输入待分类图像:
*.png
。
3. 编写测试代码
整个测试代码如下,读者也可以到我的github上阅读。文件保存为predict.py
。
import sys
import argparse
import cv2
import caffe
if __name__ == '__main__':
# 命令行参数设置,文件名*.png通过--png参数传入
parse = argparse.ArgumentParser()
parse.add_argument('--png')
args = parse.parse_args()
# 模型描述文件和权值描述文件
model = './classificat_net.prototxt'
weights = './lenet_iter_10000.caffemodel'
# 构建一个网络
net = caffe.Net(model, weights, caffe.TEST)
# 读入图像
img = cv2.imread(args.png, cv2.IMREAD_GRAYSCALE)
# 将图像输入网络
net.blobs['data'].data[...] = img
# 网络前向计算
out = net.forward()
# 输出
prob = out['prob'][0]
# 检测输出中为1的即为识别结果
for index, item in enumerate(prob):
if item == 1:
print index
4. 开始测试
依次执行下面这条命令,*可以用0~9代替,输出的结果显示大约有8个识别正确,而且不会出现前文提到的错误,还是比较理想的。
$ python predict.py --png *.png