pytorch修改resnet18 输入通道

方法一:扩张1通道为3通道,利用torch.expand()方法

    model = resnet18(pretrained=False) # 主干提取网络
    model.load_state_dict(torch.load('./resnet18-5c106cde.pth'), strict=False)
    print(model)
    par = summary(model, (3, 224, 224), device='cpu')
    print(par)

    net = RFNet( model, 1, use_bn=True) # 输出类别 num_classes
    # print(model)
    input1 = torch.rand((1,3,256,256)) # 输入通道为1
    input1 = input1.expand(1,3,256,256) # 扩展为3通道

    print(input1.shape)
    input2 = torch.rand((1,1,256, 256))
    output = net(input1,input2)
    print(output.shape)

方法二:修改字典参数

import torchvision.models as models
import torch
import torch.nn as nn
from torchsummary import summary

resnet18 = models.resnet18(pretrained=False)
resnet18.conv1= nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3,bias=False)

# print(resnet18)
pretrained_dict = torch.load('./resnet/resnet18-5c106cde.pth')
# for k, v in pretrained_dict.items():
#     print(k)

x = torch.rand(64, 1, 7, 7)
pretrained_dict["conv1.weight"] = x

conv1 = pretrained_dict["conv1.weight"]
print(conv1.shape)
resnet18.load_state_dict(pretrained_dict)

# print(resnet18)
par = summary(resnet18, (1, 224, 224),device='cpu')
print(par)

上一篇:学习Python绕不过的13个小技巧!一般人我不告诉他!非常有用!


下一篇:python-字典使用方法