MMSegmentation文档-Tutorial 4: Customize Models
源文档 https://mmsegmentation.readthedocs.io/en/latest/tutorials/customize_models.html
1 Customize optimizer
假设你想要添加一个名为MyOptimizer
的优化器,该优化器有参数a
、b
和c
。首先需要在文件中实现这个优化器,即mmseg/core/optimizer/my_optimizer.py
:
from mmcv.runner import OPTIMIZERS
from torch.optim import Optimizer
@OPTIMIZERS.register_module
class MyOptimizer(Optimizer):
def __init__(self, a, b, c)
然后将这个模块添加到mmseg/core/optimizer/__init__.py
中,这样注册表就会找到新模块并添加它,
from .my_optimizer import MyOptimizer
然后你可以在配置文件的optimizer
字段中使用MyOptimizer
。在配置中,优化器是由如下字段optimizer
定义的:
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
要使用自己的优化器,可以将字段更改为:
optimizer = dict(type='MyOptimizer', a=a_value, b=b_value, c=c_value)
我们已经支持使用PyTorch实现的所有优化器,唯一的修改是更改配置文件的优化器optimizer
字段。例如,如果您想使用ADAM
,尽管性能会下降很多,但修改可以如下所示:
optimizer = dict(type='Adam', lr=0.0003, weight_decay=0.0001)
可以根据PyTorch的API文档直接设置参数
2 Customize optimizer constructor
一些模型可能有一些特定参数的优化设置,例如,BatchNoarm层的weight decay。用户可以通过自定义优化器构造函数来进行这些细粒度参数调优。
from mmcv.utils import build_from_cfg
from mmcv.runner import OPTIMIZER_BUILDERS
from .cocktail_optimizer import CocktailOptimizer
@OPTIMIZER_BUILDERS.register_module
class CocktailOptimizerConstructor(object):
def __init__(self, optimizer_cfg, paramwise_cfg=None):
def __call__(self, model):
return my_optimizer
3 Develop new components
MMSegmentation主要有两种类型的组件。
- backbone:通常是堆叠的卷积网络来提取特征地图,如ResNet, HRNet。
- head:用于语义分割地特征图解码的组件
3.1 Add new backbones
这里通过一个MobileNet的例子来展示如何开发新的组件。
- 创建一个新文件
mmseg/models/backbones/mobilenet.py
import torch.nn as nn
from ..registry import BACKBONES
@BACKBONES.register_module
class MobileNet(nn.Module):
def __init__(self, arg1, arg2):
pass
def forward(self, x): # should return a tuple
pass
def init_weights(self, pretrained=None):
pass
- 在
mmseg/models/backbones/__init__.py
中import模块
from .mobilenet import MobileNet
- 在配置文件中使用
model = dict(
...
backbone=dict(
type='MobileNet',
arg1=xxx,
arg2=xxx),
...
3.2 Add new heads
在MMSegmentation中,为所有的分割头提供了一个基本的BaseDecodeHead。所有新实现的解码头都应该从它派生出来。下面将使用PSPNet的示例演示如何开发一个新的头。
首先,在mmseg/models/decode_heads/psp_head.py
中添加一个新的解码头。PSPNet实现了一个用于分割解码的解码头。为了实现一个解码头,我们基本上需要实现新模块的三个功能,如下所示:
@HEADS.register_module()
class PSPHead(BaseDecodeHead):
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
super(PSPHead, self).__init__(**kwargs)
def init_weights(self):
def forward(self, inputs):
接下来,需要在mmseg/models/decode_heads/__init__.py
中添加模块,这样对应的注册表就可以找到并加载它们。
PSPNet的配置文件如下所示:
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='pretrain_model/resnet50_v1c_trick-2cccc1ad.pth',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='PSPHead',
in_channels=2048,
in_index=3,
channels=512,
pool_scales=(1, 2, 3, 6),
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)))
3.3 Add new loss
假设想添加一个新的损失作为MyLoss
用于分割解码。要添加一个新的损失函数,需要在mmseg/models/losses/my_loss.py
中实现它。weighted_loss
可以对每个元素的损失进行加权。
import torch
import torch.nn as nn
from ..builder import LOSSES
from .utils import weighted_loss
@weighted_loss
def my_loss(pred, target):
assert pred.size() == target.size() and target.numel() > 0
loss = torch.abs(pred - target)
return loss
@LOSSES.register_module
class MyLoss(nn.Module):
def __init__(self, reduction='mean', loss_weight=1.0):
super(MyLoss, self).__init__()
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss = self.loss_weight * my_loss(
pred, target, weight, reduction=reduction, avg_factor=avg_factor)
return loss
然后需要将其添加到mmseg/models/losses/__init__.py
中
from .my_loss import MyLoss, my_loss
若要使用,需修改loss_xxx
字段。然后需要修改head中的loss_decode
字段。Loss_weight
可以用来平衡多个损失
loss_decode=dict(type='MyLoss', loss_weight=1.0))