关于跑图像去雾算法DCPDN的教程及Bug解决

前情提要

  最近刚刚开始图像去雾方面的研究,自然少不了阅读这一领域的经典文献和GitHub源码。DCPDN是其中比较很有价值的一篇,在阅读文献过程中,希望跑通它的代码,结合代码来帮助我理解这一算法的实现原理。但是我在配置环境跑程序的过程中出现了许多问题,花费了许多时间和精力解决了其中部分问题,因此想在此记录下来,同时也希望对同样遇到这些问题的你有所帮助,谢谢~

论文Densely Connected Pyramid Dehazing Network
github源码https://github.com/hezhangsprinter/DCPDN
参考的相关博客1. 一步一步教你跑DCPDN深度学习去雾网络
        2. DCPDN项目

代码运行环境

  1. Ubuntu 18.04.3
  2. python 3.6
  3. torch 0.3.1
  4. torchvision 0.2.1

直接把我配环境的命令行语句贴出来吧

  1. 新建conda环境
conda create -n xxx(你的环境名) python=3.6
  1. 进入新建的环境(激活环境)
conda activate xxx(你的环境名)
  1. 开始安装各种依赖的包
pip install https://download.pytorch.org/whl/cu90/torch-0.3.1-cp36-cp36m-linux_x86_64.whl
pip install torchvision==0.2.1
pip install h5py
pip install scipy

(没记错的话应该就是上面这些,比较重要的是torch版本和torchvision版本,因为作者的代码是基于早期的低版本,由于新版本做了一些改动,如果不按低版本来装的话,会遇到更多麻烦的问题,亲身经历。等自己研究透了,看看能否高版本的复现一下?哈哈哈)


原始代码可能使用的是Python2版本,因此某些语法与现在的Python3不兼容,(Python2使用<>作为“不等于”的符号,而Python3使用!=)需要修改一下,分别位于train.py文件的第312行、329行、353行和366行

问题1:Missing key(s) in state_dict: xxxxxxxxxxxxx; Unexpected key(s) in state_dict: xxxxxxxxxxxxx

报错原因: 预训练模权重的字典关键字与所创建的网络模型的字典关键字不匹配。(简单来说,我们的模型要使用现有的预训练权重来进行参数的初始化,这个过程需要两者各层级的网络名称相对应,否则就会出现上述错误。)

问题出在train.py文件的第124行:

if opt.netG != '':
	netG.load_state_dict(torch.load(opt.netG))

opt.netG是预训练权重文件netG_epoch_8.pth的所在路径,torch.load(opt.netG)是加载这一权重文件。这个权重文件在以前训练的时候可能还是采用旧的字典关键字,如:‘norm.1’, ‘relu.1’, ‘conv.1’, ‘norm.2’, ‘relu.2’, ‘conv.2’,但是现在的网络模型在创建时已经不再允许使用“.”了,所以需要修改预训练权重的关键字,使其与我们的网络匹配。

通过正则修改,将上面的代码修改成以下内容:(参考torchvision.models.densenet中的做法)

if opt.netG != '':
    checkpoint = torch.load(opt.netG)
    pattern = re.compile(r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.'
                         r'(?:weight|bias|running_mean|running_var))$')
    for key in list(checkpoint.keys()):
        res = pattern.match(key)
        if res:
            new_key = res.group(1) + res.group(2)
            checkpoint[new_key] = checkpoint[key]
            del checkpoint[key]
    netG.load_state_dict(checkpoint)

经过修改后,这一段代码就能顺利执行了。

问题2:"python3.6/site-packages/torch/utils/data/dataloader.py", line 271, in __next__ raise StopIteration

关于这个问题,上面参考的博客2中作出了解答。因为博客1的作者训练网络时使用以下的命令:

python train.py --dataroot ./facades/train512 --valDataroot ./facades/test512 --exp ./checkpoints_new --netG ./demo_model/netG_epoch_8.pth

其中,--valDataroot传入的是./facades/test512这个路径,但是源代码的作者并没有提供这一文件,只有./facades/val512,因此把命令改成:

python train.py --dataroot ./facades/train512 --valDataroot ./facades/val512 --exp ./checkpoints_new --netG ./demo_model/netG_epoch_8.pth

就能成功解决这个问题。

最后,再次感谢上面两位博主的博客~

上一篇:mt8665芯片怎么样?联发科mt8665芯片参数介绍


下一篇:【Laravel基础】laravel基础之相关概念,自定义服务提供者:Contracts, ServiceContainer, ServiceProvider, Facades关系