迁移学习resnet

  1 import torch
  2 import numpy as np
  3 import torchvision
  4 import torch.nn as nn
  5 
  6 from torchvision import datasets,transforms,models
  7 import matplotlib.pyplot as plt
  8 import time
  9 import os
 10 import copy
 11 print("Torchvision Version:",torchvision.__version__)
 12 
 13 data_dir="./hymenoptera_data"
 14 batch_size=32
 15 input_size=224
 16 model_name="resnet"
 17 num_classes=2
 18 num_epochs=15
 19 feature_extract=True
 20 data_transforms={
 21     "train":transforms.Compose([
 22         transforms.RandomResizedCrop(input_size),
 23         transforms.RandomHorizontalFlip(),
 24         transforms.ToTensor(),
 25         transforms.Normalize([0.482,0.456,0.406],[0.229,0.224,0.225])
 26     ]),
 27     "val":transforms.Compose([
 28 
 29     transforms.RandomResizedCrop(input_size),
 30     transforms.RandomHorizontalFlip(),
 31     transforms.ToTensor(),
 32     transforms.Normalize([0.482, 0.456, 0.406], [0.229, 0.224, 0.225])
 33 ]),
 34 }
 35 image_datasets={x:datasets.ImageFolder(os.path.join(data_dir,x),data_transforms[x])
 36                 for x in ["train",val]}
 37 dataloader_dict={x:torch.utils.data.DataLoader(image_datasets[x],batch_size=batch_size,
 38                 shuffle=True)for x in [train,val]}
 39 device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
 40 inputs,labels=next(iter(dataloader_dict["train"]))
 41 #print(inputs.shape)#一个batch
 42 #print(labels)
 43 
 44 
 45 #加载resent模型并修改全连接层
 46 def set_parameter_requires_grad(model,feature_extract):
 47     if feature_extract:
 48         for param in model.parameters():
 49             param.requires_grad=False
 50 
 51 def initialize_model(model_name,num_classes,feature_extract,use_pretrained=True):
 52     if model_name=="resnet":
 53         model_ft=models.resnet18(pretrained=use_pretrained)
 54         set_parameter_requires_grad(model_ft,feature_extract)
 55         num_ftrs=model_ft.fc.in_features
 56         model_ft.fc=nn.Linear(num_ftrs,num_classes)
 57         input_size=224
 58     else:
 59         print("model not implemented")
 60         return None,None
 61 
 62     return model_ft,input_size
 63 model_ft,input_size=initialize_model(model_name,num_classes,feature_extract,use_pretrained=True)
 64 #print(model_ft)
 65 print(-*200)
 66 
 67 
 68 def train_model(model, dataloaders, loss_fn, optimizer, num_epochs=5):
 69     best_model_wts = copy.deepcopy(model.state_dict())
 70     best_acc = 0.
 71     val_acc_history = []
 72     for epoch in range(num_epochs):
 73         for phase in ["train", "val"]:
 74             running_loss = 0.
 75             running_corrects = 0.
 76             if phase == "train":
 77                 model.train()
 78             else:
 79                 model.eval()
 80 
 81             for inputs, labels in dataloaders[phase]:
 82                 inputs, labels = inputs.to(device), labels.to(device)
 83 
 84                 with torch.autograd.set_grad_enabled(phase == "train"):
 85                     outputs = model(inputs)  # bsize * 2
 86                     loss = loss_fn(outputs, labels)
 87 
 88                 preds = outputs.argmax(dim=1)
 89                 if phase == "train":
 90                     optimizer.zero_grad()
 91                     loss.backward()
 92                     optimizer.step()
 93                 running_loss += loss.item() * inputs.size(0)
 94                 running_corrects += torch.sum(preds.view(-1) == labels.view(-1)).item()
 95 
 96             epoch_loss = running_loss / len(dataloaders[phase].dataset)
 97             epoch_acc = running_corrects / len(dataloaders[phase].dataset)
 98 
 99             print("Phase {} loss: {}, acc: {}".format(phase, epoch_loss, epoch_acc))
100 
101             if phase == "val" and epoch_acc > best_acc:
102                 best_acc = epoch_acc
103                 best_model_wts = copy.deepcopy(model.state_dict())
104             if phase == "val":
105                 val_acc_history.append(epoch_acc)
106     model.load_state_dict(best_model_wts)
107     return model, val_acc_history
108 
109 
110 model_ft = model_ft.to(device)
111 optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
112                                    model_ft.parameters()), lr=0.001, momentum=0.9)
113 loss_fn = nn.CrossEntropyLoss()
114 _, ohist = train_model(model_ft, dataloader_dict, loss_fn, optimizer, num_epochs=num_epochs)
115 
116 
117 
118 plt.title("Validation Accuracy vs. Number of Training Epochs")
119 plt.xlabel("Training Epochs")
120 plt.ylabel("Validation Accuracy")
121 plt.plot(range(1,num_epochs+1),ohist,label="Pretrained")
122 plt.ylim((0,1.))
123 plt.xticks(np.arange(1, num_epochs+1, 1.0))
124 plt.legend()
125 plt.show()

迁移学习resnet

迁移学习resnet

上一篇:spring依赖注入源码分析和mongodb自带连接本地mongodb服务逻辑分析


下一篇:第44天: Web 开发 Bootstrap