0 准备工作


1 爬取动物图片


gedit get_data.pypython get_data.py


import requestsimport urllib.parse as upimport jsonimport timeimport osmajor_url = ';headers = {'User-Agent' : 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/84.0.4147.135 Safari/537.36'}def pic_spider(kw, path, page = 10):    path = os.path.join(path, kw)    if not os.path.exists(path):        os.mkdir(path)    if kw != '':        for num in range(page):            data = {                "tn": "resultjson_com",                "logid": "11587207680030063767",                "ipn": "rj",                "ct": "201326592",                "is": "",                "fp": "result",                "queryWord": kw,                "cl": "2",                "lm": "-1",                "ie": "utf-8",                "oe": "utf-8",                "adpicid": "",                "st": "-1",                "z": "",                "ic": "0",                "hd": "",                "latest": "",                "copyright": "",                "word": kw,                "s": "",                "se": "",                "tab": "",                "width": "",                "height": "",                "face": "0",                "istype": "2",                "qc": "",                "nc": "1",                "fr": "",                "expermode": "",                "force": "",                "pn": num*30,                "rn": "30",                "gsm": oct(num*30),                "1602481599433": ""            }            url = major_url + up.urlencode(data)            i = 0            pic_list = []            while i < 5:                try:                    pic_list = requests.get(url=url, headers=headers).json().get('data')                    break                except:                    print('网络不好,正在重试...')                    i += 1                    time.sleep(1.3)            for pic in pic_list:                url = pic.get('thumbURL', '') # 有的没有图片链接,就设置成空                if url == '':                    continue                name = pic.get('fromPageTitleEnc')                for char in ['?', '\\', '/', '*', '"', '|', ':', '<', '>']:                    name = name.replace(char, '') # 将所有不能出现在文件名中的字符去除掉                type = pic.get('type', 'jpg') # 找到图片的类型,若没有找到,默认为 jpg                pic_path = (os.path.join(path, '%s.%s') % (name, type))                print(name, '已完成下载')                if not os.path.exists(pic_path):                    with open(pic_path, 'wb') as f:                        f.write(requests.get(url = url, headers = headers).content)cwd = os.getcwd() # 当前路径        file1 = 'flower_data/flower_photos'file2 = '数据/下载数据'save_path = os.path.join(cwd,file2)#flower_class = [cla for cla in os.listdir(file1) if ".txt" not in cla]lists = ['猫','哈士奇','燕子','恐龙','鹦鹉','老鹰','柴犬','田园犬','咖啡猫','老虎','狮子','哥斯拉','奥特曼']print("lists_len: ",len(lists))for list in lists:    if not os.path.exists(save_path):        os.mkdir(save_path)    pic_spider(list,save_path, page = 10)

2 数据划分


gedit spile_data.pypython spile_data.py


import osfrom shutil import copyimport randomdef mkfile(file):    if not os.path.exists(file):        os.makedirs(file)#file = 'flower_data/flower_photos'file = '数据/下载数据'flower_class = [cla for cla in os.listdir(file) if ".txt" not in cla]#mkfile('flower_data/train')mkfile('数据/train')for cla in flower_class:    #mkfile('flower_data/train/'+cla)    mkfile('数据/train/'+cla)#mkfile('flower_data/val')mkfile('数据/val')for cla in flower_class:    #mkfile('flower_data/val/'+cla)    mkfile('数据/val/'+cla)mkfile('数据/predict')for cla in flower_class:    #mkfile('flower_data/predict/'+cla)    mkfile('数据/predict/'+cla)split_rate = 0.1for cla in flower_class:    cla_path = file + '/' + cla + '/'    images1 = [cla1 for cla1 in os.listdir(cla_path) if ".jpg" in cla1]    images = [cla1 for cla1 in os.listdir(cla_path) if ".png" in cla1]+images1    #images = os.listdir(cla_path)     num = len(images)                #eval_index = random.sample(images, k=int(num*split_rate))    for index, image in enumerate(images):        if index<0.1*num:            image_path = cla_path + image            new_path = '数据/val/' + cla            copy(image_path, new_path)        elif 0.1*num<index<0.9*num:            image_path = cla_path + image            new_path = '数据/train/' + cla            copy(image_path, new_path)        else:            image_path = cla_path + image            new_path = '数据/predict/' + cla            copy(image_path, new_path)        print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")  # processing bar    print()print("处理完成 !")
3 模型
gedit model.pypython model.py


import torch.nn as nnimport torchclass AlexNet(nn.Module):    def __init__(self, num_classes=1000, init_weights=False):           super(AlexNet, self).__init__()        self.features = nn.Sequential(  #打包            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55] 自动舍去小数点后            nn.ReLU(inplace=True), #inplace 可以载入更大模型            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27] kernel_num为原论文一半            nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]            nn.ReLU(inplace=True),            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]            nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]            nn.ReLU(inplace=True),            nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]            nn.ReLU(inplace=True),            nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]            nn.ReLU(inplace=True),            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6]        )        self.classifier = nn.Sequential(            nn.Dropout(p=0.5),            #全链接            nn.Linear(128 * 6 * 6, 2048),            nn.ReLU(inplace=True),            nn.Dropout(p=0.5),            nn.Linear(2048, 2048),            nn.ReLU(inplace=True),            nn.Linear(2048, num_classes),        )        if init_weights:            self._initialize_weights()    def forward(self, x):        x = self.features(x)        x = torch.flatten(x, start_dim=1) #展平   或者view()        x = self.classifier(x)        return x    def _initialize_weights(self):        for m in self.modules():            if isinstance(m, nn.Conv2d):                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') #何教授方法                if m.bias is not None:                    nn.init.constant_(m.bias, 0)            elif isinstance(m, nn.Linear):                nn.init.normal_(m.weight, 0, 0.01)  #正态分布赋值                nn.init.constant_(m.bias, 0)
4 训练和验证
gedit train.pypython train.py


import torchimport torch.nn as nnfrom torchvision import transforms, datasets, utilsimport matplotlib.pyplot as pltimport numpy as npimport torch.optim as optimfrom model import AlexNetimport osimport jsonimport timeimport torchvision#device : GPU 或 CPUdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print(device)#数据预处理data_transform = {    "train": transforms.Compose([transforms.RandomResizedCrop(224), # 随机裁剪为224x224                                 transforms.RandomHorizontalFlip(), # 水平翻转                                 transforms.ToTensor(), # 转为张量                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),# 均值和方差为0.5    "val": transforms.Compose([transforms.Resize((224, 224)), # 重置大小                               transforms.ToTensor(),                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}batch_size = 32 # 批次大小data_root = os.getcwd() # 获取当前路径image_path = data_root + "/数据"  # 数据路径train_dataset = datasets.ImageFolder(root=image_path + "/train",                                     transform=data_transform["train"]) # 加载训练数据集并预处理train_num = len(train_dataset) # 训练数据集大小train_loader = torch.utils.data.DataLoader(train_dataset,                                           batch_size=batch_size, shuffle=True,                                           num_workers=0) # 训练加载器validate_dataset = datasets.ImageFolder(root=image_path + "/val",                                        transform=data_transform["val"]) # 验证数据集val_num = len(validate_dataset) # 验证数据集大小validate_loader = torch.utils.data.DataLoader(validate_dataset,                                              batch_size=batch_size, shuffle=True,                                              num_workers=0) # 验证加载器print("训练数据集大小: ",train_num,"\n") # 28218print("验证数据集大小: ",val_num,"\n") # 308def imshow(img):    img = img / 2 + 0.5     # unnormalize    npimg = img.numpy()    plt.imshow(np.transpose(npimg, (1, 2, 0)))    plt.show()net = AlexNet(num_classes=13, init_weights=True) # 调用模型net.to(device)loss_function = nn.CrossEntropyLoss() # 损失函数:交叉熵optimizer = optim.Adam(net.parameters(), lr=0.0002) #优化器 Adamsave_path = './AlexNet.pth' # 训练参数保存路径best_acc = 0.0 # 训练过程中最高准确率#开始进行训练和测试,训练一轮,测试一轮for epoch in range(10):    # 训练部分    print(">>开始训练: ",epoch+1)    net.train()    #训练dropout    running_loss = 0.0    t1 = time.perf_counter()    for step, data in enumerate(train_loader, start=0):        images, labels = data        #print("\nlabels: ",labels)        #imshow(torchvision.utils.make_grid(images))        optimizer.zero_grad() # 梯度置0        outputs = net(images.to(device))         loss = loss_function(outputs, labels.to(device))        loss.backward() # 反向传播        optimizer.step()                running_loss += loss.item() # 累加损失        rate = (step + 1) / len(train_loader) # 训练进度        a = "*" * int(rate * 50) # *数        b = "." * int((1 - rate) * 50) # .数        print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")    print()    print(time.perf_counter()-t1) # 一个epoch花费的时间    # 验证部分    print(">>开始验证: ",epoch+1)    net.eval()    #验证不需要dropout    acc = 0.0  # 一个批次中分类正确个数    with torch.no_grad():        for val_data in validate_loader:            val_images, val_labels = val_data            outputs = net(val_images.to(device))            #print("outputs: \n",outputs,"\n")            predict_y = torch.max(outputs, dim=1)[1]            #print("predict_y: \n",predict_y,"\n")            acc += (predict_y == val_labels.to(device)).sum().item() # 预测和标签一致,累加        val_accurate = acc / val_num # 一个批次的准确率        if val_accurate > best_acc:            best_acc = val_accurate            torch.save(net.state_dict(), save_path) # 更新准确率最高的网络参数        print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %              (epoch + 1, running_loss / step, val_accurate))print('Finished Training')# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}flower_list = train_dataset.class_to_idx #  {0: 'daisy', 1: 'dandelion', 2: 'roses', 3: 'sunflowers', 4: 'tulips'}cla_dict = dict((val, key) for key, val in flower_list.items())# 将字典写入 json 文件json_str = json.dumps(cla_dict, indent=4) # 字典转jsonwith open('class_indices.json', 'w') as json_file: # 对class_indices.json写入操作    json_file.write(json_str) # 写入class_indices.json
5 测试
gedit predict.pypython predict.py


import torchfrom model import AlexNetfrom PIL import Imagefrom torchvision import transformsimport matplotlib.pyplot as pltimport jsonimport osdata_transform = transforms.Compose(    [transforms.Resize((224, 224)),     transforms.ToTensor(),     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])cwd = os.getcwd() # 当前路径predict = '数据/predict'predict_path = os.path.join(cwd,predict)#flowers = ['雏菊','蒲公英','玫瑰花','太阳花','郁金香']#flowers = [flower for flower in os.listdir(predict_path)]try:    json_file = open('./class_indices.json', 'r')    class_indict = json.load(json_file)except Exception as e:    print(e)    exit(-1)for j,flower in class_indict.items():    print(">>测试: ",flower)    #print("花\t","概率")     path = os.path.join(predict_path,flower)    images = [f1 for f1 in os.listdir(path) if ".gif" not in f1] # 过滤gif动图        acc_ = [0,0,0,0,0,0,0,0,0,0,0,0,0]    for image in images:        # 加载图片        img = Image.open(path+'/'+image).convert('RGB')        # RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 0        # .convert('RGB')        plt.imshow(img)        # [N, C, H, W]        img = data_transform(img)        # expand batch dimension        img = torch.unsqueeze(img, dim=0)        # read class_indict        try:            json_file = open('./class_indices.json', 'r')            class_indict = json.load(json_file)        except Exception as e:            print(e)            exit(-1)        # create model        model = AlexNet(num_classes=13)        # load model weights        model_weight_path = "./AlexNet.pth"        model.load_state_dict(torch.load(model_weight_path))        model.eval()        with torch.no_grad():            # predict class            output = torch.squeeze(model(img))            predict = torch.softmax(output, dim=0)            predict_flower = torch.argmax(predict).numpy()        #print(class_indict[str(predict_flower)],'\t', predict[predict_flower].item())        #print(str(predict_flower))        acc_[predict_flower]+=1    #print("acc_: ",acc_)    print("{}总共有{}张图片 \n".format(flower,len(images)))    #print(class_indict.values(),'\n',str(acc_))    print("{}准确率为:{}%".format(flower,100*acc_[int(j)]/len(images)))    print("\n")print(">>测试完毕!")


>>测试:  咖啡猫咖啡猫总共有21张图片咖啡猫准确率为:14.285714285714286%>>测试:  哈士奇哈士奇总共有24张图片 哈士奇准确率为:58.333333333333336%>>测试:  哥斯拉哥斯拉总共有21张图片 哥斯拉准确率为:57.142857142857146%>>测试:  奥特曼奥特曼总共有23张图片奥特曼准确率为:30.434782608695652%>>测试:  恐龙恐龙总共有23张图片 恐龙准确率为:34.78260869565217%>>测试:  柴犬柴犬总共有17张图片 柴犬准确率为:70.58823529411765%>>测试:  燕子燕子总共有18张图片 燕子准确率为:61.111111111111114%>>测试:  狮子狮子总共有24张图片狮子准确率为:58.333333333333336%>>测试:  猫猫总共有22张图片 猫准确率为:18.181818181818183%>>测试:  田园犬田园犬总共有22张图片 田园犬准确率为:4.545454545454546%>>测试:  老虎老虎总共有22张图片老虎准确率为:22.727272727272727%>>测试:  老鹰老鹰总共有22张图片 老鹰准确率为:0.0%>>测试:  鹦鹉鹦鹉总共有22张图片 鹦鹉准确率为:72.72727272727273%


标签: #net40eval