使用torch加载模型时出现字典键对应不起来的问题

问题描述:

RuntimeError: Error(s) in loading state_dict for Model:
Missing key(s) in state_dict: “MLP.3.weight”, “MLP.3.bias”, “MLP.6.weight”, “MLP.6.bias”.
Unexpected key(s) in state_dict: “MLP.2.weight”, “MLP.2.bias”, “MLP.4.weight”, “MLP.4.bias”.

这几天天天改bug,改的我实在想吐,下面的解决方案是参考了一个国外大佬的回答

## MLP Model modules
       MLP = nn.Sequential(
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(100, 2)
            )

上面是我在模型使用了nn.Sequential,关键就是每一层都没有名字,这个时候torch就会随机按照规则起名字。程序再次运行时由于环境不一样,就会导致这样的问题,还有大佬说是使用了多cpu或者GPU训练导致的。
将nn.Sequential里面的代码改为下面的:为每一层指定name

        MLP = nn.Sequential(OrderedDict([
            ('fc1',nn.Linear(100, 100)),
            ('relu1', nn.ReLU()),
            ('dropout1', nn.Dropout(0.2)),
            ('fc2',nn.Linear(100, 100)),
            ('relu2', nn.ReLU()),
            ('dropout2',nn.Dropout(0.2)),
            ('fc3',nn.Linear(100, 2))])
            )

有人也将torch.load()的参数strict设置为False,但是对我的来说不行,还可以考虑按照顺序匹配,但是太麻烦了。

上一篇:pytorch(二十六):自动编码器


下一篇:css画一条渐变色的直线