torchvision中提供了很多训练好的模型,这些模型是在1000类,224*224的imagenet中训练得到的,很多时候不适合我们自己的数据,可以根据需要进行修改。
1、类别不同
# coding=UTF-8
import torchvision.models as models #调用模型
model = models.resnet50(pretrained=True)
#提取fc层中固定的参数
fc_features = model.fc.in_features
#修改类别为9
model.fc = nn.Linear(fc_features, 9)
2、添加层后,加载部分参数
model = ...
model_dict = model.state_dict() # 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)
参考:https://blog.csdn.net/u012494820/article/details/79068625
https://blog.csdn.net/whut_ldz/article/details/78845947