model.load_state_dict(state_dict)报错问题

model.load_state_dict(state_dict)报错问题

看一下训练得到的模型参数

    state_dict = torch.load('logs/sanity_3/checkpoint', map_location='cuda' if args['train']['cuda'] else 'cpu')
    state_dict = state_dict['model']
看一下参数
for k,v in state_dict.items():
    print(k)

输出:
module.block.0.layers.0.weight
module.block.0.layers.0.bias
module.block.0.layers.2.weight
module.block.0.layers.2.bias
module.block.0.layers.4.weight
module.block.0.layers.4.bias
module.block.0.layers.6.weight
module.block.0.layers.6.bias
module.block.0.layers.8.weight
module.block.0.layers.8.bias
module.block.2.layers.0.weight

再看一下网络模型的参数

model = CascadeNetwork(**args['network']) 
params=model.state_dict() #获得模型的原始状态以及参数。
for k,v in params.items():
    print(k) #只打印key值,不打印具体参数。

输出:
block.0.layers.0.weight
block.0.layers.0.bias
block.0.layers.2.weight
block.0.layers.2.bias
block.0.layers.4.weight
block.0.layers.4.bias
block.0.layers.6.weight
block.0.layers.6.bias

解决方法

对load的模型创建新的字典,去掉不需要的key值"module".

首先加载我训练好的模型
    state_dict = torch.load('logs/sanity_3/checkpoint', map_location='cuda' if args['train']['cuda'] else 'cpu')
    state_dict = state_dict['model']
然后创建一个新的词典that does not contain module.
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:] # remove `module.`,表面从第7个key值字符取到最后一个字符,正好去掉了module.
        new_state_dict[name] = v #新字典的key值对应的value为一一对应的值。

最后把参数导入网络模型

model.load_state_dict(new_state_dict) # 从新加载这个模型。
上一篇:8.字体样式


下一篇:【优化覆盖】移动网格求解无线传感器网络节点覆盖优化问题matlab源码