1 import torch
2 import numpy as np
3 import torchvision
4 import torch.nn as nn
5
6 from torchvision import datasets,transforms,models
7 import matplotlib.pyplot as plt
8 import time
9 import os
10 import copy
11 print("Torchvision Version:",torchvision.__version__)
12
13 data_dir="./hymenoptera_data"
14 batch_size=32
15 input_size=224
16 model_name="resnet"
17 num_classes=2
18 num_epochs=15
19 feature_extract=True
20 data_transforms={
21 "train":transforms.Compose([
22 transforms.RandomResizedCrop(input_size),
23 transforms.RandomHorizontalFlip(),
24 transforms.ToTensor(),
25 transforms.Normalize([0.482,0.456,0.406],[0.229,0.224,0.225])
26 ]),
27 "val":transforms.Compose([
28
29 transforms.RandomResizedCrop(input_size),
30 transforms.RandomHorizontalFlip(),
31 transforms.ToTensor(),
32 transforms.Normalize([0.482, 0.456, 0.406], [0.229, 0.224, 0.225])
33 ]),
34 }
35 image_datasets={x:datasets.ImageFolder(os.path.join(data_dir,x),data_transforms[x])
36 for x in ["train",‘val‘]}
37 dataloader_dict={x:torch.utils.data.DataLoader(image_datasets[x],batch_size=batch_size,
38 shuffle=True)for x in [‘train‘,‘val‘]}
39 device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
40 inputs,labels=next(iter(dataloader_dict["train"]))
41 #print(inputs.shape)#一个batch
42 #print(labels)
43
44
45 #加载resent模型并修改全连接层
46 def set_parameter_requires_grad(model,feature_extract):
47 if feature_extract:
48 for param in model.parameters():
49 param.requires_grad=False
50
51 def initialize_model(model_name,num_classes,feature_extract,use_pretrained=True):
52 if model_name=="resnet":
53 model_ft=models.resnet18(pretrained=use_pretrained)
54 set_parameter_requires_grad(model_ft,feature_extract)
55 num_ftrs=model_ft.fc.in_features
56 model_ft.fc=nn.Linear(num_ftrs,num_classes)
57 input_size=224
58 else:
59 print("model not implemented")
60 return None,None
61
62 return model_ft,input_size
63 model_ft,input_size=initialize_model(model_name,num_classes,feature_extract,use_pretrained=True)
64 #print(model_ft)
65 print(‘-‘*200)
66
67
68 def train_model(model, dataloaders, loss_fn, optimizer, num_epochs=5):
69 best_model_wts = copy.deepcopy(model.state_dict())
70 best_acc = 0.
71 val_acc_history = []
72 for epoch in range(num_epochs):
73 for phase in ["train", "val"]:
74 running_loss = 0.
75 running_corrects = 0.
76 if phase == "train":
77 model.train()
78 else:
79 model.eval()
80
81 for inputs, labels in dataloaders[phase]:
82 inputs, labels = inputs.to(device), labels.to(device)
83
84 with torch.autograd.set_grad_enabled(phase == "train"):
85 outputs = model(inputs) # bsize * 2
86 loss = loss_fn(outputs, labels)
87
88 preds = outputs.argmax(dim=1)
89 if phase == "train":
90 optimizer.zero_grad()
91 loss.backward()
92 optimizer.step()
93 running_loss += loss.item() * inputs.size(0)
94 running_corrects += torch.sum(preds.view(-1) == labels.view(-1)).item()
95
96 epoch_loss = running_loss / len(dataloaders[phase].dataset)
97 epoch_acc = running_corrects / len(dataloaders[phase].dataset)
98
99 print("Phase {} loss: {}, acc: {}".format(phase, epoch_loss, epoch_acc))
100
101 if phase == "val" and epoch_acc > best_acc:
102 best_acc = epoch_acc
103 best_model_wts = copy.deepcopy(model.state_dict())
104 if phase == "val":
105 val_acc_history.append(epoch_acc)
106 model.load_state_dict(best_model_wts)
107 return model, val_acc_history
108
109
110 model_ft = model_ft.to(device)
111 optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
112 model_ft.parameters()), lr=0.001, momentum=0.9)
113 loss_fn = nn.CrossEntropyLoss()
114 _, ohist = train_model(model_ft, dataloader_dict, loss_fn, optimizer, num_epochs=num_epochs)
115
116
117
118 plt.title("Validation Accuracy vs. Number of Training Epochs")
119 plt.xlabel("Training Epochs")
120 plt.ylabel("Validation Accuracy")
121 plt.plot(range(1,num_epochs+1),ohist,label="Pretrained")
122 plt.ylim((0,1.))
123 plt.xticks(np.arange(1, num_epochs+1, 1.0))
124 plt.legend()
125 plt.show()