话不多说我直接上代码,不懂的加我QQ:1260814407,我为了验证state_dict的使用方法,全连接的时候写的有点不一样,之后我会试试其他模型的迁移学习,看看有没有什么更好的办法,字典实在是用的太不习惯了,python我唯一能忍受的就是列表了,别的都好难用。
1 import torch 2 import torch.nn as nn 3 from torchvision.models import alexnet 4 5 alex=alexnet(pretrained=True) 6 # print(alex) 7 # print(alex.state_dict().keys()) 8 pretrained_dict=alex.state_dict() 9 weight_0=pretrained_dict['features.3.weight'] 10 bias_0=pretrained_dict['features.3.bias'] 11 print(weight_0.shape) 12 print(bias_0.shape) 13 class alex_net(nn.Module): 14 def __init__(self,num_classes): 15 super(alex_net, self).__init__() 16 self.features=nn.Sequential( 17 nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), 18 nn.ReLU(inplace=True), 19 nn.MaxPool2d(kernel_size=3, stride=2), 20 nn.Conv2d(64, 192, kernel_size=5, padding=2), 21 nn.ReLU(inplace=True), 22 nn.MaxPool2d(kernel_size=3, stride=2), 23 nn.Conv2d(192, 384, kernel_size=3, padding=1), 24 nn.ReLU(inplace=True), 25 nn.Conv2d(384, 256, kernel_size=3, padding=1), 26 nn.ReLU(inplace=True), 27 nn.Conv2d(256, 256, kernel_size=3, padding=1), 28 nn.ReLU(inplace=True), 29 nn.MaxPool2d(kernel_size=3, stride=2), 30 ) 31 self.avgpool=nn.AdaptiveAvgPool2d((6,6)) 32 self.classifier=nn.Sequential( 33 nn.Dropout(0.5), 34 nn.Linear(256 * 6 * 6, 4096), 35 nn.ReLU(inplace=True), 36 nn.Dropout(), 37 nn.Linear(4096, 4096), 38 nn.ReLU(inplace=True), 39 # nn.Linear(4096,num_classes) 40 ) 41 self.gategory=nn.Linear(4096, num_classes) 42 def forward(self,input): 43 out=self.features(input) 44 out=self.avgpool(out) 45 out=torch.flatten(out,1) 46 out=self.classifier(out) 47 out=self.gategory(out) 48 return out 49 50 model=alex_net(num_classes=5) 51 print(model.state_dict().keys()) 52 # print(model)
import torch from torch import optim,nn import visdom from torchvision.models import alexnet from torch.utils.data import DataLoader from transfer_learning.poke import Pokemonn from transfer_learning.model import alex_net batch_size=16 learning_rate=1e-3 # device=torch.device('cuda') epoches=10 # 设置随机种子,用于生成随机数 torch.manual_seed(1234) vis = visdom.Visdom() train_db=Pokemonn('/Users/wenyu/Desktop/TorchProject/Pokemann/pokeman',227,mode='train') validation_db=Pokemonn('/Users/wenyu/Desktop/TorchProject/Pokemann/pokeman',227,mode='validation') test_db=Pokemonn('/Users/wenyu/Desktop/TorchProject/Pokemann/pokeman',227,mode='test') train_loader=DataLoader(train_db,batch_size=batch_size,shuffle=True,num_workers=4) validation_loader=DataLoader(validation_db,batch_size=batch_size,num_workers=2) test_loader=DataLoader(test_db,batch_size=batch_size,num_workers=2) def evaluate(model,loader): correct=0 total_num=len(loader.dataset) for x,y in loader: # x,y=x.to(device),y.to(device) with torch.no_grad(): logits=model(x) pred=logits.argmax(dim=1) correct+=torch.eq(pred,y).sum().float().item() return correct/total_num def main(): # model = ResNet18(5).to(device) model=alex_net(5) model_dict=model.state_dict() pretrained_model=alexnet(pretrained=True) pretrained_dict=pretrained_model.state_dict() pretrained_dict={k: v for k, v in pretrained_dict.items() if k in model_dict} model_dict.update(pretrained_dict) model.load_state_dict(model_dict) optimizer=optim.SGD(model.parameters(),lr=learning_rate) fun_loss=nn.CrossEntropyLoss() vis.line([0.], [-1], win='train_loss', opts=dict(title='train_loss')) vis.line([0.], [-1], win='validation_acc', opts=dict(title='validation_acc')) global_step=0 best_epoch,best_acc=0,0 for epoch in range(epoches): for step,(x,y) in enumerate(train_loader): # x,y=x.to(device),y.to(device) logits=model(x) loss=fun_loss(logits,y) # pred=logits.argmax(dim=1) optimizer.zero_grad() loss.backward() optimizer.step() vis.line([loss.item()],[global_step],win='train_loss',update='append') global_step += 1 if epoch % 1==0: val_acc=evaluate(model, validation_loader) if val_acc>best_acc: best_acc=val_acc best_epoch=epoch torch.save(model.state_dict(),'best.mdl') vis.line([val_acc],[global_step],win='validation_acc',update='append') print('best acc',best_acc,'best epoch',best_epoch) model.load_state_dict(torch.load('best.mdl')) print('load from ckpt') test_acc=evaluate(model,test_loader) print(test_acc) if __name__ == '__main__': main()
训练这部分大部分是跟龙龙老师写的,数据集也是它的,我就想简单的验证一下迁移学习怎么用的,之后会做mobelnet,龙龙老师的pytorch讲的真的非常浅显易懂,但是迁移学习这块不是很全面,想学的话还需要再看看。