在做神经网络的搭建过程,经常使用pytorch中的resnet作为backbone,特别是resnet50,比如下面的这个网络设定
import torch import torch.nn as nn from torchvision import datasets, transforms from torchvision import models class base_resnet(nn.Module): def __init__(self): super(base_resnet, self).__init__() self.model = models.resnet50(pretrained=True) #self.model.load_state_dict(torch.load(‘./model/resnet50-19c8e357.pth‘)) self.model.avgpool = nn.AdaptiveAvgPool2d((1, 1)) def forward(self, x): x = self.model.conv1(x) x = self.model.bn1(x) x = self.model.relu(x) x = self.model.maxpool(x) x = self.model.layer1(x) x = self.model.layer2(x) x = self.model.layer3(x) x = self.model.layer4(x) x = self.model.avgpool(x) # x = x.view(x.size(0), x.size(1)) return x
该网络相当于继承了resnet50的所有参数结构,只不过是在forward中,改变了数据的传输过程,没有经过最后的特征展开以及线性分类。在下面的这行代码中,是相当于调用了pytoch中定义的resnet50网络,并且会自动下载并且加载训练好的网络参数,如果调为 pretrained=False,则不会加载训练好的参数,而是随机进行参数的赋值。但是我在服务器上跑这一类代码的时候发现,每当我重新跑一次程序,如果设置为True都会重新下载resnet50训练好的参数,但是由于有时候网络特别不好,导致我下载个基础的resnet50就要耗费我好长时间,那么我就想能不能将这个resnet50的参数提前下载好,使用的时候直接加载呢。当然是能了。
self.model = models.resnet50(pretrained=True)
我们可以根据我们使用的结构,到对应的地址下载对应的模型到本地,常用的resnet的地址如下:
‘resnet18‘: ‘https://download.pytorch.org/models/resnet18-5c106cde.pth‘,
‘resnet34‘: ‘https://download.pytorch.org/models/resnet34-333f7ec4.pth‘,
‘resnet50‘: ‘https://download.pytorch.org/models/resnet50-19c8e357.pth‘,
‘resnet101‘: ‘https://download.pytorch.org/models/resnet101-5d3b4d8f.pth‘,
‘resnet152‘: ‘https://download.pytorch.org/models/resnet152-b121ed2d.pth‘,
将其下载下来,然后将模型放入到和net.py同目录的model文件夹下面,然后使用下面的代码就可以避免每次都重新下载模型的问题了。
self.model = models.resnet50(pretrained=False)
self.model.load_state_dict(torch.load(‘./model/resnet50-19c8e357.pth‘))
抄袭:https://blog.csdn.net/Leo_whj/article/details/105247188