龙空技术网

利用pytorch CNN手写字母识别神经网络模型识别手写字母

人工智能研究所 963

前言:

眼前大家对“基于cnn的手写数字识别”大约比较着重,大家都需要剖析一些“基于cnn的手写数字识别”的相关文章。那么小编在网上搜集了一些有关“基于cnn的手写数字识别””的相关资讯,希望姐妹们能喜欢,看官们快快来了解一下吧!

上期文章我们分享了使用EMNIST数据集训练第一个pytorch CNN手写字母识别神经网络

哪里使用pytorch训练了第一个手写字母的神经网络,并保存了预训练模型,本期我们使用上期的模型进行手写字母的识别

搭建神经网络

根据上期文章的分享,我们搭建一个手写字母识别的神经网络

import torchimport torch.nn as nnfrom PIL import Image  # 导入图片处理工具import PIL.ImageOpsimport numpy as npfrom torchvision import transformsimport cv2import matplotlib.pyplot as plt# 定义神经网络class CNN(nn.Module):    def __init__(self):        super(CNN, self).__init__()        self.conv1 = nn.Sequential(  # input shape (1, 28, 28)            nn.Conv2d(                in_channels=1,  # 输入通道数                out_channels=16,  # 输出通道数                kernel_size=5,   # 卷积核大小                stride=1,  #卷积步数                padding=2,  # 如果想要 con2d 出来的图片长宽没有变化,                             # padding=(kernel_size-1)/2 当 stride=1            ),  # output shape (16, 28, 28)            nn.ReLU(),  # activation            nn.MaxPool2d(kernel_size=2),  # 在 2x2 空间里向下采样, output shape (16, 14, 14)        )        self.conv2 = nn.Sequential(  # input shape (16, 14, 14)            nn.Conv2d(16, 32, 5, 1, 2),  # output shape (32, 14, 14)            nn.ReLU(),  # activation            nn.MaxPool2d(2),  # output shape (32, 7, 7)        )        self.out = nn.Linear(32 * 7 * 7, 37)  # 全连接层,A/Z,a/z一共37个类    def forward(self, x):        x = self.conv1(x)        x = self.conv2(x)        x = x.view(x.size(0), -1)  # 展平多维的卷积图成 (batch_size, 32 * 7 * 7)        output = self.out(x)        return output

我们手写字母的识别神经网络主要使用了EMNIST数据库

EMNIST 主要分为以下 6 类:

By_Class : 共 814255 张,62 类,与 NIST 相比重新划分类训练集与测试机的图片数

By_Merge: 共 814255 张,47 类, 与 NIST 相比重新划分类训练集与测试机的图片数

Balanced : 共 131600 张,47 类, 每一类都包含了相同的数据,每一类训练集 2400 张,测试集 400 张

Digits :共 28000 张,10 类,每一类包含相同数量数据,每一类训练集 24000 张,测试集 4000 张

Letters : 共 103600 张,37 类,每一类包含相同数据,每一类训练集 2400 张,测试集 400 张

MNIST : 共 70000 张,10 类,每一类包含相同数量数据(注:这里虽然数目和分类都一样,但是图片的处理方式不一样,EMNIST 的 MNIST 子集数字占的比重更大)

这里为什么后面的分类不是26+26?其主要原因是一些大小写字母比较类似的字母就合并了,比如C等等

这就是为什么神经网络的全连接层self.out = nn.Linear(32 * 7 * 7, 37)一共有37类,我们上期代码中的训练是按照letters类进行训练的,神经网络的搭建过程,这里不再一一介绍了,可以参考往期文章进行学习。

使用MNIST数据集训练第一个pytorch CNN手写数字识别神经网络

pytorch利用CNN卷积神经网络来识别手写数字

上期代码# EMNIST 手写字母 训练集train_data = torchvision.datasets.EMNIST(    root='./data',    train=True,    transform=torchvision.transforms.ToTensor(),    download = DOWNLOAD_MNIST,    split = 'letters' )# EMNIST 手写字母 测试集test_data = torchvision.datasets.EMNIST(    root='./data',    train=False,    transform=torchvision.transforms.ToTensor(),    download=False,    split = 'letters'     )
加载图片,预处理

神经网络搭建完成后,我们需要加载图片,并进行图片的一些预处理操作

file_name = '55.png'  # 导入自己的图片img = Image.open(file_name)img = img.convert('L')img = PIL.ImageOps.invert(img)img = img.transpose(Image.FLIP_LEFT_RIGHT)img = img.rotate(90)plt.imshow(img)plt.show()train_transform = transforms.Compose([       transforms.Grayscale(),         transforms.Resize((28, 28)),         transforms.ToTensor(), ])img = train_transform(img)img = torch.unsqueeze(img, dim=0)#torch.unsqueeze()这个函数主要是对数据维度进行扩充。需要通过dim指定位置,给指定位置加上维数为1的维度。
通过往期文章对数据库的可视化,可以得知EMNIST数据库是黑底白字,但是平时我们自己的照片一般是白底黑字,这里,我们使用img = img.convert('L')img = PIL.ImageOps.invert(img)对图片进行颜色的翻转且我们知道EMNIST数据库左右翻转图片后,又进行了图片的逆时针旋转90度这里我们使用PIL库提供的函数进行图片的处理操作img = img.transpose(Image.FLIP_LEFT_RIGHT)img = img.rotate(90)Grayscale:将图像转换为灰度ToTensor:转换为张量

通过train_transform(img)函数我们对输入的图片进行预处理操作,转换为pytorch可以识别的神经网络数据

加载模型

#加载模型model = CNN()model.load_state_dict(torch.load('./model/Eminist.pth',map_location='cpu'))model.eval()def get_mapping(num, with_type='letters'):    """    根据 mapping,由传入的 num 计算 UTF8 字符。    """    if with_type == 'byclass':        if num <= 9:            return chr(num + 48)  # 数字        elif num <= 35:            return chr(num + 55)  # 大写字母        else:            return chr(num + 61)  # 小写字母    elif with_type == 'letters':        return chr(num + 64) + " / " + chr(num + 96)  # 大写/小写字母    elif with_type == 'digits':        return chr(num + 96)    else:        return num

model.load_state_dict(torch.load)函数加载上期神经网络训练完成的模型

get_mapping函数:

由于神经网络识别完成后,反馈给程序的是字母的UTF-8编码,我们通过查表来找到对应的字母

字符编码表(UTF-8)

神经网络识别

with torch.no_grad():    y = model(img)    print(y)    output = torch.squeeze(y)    print(output)    predict = torch.softmax(output, dim=0)    print(predict)    predict_cla = torch.argmax(predict).numpy()    print(predict_cla)print(get_mapping(predict_cla), predict[predict_cla].numpy())

运行以上代码我们可以看到神经网络输出每个字母的识别精度,我们使用torch.argmax(predict).numpy()函数选择其中精度最大字母,并利用get_mapping函数查表选择出神经网络识别出来的字母

tensor([8.1084e-10, 6.5350e-04, 8.5815e-01, 1.2294e-05, 4.6187e-03, 9.3248e-04,        9.5208e-06, 2.7102e-02, 3.7893e-04, 3.2245e-05, 3.4475e-05, 8.7205e-06,        3.5875e-07, 2.1584e-06, 1.9850e-06, 5.4030e-02, 4.6604e-03, 4.5797e-02,        3.4711e-04, 3.0402e-03, 3.9365e-05, 1.6103e-06, 5.6906e-06, 1.3600e-07,        6.6020e-07, 2.8310e-06, 1.3650e-04, 1.1401e-09, 7.8941e-10, 7.8516e-10,        8.4411e-10, 1.2178e-09, 1.3640e-09, 8.4337e-10, 1.1152e-09, 1.0610e-09,        1.0475e-09])
2B / b 0.8581509

我们可以看到神经网络可以成功地识别出我们手写的字母B

标签: #基于cnn的手写数字识别