前言:
眼前大家对“基于cnn的手写数字识别”大约比较着重,大家都需要剖析一些“基于cnn的手写数字识别”的相关文章。那么小编在网上搜集了一些有关“基于cnn的手写数字识别””的相关资讯,希望姐妹们能喜欢,看官们快快来了解一下吧!上期文章我们分享了使用EMNIST数据集训练第一个pytorch CNN手写字母识别神经网络
哪里使用pytorch训练了第一个手写字母的神经网络,并保存了预训练模型,本期我们使用上期的模型进行手写字母的识别
搭建神经网络
根据上期文章的分享,我们搭建一个手写字母识别的神经网络
import torchimport torch.nn as nnfrom PIL import Image # 导入图片处理工具import PIL.ImageOpsimport numpy as npfrom torchvision import transformsimport cv2import matplotlib.pyplot as plt# 定义神经网络class CNN(nn.Module): def __init__(self): super(CNN, self).__init__() self.conv1 = nn.Sequential( # input shape (1, 28, 28) nn.Conv2d( in_channels=1, # 输入通道数 out_channels=16, # 输出通道数 kernel_size=5, # 卷积核大小 stride=1, #卷积步数 padding=2, # 如果想要 con2d 出来的图片长宽没有变化, # padding=(kernel_size-1)/2 当 stride=1 ), # output shape (16, 28, 28) nn.ReLU(), # activation nn.MaxPool2d(kernel_size=2), # 在 2x2 空间里向下采样, output shape (16, 14, 14) ) self.conv2 = nn.Sequential( # input shape (16, 14, 14) nn.Conv2d(16, 32, 5, 1, 2), # output shape (32, 14, 14) nn.ReLU(), # activation nn.MaxPool2d(2), # output shape (32, 7, 7) ) self.out = nn.Linear(32 * 7 * 7, 37) # 全连接层,A/Z,a/z一共37个类 def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = x.view(x.size(0), -1) # 展平多维的卷积图成 (batch_size, 32 * 7 * 7) output = self.out(x) return output
我们手写字母的识别神经网络主要使用了EMNIST数据库
EMNIST 主要分为以下 6 类:
By_Class : 共 814255 张,62 类,与 NIST 相比重新划分类训练集与测试机的图片数
By_Merge: 共 814255 张,47 类, 与 NIST 相比重新划分类训练集与测试机的图片数
Balanced : 共 131600 张,47 类, 每一类都包含了相同的数据,每一类训练集 2400 张,测试集 400 张
Digits :共 28000 张,10 类,每一类包含相同数量数据,每一类训练集 24000 张,测试集 4000 张
Letters : 共 103600 张,37 类,每一类包含相同数据,每一类训练集 2400 张,测试集 400 张
MNIST : 共 70000 张,10 类,每一类包含相同数量数据(注:这里虽然数目和分类都一样,但是图片的处理方式不一样,EMNIST 的 MNIST 子集数字占的比重更大)
这里为什么后面的分类不是26+26?其主要原因是一些大小写字母比较类似的字母就合并了,比如C等等
这就是为什么神经网络的全连接层self.out = nn.Linear(32 * 7 * 7, 37)一共有37类,我们上期代码中的训练是按照letters类进行训练的,神经网络的搭建过程,这里不再一一介绍了,可以参考往期文章进行学习。
使用MNIST数据集训练第一个pytorch CNN手写数字识别神经网络
pytorch利用CNN卷积神经网络来识别手写数字
上期代码# EMNIST 手写字母 训练集train_data = torchvision.datasets.EMNIST( root='./data', train=True, transform=torchvision.transforms.ToTensor(), download = DOWNLOAD_MNIST, split = 'letters' )# EMNIST 手写字母 测试集test_data = torchvision.datasets.EMNIST( root='./data', train=False, transform=torchvision.transforms.ToTensor(), download=False, split = 'letters' )加载图片,预处理
神经网络搭建完成后,我们需要加载图片,并进行图片的一些预处理操作
file_name = '55.png' # 导入自己的图片img = Image.open(file_name)img = img.convert('L')img = PIL.ImageOps.invert(img)img = img.transpose(Image.FLIP_LEFT_RIGHT)img = img.rotate(90)plt.imshow(img)plt.show()train_transform = transforms.Compose([ transforms.Grayscale(), transforms.Resize((28, 28)), transforms.ToTensor(), ])img = train_transform(img)img = torch.unsqueeze(img, dim=0)#torch.unsqueeze()这个函数主要是对数据维度进行扩充。需要通过dim指定位置,给指定位置加上维数为1的维度。
通过往期文章对数据库的可视化,可以得知EMNIST数据库是黑底白字,但是平时我们自己的照片一般是白底黑字,这里,我们使用img = img.convert('L')img = PIL.ImageOps.invert(img)对图片进行颜色的翻转且我们知道EMNIST数据库左右翻转图片后,又进行了图片的逆时针旋转90度这里我们使用PIL库提供的函数进行图片的处理操作img = img.transpose(Image.FLIP_LEFT_RIGHT)img = img.rotate(90)Grayscale:将图像转换为灰度ToTensor:转换为张量
通过train_transform(img)函数我们对输入的图片进行预处理操作,转换为pytorch可以识别的神经网络数据
加载模型
#加载模型model = CNN()model.load_state_dict(torch.load('./model/Eminist.pth',map_location='cpu'))model.eval()def get_mapping(num, with_type='letters'): """ 根据 mapping,由传入的 num 计算 UTF8 字符。 """ if with_type == 'byclass': if num <= 9: return chr(num + 48) # 数字 elif num <= 35: return chr(num + 55) # 大写字母 else: return chr(num + 61) # 小写字母 elif with_type == 'letters': return chr(num + 64) + " / " + chr(num + 96) # 大写/小写字母 elif with_type == 'digits': return chr(num + 96) else: return num
model.load_state_dict(torch.load)函数加载上期神经网络训练完成的模型
get_mapping函数:
由于神经网络识别完成后,反馈给程序的是字母的UTF-8编码,我们通过查表来找到对应的字母
字符编码表(UTF-8)
神经网络识别
with torch.no_grad(): y = model(img) print(y) output = torch.squeeze(y) print(output) predict = torch.softmax(output, dim=0) print(predict) predict_cla = torch.argmax(predict).numpy() print(predict_cla)print(get_mapping(predict_cla), predict[predict_cla].numpy())
运行以上代码我们可以看到神经网络输出每个字母的识别精度,我们使用torch.argmax(predict).numpy()函数选择其中精度最大字母,并利用get_mapping函数查表选择出神经网络识别出来的字母
tensor([8.1084e-10, 6.5350e-04, 8.5815e-01, 1.2294e-05, 4.6187e-03, 9.3248e-04, 9.5208e-06, 2.7102e-02, 3.7893e-04, 3.2245e-05, 3.4475e-05, 8.7205e-06, 3.5875e-07, 2.1584e-06, 1.9850e-06, 5.4030e-02, 4.6604e-03, 4.5797e-02, 3.4711e-04, 3.0402e-03, 3.9365e-05, 1.6103e-06, 5.6906e-06, 1.3600e-07, 6.6020e-07, 2.8310e-06, 1.3650e-04, 1.1401e-09, 7.8941e-10, 7.8516e-10, 8.4411e-10, 1.2178e-09, 1.3640e-09, 8.4337e-10, 1.1152e-09, 1.0610e-09, 1.0475e-09])
2B / b 0.8581509
我们可以看到神经网络可以成功地识别出我们手写的字母B
标签: #基于cnn的手写数字识别