最近在读 MMClassification 的源码时,发现有很多模块的构建都调用了 build_from_cfg()
函数,其中第 2 个参数就是一个 Registry 对象,直接看构造函数有点难以理解,所以总结一下文档和源码里的相关内容(文档里也正好为 Registry 单独开了一章)。
什么是 Registry
每个 Registry 对象都保存着一组从类名字符串到类类型的映射,这些类往往拥有类似的 API,比如:
'Converter1' -> <class 'Converter1'>
使用方法通常是:
- 实例化一个 Registry 对象,传入一个 name 位置参数作为这个 Registry 的名字:
from mmcv.utils import Registry
# create a registry for converters
CONVERTERS = Registry('converter')
- 在具体的类声明上方添加注解来注册模块:
from .builder import CONVERTERS
# use the registry to manage the module
@CONVERTERS.register_module()
class Converter1(object):
def __init__(self, a, b):
self.a = a
self.b = b
- 当需要实例化某个类时,就调用 Registry 对象的
build
方法,传入包含type
和构造函数参数的字典(通常是配置字典cfg
),就会在内部调用相应的build_func
(默认是build_from_cfg()
),返回实例化后的对象:
converter_cfg = dict(type='Converter1', a=a_value, b=b_value)
converter = CONVERTERS.build(converter_cfg)
build_from_cfg()
def build_from_cfg(cfg, registry, default_args=None):
"""Build a module from config dict.
根据配置字典构建一个模块。
Args:
cfg (dict): Config dict. It should at least contain the key "type".
配置字典。必须包含type。
registry (:obj:`Registry`): The registry to search the type from.
Registry对象。
default_args (dict, optional): Default initialization arguments.
参数的默认值。
Returns:
object: The constructed object.
构建得到的对象。
"""
# 配置参数必须是字典
if not isinstance(cfg, dict):
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
# 必须提供type
if 'type' not in cfg:
if default_args is None or 'type' not in default_args:
raise KeyError(
'`cfg` or `default_args` must contain the key "type", '
f'but got {cfg}\n{default_args}')
# registry必须是Registry对象
if not isinstance(registry, Registry):
raise TypeError('registry must be an mmcv.Registry object, '
f'but got {type(registry)}')
# default_args如果有,必须是字典
if not (isinstance(default_args, dict) or default_args is None):
raise TypeError('default_args must be a dict or None, '
f'but got {type(default_args)}')
# 复制一份
args = cfg.copy()
# 设置默认值
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
# 弹出type
obj_type = args.pop('type')
if isinstance(obj_type, str):
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError(
f'{obj_type} is not in the {registry.name} registry')
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')
try:
# 返回实例化的模块对象
return obj_cls(**args)
except Exception as e:
# Normal TypeError does not print class name.
raise type(e)(f'{obj_cls.__name__}: {e}')
Registry 类
- 构造函数
def __init__(self, name, build_func=None, parent=None, scope=None):
self._name = name
self._module_dict = dict()
self._children = dict()
self._scope = self.infer_scope() if scope is None else scope
# self.build_func will be set with the following priority:
# 1. build_func
# 2. parent.build_func
# 3. build_from_cfg
if build_func is None:
if parent is not None:
self.build_func = parent.build_func
else:
self.build_func = build_from_cfg
else:
self.build_func = build_func
if parent is not None:
assert isinstance(parent, Registry)
parent._add_children(self)
self.parent = parent
else:
self.parent = None
类名与类型的映射以字典的形式保存在 _module_dict
中。_children
保存了子 Registry 的名称和实例。register_module
、_add_children
等函数基本是对这两个字典的读写操作。