龙空技术网

动物分类器

cztAI 94

前言:

眼前朋友们对“net40eval”大概比较讲究,我们都想要学习一些“net40eval”的相关文章。那么小编也在网上汇集了一些有关“net40eval””的相关知识,希望看官们能喜欢,各位老铁们一起来学习一下吧!

0 准备工作

新建文件夹“动物分类”,在“动物分类”新建文件夹“数据”。

1 爬取动物图片

在“动物分类”,右键运行终端:

gedit get_data.pypython get_data.py

get_data.py

import requestsimport urllib.parse as upimport jsonimport timeimport osmajor_url = ';headers = {'User-Agent' : 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/84.0.4147.135 Safari/537.36'}def pic_spider(kw, path, page = 10):    path = os.path.join(path, kw)    if not os.path.exists(path):        os.mkdir(path)    if kw != '':        for num in range(page):            data = {                "tn": "resultjson_com",                "logid": "11587207680030063767",                "ipn": "rj",                "ct": "201326592",                "is": "",                "fp": "result",                "queryWord": kw,                "cl": "2",                "lm": "-1",                "ie": "utf-8",                "oe": "utf-8",                "adpicid": "",                "st": "-1",                "z": "",                "ic": "0",                "hd": "",                "latest": "",                "copyright": "",                "word": kw,                "s": "",                "se": "",                "tab": "",                "width": "",                "height": "",                "face": "0",                "istype": "2",                "qc": "",                "nc": "1",                "fr": "",                "expermode": "",                "force": "",                "pn": num*30,                "rn": "30",                "gsm": oct(num*30),                "1602481599433": ""            }            url = major_url + up.urlencode(data)            i = 0            pic_list = []            while i < 5:                try:                    pic_list = requests.get(url=url, headers=headers).json().get('data')                    break                except:                    print('网络不好,正在重试...')                    i += 1                    time.sleep(1.3)            for pic in pic_list:                url = pic.get('thumbURL', '') # 有的没有图片链接,就设置成空                if url == '':                    continue                name = pic.get('fromPageTitleEnc')                for char in ['?', '\\', '/', '*', '"', '|', ':', '<', '>']:                    name = name.replace(char, '') # 将所有不能出现在文件名中的字符去除掉                type = pic.get('type', 'jpg') # 找到图片的类型,若没有找到,默认为 jpg                pic_path = (os.path.join(path, '%s.%s') % (name, type))                print(name, '已完成下载')                if not os.path.exists(pic_path):                    with open(pic_path, 'wb') as f:                        f.write(requests.get(url = url, headers = headers).content)cwd = os.getcwd() # 当前路径        file1 = 'flower_data/flower_photos'file2 = '数据/下载数据'save_path = os.path.join(cwd,file2)#flower_class = [cla for cla in os.listdir(file1) if ".txt" not in cla]lists = ['猫','哈士奇','燕子','恐龙','鹦鹉','老鹰','柴犬','田园犬','咖啡猫','老虎','狮子','哥斯拉','奥特曼']print("lists_len: ",len(lists))for list in lists:    if not os.path.exists(save_path):        os.mkdir(save_path)    pic_spider(list,save_path, page = 10)

2 数据划分

将下载数据划分为训练集(80%)、验证集(10%)和测试集(10%)

gedit spile_data.pypython spile_data.py

spile_data.py

import osfrom shutil import copyimport randomdef mkfile(file):    if not os.path.exists(file):        os.makedirs(file)#file = 'flower_data/flower_photos'file = '数据/下载数据'flower_class = [cla for cla in os.listdir(file) if ".txt" not in cla]#mkfile('flower_data/train')mkfile('数据/train')for cla in flower_class:    #mkfile('flower_data/train/'+cla)    mkfile('数据/train/'+cla)#mkfile('flower_data/val')mkfile('数据/val')for cla in flower_class:    #mkfile('flower_data/val/'+cla)    mkfile('数据/val/'+cla)mkfile('数据/predict')for cla in flower_class:    #mkfile('flower_data/predict/'+cla)    mkfile('数据/predict/'+cla)split_rate = 0.1for cla in flower_class:    cla_path = file + '/' + cla + '/'    images1 = [cla1 for cla1 in os.listdir(cla_path) if ".jpg" in cla1]    images = [cla1 for cla1 in os.listdir(cla_path) if ".png" in cla1]+images1    #images = os.listdir(cla_path)     num = len(images)                #eval_index = random.sample(images, k=int(num*split_rate))    for index, image in enumerate(images):        if index<0.1*num:            image_path = cla_path + image            new_path = '数据/val/' + cla            copy(image_path, new_path)        elif 0.1*num<index<0.9*num:            image_path = cla_path + image            new_path = '数据/train/' + cla            copy(image_path, new_path)        else:            image_path = cla_path + image            new_path = '数据/predict/' + cla            copy(image_path, new_path)        print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")  # processing bar    print()print("处理完成 !")
3 模型
gedit model.pypython model.py

model.py

import torch.nn as nnimport torchclass AlexNet(nn.Module):    def __init__(self, num_classes=1000, init_weights=False):           super(AlexNet, self).__init__()        self.features = nn.Sequential(  #打包            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55] 自动舍去小数点后            nn.ReLU(inplace=True), #inplace 可以载入更大模型            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27] kernel_num为原论文一半            nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]            nn.ReLU(inplace=True),            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]            nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]            nn.ReLU(inplace=True),            nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]            nn.ReLU(inplace=True),            nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]            nn.ReLU(inplace=True),            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6]        )        self.classifier = nn.Sequential(            nn.Dropout(p=0.5),            #全链接            nn.Linear(128 * 6 * 6, 2048),            nn.ReLU(inplace=True),            nn.Dropout(p=0.5),            nn.Linear(2048, 2048),            nn.ReLU(inplace=True),            nn.Linear(2048, num_classes),        )        if init_weights:            self._initialize_weights()    def forward(self, x):        x = self.features(x)        x = torch.flatten(x, start_dim=1) #展平   或者view()        x = self.classifier(x)        return x    def _initialize_weights(self):        for m in self.modules():            if isinstance(m, nn.Conv2d):                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') #何教授方法                if m.bias is not None:                    nn.init.constant_(m.bias, 0)            elif isinstance(m, nn.Linear):                nn.init.normal_(m.weight, 0, 0.01)  #正态分布赋值                nn.init.constant_(m.bias, 0)
4 训练和验证
gedit train.pypython train.py

train.py

import torchimport torch.nn as nnfrom torchvision import transforms, datasets, utilsimport matplotlib.pyplot as pltimport numpy as npimport torch.optim as optimfrom model import AlexNetimport osimport jsonimport timeimport torchvision#device : GPU 或 CPUdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print(device)#数据预处理data_transform = {    "train": transforms.Compose([transforms.RandomResizedCrop(224), # 随机裁剪为224x224                                 transforms.RandomHorizontalFlip(), # 水平翻转                                 transforms.ToTensor(), # 转为张量                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),# 均值和方差为0.5    "val": transforms.Compose([transforms.Resize((224, 224)), # 重置大小                               transforms.ToTensor(),                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}batch_size = 32 # 批次大小data_root = os.getcwd() # 获取当前路径image_path = data_root + "/数据"  # 数据路径train_dataset = datasets.ImageFolder(root=image_path + "/train",                                     transform=data_transform["train"]) # 加载训练数据集并预处理train_num = len(train_dataset) # 训练数据集大小train_loader = torch.utils.data.DataLoader(train_dataset,                                           batch_size=batch_size, shuffle=True,                                           num_workers=0) # 训练加载器validate_dataset = datasets.ImageFolder(root=image_path + "/val",                                        transform=data_transform["val"]) # 验证数据集val_num = len(validate_dataset) # 验证数据集大小validate_loader = torch.utils.data.DataLoader(validate_dataset,                                              batch_size=batch_size, shuffle=True,                                              num_workers=0) # 验证加载器print("训练数据集大小: ",train_num,"\n") # 28218print("验证数据集大小: ",val_num,"\n") # 308def imshow(img):    img = img / 2 + 0.5     # unnormalize    npimg = img.numpy()    plt.imshow(np.transpose(npimg, (1, 2, 0)))    plt.show()net = AlexNet(num_classes=13, init_weights=True) # 调用模型net.to(device)loss_function = nn.CrossEntropyLoss() # 损失函数:交叉熵optimizer = optim.Adam(net.parameters(), lr=0.0002) #优化器 Adamsave_path = './AlexNet.pth' # 训练参数保存路径best_acc = 0.0 # 训练过程中最高准确率#开始进行训练和测试,训练一轮,测试一轮for epoch in range(10):    # 训练部分    print(">>开始训练: ",epoch+1)    net.train()    #训练dropout    running_loss = 0.0    t1 = time.perf_counter()    for step, data in enumerate(train_loader, start=0):        images, labels = data        #print("\nlabels: ",labels)        #imshow(torchvision.utils.make_grid(images))        optimizer.zero_grad() # 梯度置0        outputs = net(images.to(device))         loss = loss_function(outputs, labels.to(device))        loss.backward() # 反向传播        optimizer.step()                running_loss += loss.item() # 累加损失        rate = (step + 1) / len(train_loader) # 训练进度        a = "*" * int(rate * 50) # *数        b = "." * int((1 - rate) * 50) # .数        print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")    print()    print(time.perf_counter()-t1) # 一个epoch花费的时间    # 验证部分    print(">>开始验证: ",epoch+1)    net.eval()    #验证不需要dropout    acc = 0.0  # 一个批次中分类正确个数    with torch.no_grad():        for val_data in validate_loader:            val_images, val_labels = val_data            outputs = net(val_images.to(device))            #print("outputs: \n",outputs,"\n")            predict_y = torch.max(outputs, dim=1)[1]            #print("predict_y: \n",predict_y,"\n")            acc += (predict_y == val_labels.to(device)).sum().item() # 预测和标签一致,累加        val_accurate = acc / val_num # 一个批次的准确率        if val_accurate > best_acc:            best_acc = val_accurate            torch.save(net.state_dict(), save_path) # 更新准确率最高的网络参数        print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %              (epoch + 1, running_loss / step, val_accurate))print('Finished Training')# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}flower_list = train_dataset.class_to_idx #  {0: 'daisy', 1: 'dandelion', 2: 'roses', 3: 'sunflowers', 4: 'tulips'}cla_dict = dict((val, key) for key, val in flower_list.items())# 将字典写入 json 文件json_str = json.dumps(cla_dict, indent=4) # 字典转jsonwith open('class_indices.json', 'w') as json_file: # 对class_indices.json写入操作    json_file.write(json_str) # 写入class_indices.json
5 测试
gedit predict.pypython predict.py

predict.py

import torchfrom model import AlexNetfrom PIL import Imagefrom torchvision import transformsimport matplotlib.pyplot as pltimport jsonimport osdata_transform = transforms.Compose(    [transforms.Resize((224, 224)),     transforms.ToTensor(),     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])cwd = os.getcwd() # 当前路径predict = '数据/predict'predict_path = os.path.join(cwd,predict)#flowers = ['雏菊','蒲公英','玫瑰花','太阳花','郁金香']#flowers = [flower for flower in os.listdir(predict_path)]try:    json_file = open('./class_indices.json', 'r')    class_indict = json.load(json_file)except Exception as e:    print(e)    exit(-1)for j,flower in class_indict.items():    print(">>测试: ",flower)    #print("花\t","概率")     path = os.path.join(predict_path,flower)    images = [f1 for f1 in os.listdir(path) if ".gif" not in f1] # 过滤gif动图        acc_ = [0,0,0,0,0,0,0,0,0,0,0,0,0]    for image in images:        # 加载图片        img = Image.open(path+'/'+image).convert('RGB')        # RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 0        # .convert('RGB')        plt.imshow(img)        # [N, C, H, W]        img = data_transform(img)        # expand batch dimension        img = torch.unsqueeze(img, dim=0)        # read class_indict        try:            json_file = open('./class_indices.json', 'r')            class_indict = json.load(json_file)        except Exception as e:            print(e)            exit(-1)        # create model        model = AlexNet(num_classes=13)        # load model weights        model_weight_path = "./AlexNet.pth"        model.load_state_dict(torch.load(model_weight_path))        model.eval()        with torch.no_grad():            # predict class            output = torch.squeeze(model(img))            predict = torch.softmax(output, dim=0)            predict_flower = torch.argmax(predict).numpy()        #print(class_indict[str(predict_flower)],'\t', predict[predict_flower].item())        #print(str(predict_flower))        acc_[predict_flower]+=1    #print("acc_: ",acc_)    print("{}总共有{}张图片 \n".format(flower,len(images)))    #print(class_indict.values(),'\n',str(acc_))    print("{}准确率为:{}%".format(flower,100*acc_[int(j)]/len(images)))    print("\n")print(">>测试完毕!")

测试结果:

>>测试:  咖啡猫咖啡猫总共有21张图片咖啡猫准确率为:14.285714285714286%>>测试:  哈士奇哈士奇总共有24张图片 哈士奇准确率为:58.333333333333336%>>测试:  哥斯拉哥斯拉总共有21张图片 哥斯拉准确率为:57.142857142857146%>>测试:  奥特曼奥特曼总共有23张图片奥特曼准确率为:30.434782608695652%>>测试:  恐龙恐龙总共有23张图片 恐龙准确率为:34.78260869565217%>>测试:  柴犬柴犬总共有17张图片 柴犬准确率为:70.58823529411765%>>测试:  燕子燕子总共有18张图片 燕子准确率为:61.111111111111114%>>测试:  狮子狮子总共有24张图片狮子准确率为:58.333333333333336%>>测试:  猫猫总共有22张图片 猫准确率为:18.181818181818183%>>测试:  田园犬田园犬总共有22张图片 田园犬准确率为:4.545454545454546%>>测试:  老虎老虎总共有22张图片老虎准确率为:22.727272727272727%>>测试:  老鹰老鹰总共有22张图片 老鹰准确率为:0.0%>>测试:  鹦鹉鹦鹉总共有22张图片 鹦鹉准确率为:72.72727272727273%

提示:爬虫有风险,使用需谨慎!

标签: #net40eval