当前位置 博文首页 > DL_fan的博客:Registry注册机制

    DL_fan的博客:Registry注册机制

    作者:[db:作者] 时间:2021-07-10 22:23

    前言:不管是Detectron还是mmdetection,都有用到这个register机制,特意去弄明白,记录一下。

    首先看Registry代码:

    # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
    
    from typing import Dict, Optional, Iterable, Tuple, Iterator
    
    from tabulate import tabulate
    
    
    class Registry(Iterable[Tuple[str, object]]):
        """
        The registry that provides name -> object mapping, to support third-party
        users' custom modules.
    
        To create a registry (e.g. a backbone registry):
    
        .. code-block:: python
    
            BACKBONE_REGISTRY = Registry('BACKBONE')
    
        To register an object:
    
        .. code-block:: python
    
            @BACKBONE_REGISTRY.register()
            class MyBackbone():
                ...
    
        Or:
    
        .. code-block:: python
    
            BACKBONE_REGISTRY.register(MyBackbone)
        """
    
        def __init__(self, name: str) -> None:
            """
            Args:
                name (str): the name of this registry
            """
            self._name: str = name
            self._obj_map: Dict[str, object] = {}
    
        def _do_register(self, name: str, obj: object) -> None:
            assert (
                name not in self._obj_map
            ), "An object named '{}' was already registered in '{}' registry!".format(
                name, self._name
            )
            self._obj_map[name] = obj
    
        def register(self, obj: object = None) -> Optional[object]:
            """
            Register the given object under the the name `obj.__name__`.
            Can be used as either a decorator or not. See docstring of this class for usage.
            """
            if obj is None:
                # used as a decorator
                def deco(func_or_class: object) -> object:
                    name = func_or_class.__name__  # pyre-ignore
                    self._do_register(name, func_or_class)
                    return func_or_class
    
                return deco
    
            # used as a function call
            name = obj.__name__  # pyre-ignore
            self._do_register(name, obj)
    
        def get(self, name: str) -> object:
            ret = self._obj_map.get(name)
            if ret is None:
                raise KeyError(
                    "No object named '{}' found in '{}' registry!".format(name, self._name)
                )
            return ret
    
        def __contains__(self, name: str) -> bool:
            return name in self._obj_map
    
        def __repr__(self) -> str:
            table_headers = ["Names", "Objects"]
            table = tabulate(
                self._obj_map.items(), headers=table_headers, tablefmt="fancy_grid"
            )
            return "Registry of {}:\n".format(self._name) + table
    
        def __iter__(self) -> Iterator[Tuple[str, object]]:
            return iter(self._obj_map.items())
    
        # pyre-fixme[4]: Attribute must be annotated.
        __str__ = __repr__
    

    可看出register方法就是通过调用_do_register将函数名称或者类名称,函数地址或者类地址做成一个字典,在通过get方法获取函数或者类功能。

    示例代码调用:

    
    from fvcore.common.registry import Registry
    
    BACKBONE_REGISTRY = Registry("BACKBONE")
    
    @BACKBONE_REGISTRY.register()
    def test_register(cfg):
        print('==cfg:', cfg)
        return '==test_register is called'
    
    
    def debug_register():
        cfg = 'hahahah'
        print(BACKBONE_REGISTRY.get('test_register'))##返回函数或者类对象
        res = BACKBONE_REGISTRY.get('test_register')(cfg)#调用函数或者类功能
        print('==res:', res)
    
    if __name__ == '__main__':
        debug_register()

    而对于mmcv:

    
    import mmcv
    
    def build_from_cfg(cfg, registry, default_args=None):
        args = cfg.copy()
        print('==cfg:', cfg)
        print('==registry:', registry)
        print('==default_args:', default_args)
        if default_args is not None:
            for name, value in default_args.items():
                args.setdefault(name, value)
    
        obj_type = args.pop('type')  # 注册 str 类名
        if isinstance(obj_type, str):
            # 相当于 self._module_dict[obj_type]
            obj_cls = registry.get(obj_type)
            print('==obj_cls:', obj_cls)
            if obj_cls is None:
                raise KeyError(
                    f'{obj_type} is not in the {registry.name} registry')
    
        # 如果已经实例化了,那就直接返回
        elif inspect.isclass(obj_type):
            obj_cls = obj_type
        else:
            raise TypeError(
                f'type must be a str or valid type, but got {type(obj_type)}')
    
        # 最终初始化对于类,并且返回,就完成了一个类的实例化过程
        return obj_cls(**args)
    
    
    ANYNAMES = mmcv.Registry('convert')
    
    #其实就是将Converter1 和 类实例化做成字典
    @ANYNAMES.register_module()
    class Converter1(object):
        def __init__(self, a, b):
            self.a = a
            self.b = b
    
    
    a_value = 10
    b_value = 20
    converter_cfg = dict(type='Converter1', a=a_value, b=b_value)
    print('==converter_cfg:', converter_cfg)
    converter = build_from_cfg(converter_cfg, ANYNAMES)
    print('==converter:', converter)
    print('==converter.a:', converter.a)
    print('==converter.b:', converter.b)
    

    上述例子就是将Converter1 和 类实例化做成字典,然后再通过build_from_cfg经过get方法获取类功能。

    cs