# Copyright (c) OpenMMLab. All rights reserved. import time import pytest from mmengine.config import Config, ConfigDict # type: ignore from mmengine.registry import DefaultScope, Registry, build_from_cfg from mmengine.utils import ManagerMixin class TestRegistry: def test_init(self): CATS = Registry('cat') assert CATS.name == 'cat' assert CATS.module_dict == {} assert CATS.build_func is build_from_cfg assert len(CATS) == 0 # test `build_func` parameter def build_func(cfg, registry, default_args): pass CATS = Registry('cat', build_func=build_func) assert CATS.build_func is build_func # test `parent` parameter # `parent` is either None or a `Registry` instance with pytest.raises(AssertionError): CATS = Registry('little_cat', parent='cat', scope='little_cat') LITTLECATS = Registry('little_cat', parent=CATS, scope='little_cat') assert LITTLECATS.parent is CATS assert CATS._children.get('little_cat') is LITTLECATS # test `scope` parameter # `scope` is either None or a string with pytest.raises(AssertionError): CATS = Registry('cat', scope=1) CATS = Registry('cat') assert CATS.scope == 'test_registry' CATS = Registry('cat', scope='cat') assert CATS.scope == 'cat' def test_split_scope_key(self): DOGS = Registry('dogs') scope, key = DOGS.split_scope_key('BloodHound') assert scope is None and key == 'BloodHound' scope, key = DOGS.split_scope_key('hound.BloodHound') assert scope == 'hound' and key == 'BloodHound' scope, key = DOGS.split_scope_key('hound.little_hound.Dachshund') assert scope == 'hound' and key == 'little_hound.Dachshund' def test_register_module(self): CATS = Registry('cat') # can only decorate a class with pytest.raises(TypeError): @CATS.register_module() def some_method(): pass # test `name` parameter which must be either of None, a string or a # sequence of string # `name` is None @CATS.register_module() class BritishShorthair: pass assert len(CATS) == 1 assert CATS.get('BritishShorthair') is BritishShorthair # `name` is a string @CATS.register_module(name='Munchkin') class Munchkin: pass assert len(CATS) == 2 assert CATS.get('Munchkin') is Munchkin assert 'Munchkin' in CATS # `name` is a sequence of string @CATS.register_module(name=['Siamese', 'Siamese2']) class SiameseCat: pass assert CATS.get('Siamese') is SiameseCat assert CATS.get('Siamese2') is SiameseCat assert len(CATS) == 4 # `name` is an invalid type with pytest.raises( TypeError, match=('name must be None, an instance of str, or a sequence ' "of str, but got <class 'int'>")): @CATS.register_module(name=7474741) class SiameseCat: pass # test `force` parameter, which must be a boolean # force is not a boolean with pytest.raises( TypeError, match="force must be a boolean, but got <class 'int'>"): @CATS.register_module(force=1) class BritishShorthair: pass # force=False with pytest.raises( KeyError, match='BritishShorthair is already registered in cat'): @CATS.register_module() class BritishShorthair: pass # force=True @CATS.register_module(force=True) class BritishShorthair: pass assert len(CATS) == 4 # test `module` parameter, which is either None or a class # when the `register_module`` is called as a method rather than a # decorator, which must be a class with pytest.raises( TypeError, match="module must be a class, but got <class 'str'>"): CATS.register_module(module='string') class SphynxCat: pass CATS.register_module(module=SphynxCat) assert CATS.get('SphynxCat') is SphynxCat assert len(CATS) == 5 CATS.register_module(name='Sphynx1', module=SphynxCat) assert CATS.get('Sphynx1') is SphynxCat assert len(CATS) == 6 CATS.register_module(name=['Sphynx2', 'Sphynx3'], module=SphynxCat) assert CATS.get('Sphynx2') is SphynxCat assert CATS.get('Sphynx3') is SphynxCat assert len(CATS) == 8 def _build_registry(self): """A helper function to build a Hierarchical Registry.""" # Hierarchical Registry # DOGS # _______|_______ # | | # HOUNDS (hound) SAMOYEDS (samoyed) # _______|_______ | # | | | # LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS # (little_hound) (mid_hound) (little_samoyed) registries = [] DOGS = Registry('dogs') registries.append(DOGS) HOUNDS = Registry('dogs', parent=DOGS, scope='hound') registries.append(HOUNDS) LITTLE_HOUNDS = Registry('dogs', parent=HOUNDS, scope='little_hound') registries.append(LITTLE_HOUNDS) MID_HOUNDS = Registry('dogs', parent=HOUNDS, scope='mid_hound') registries.append(MID_HOUNDS) SAMOYEDS = Registry('dogs', parent=DOGS, scope='samoyed') registries.append(SAMOYEDS) LITTLE_SAMOYEDS = Registry( 'dogs', parent=SAMOYEDS, scope='little_samoyed') registries.append(LITTLE_SAMOYEDS) return registries def test_get_root_registry(self): # Hierarchical Registry # DOGS # _______|_______ # | | # HOUNDS (hound) SAMOYEDS (samoyed) # _______|_______ | # | | | # LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS # (little_hound) (mid_hound) (little_samoyed) registries = self._build_registry() DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS = registries[:4] assert DOGS._get_root_registry() is DOGS assert HOUNDS._get_root_registry() is DOGS assert LITTLE_HOUNDS._get_root_registry() is DOGS assert MID_HOUNDS._get_root_registry() is DOGS def test_get(self): # Hierarchical Registry # DOGS # _______|_______ # | | # HOUNDS (hound) SAMOYEDS (samoyed) # _______|_______ | # | | | # LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS # (little_hound) (mid_hound) (little_samoyed) registries = self._build_registry() DOGS, HOUNDS, LITTLE_HOUNDS = registries[:3] MID_HOUNDS, SAMOYEDS, LITTLE_SAMOYEDS = registries[3:] @DOGS.register_module() class GoldenRetriever: pass assert len(DOGS) == 1 assert DOGS.get('GoldenRetriever') is GoldenRetriever @HOUNDS.register_module() class BloodHound: pass assert len(HOUNDS) == 1 # get key from current registry assert HOUNDS.get('BloodHound') is BloodHound # get key from its children assert DOGS.get('hound.BloodHound') is BloodHound # get key from current registry assert HOUNDS.get('hound.BloodHound') is BloodHound # If the key is not found in the current registry, then look for its # parent assert HOUNDS.get('GoldenRetriever') is GoldenRetriever @LITTLE_HOUNDS.register_module() class Dachshund: pass assert len(LITTLE_HOUNDS) == 1 # get key from current registry assert LITTLE_HOUNDS.get('Dachshund') is Dachshund # get key from its parent assert LITTLE_HOUNDS.get('hound.BloodHound') is BloodHound # get key from its children assert HOUNDS.get('little_hound.Dachshund') is Dachshund # get key from its descendants assert DOGS.get('hound.little_hound.Dachshund') is Dachshund # If the key is not found in the current registry, then look for its # parent assert LITTLE_HOUNDS.get('BloodHound') is BloodHound assert LITTLE_HOUNDS.get('GoldenRetriever') is GoldenRetriever @MID_HOUNDS.register_module() class Beagle: pass # get key from its sibling registries assert LITTLE_HOUNDS.get('hound.mid_hound.Beagle') is Beagle @SAMOYEDS.register_module() class PedigreeSamoyed: pass assert len(SAMOYEDS) == 1 # get key from its uncle assert LITTLE_HOUNDS.get('samoyed.PedigreeSamoyed') is PedigreeSamoyed @LITTLE_SAMOYEDS.register_module() class LittlePedigreeSamoyed: pass # get key from its cousin assert LITTLE_HOUNDS.get('samoyed.little_samoyed.LittlePedigreeSamoyed' ) is LittlePedigreeSamoyed # get key from its nephews assert HOUNDS.get('samoyed.little_samoyed.LittlePedigreeSamoyed' ) is LittlePedigreeSamoyed # invalid keys # GoldenRetrieverererer can not be found at LITTLE_HOUNDS modules assert LITTLE_HOUNDS.get('GoldenRetrieverererer') is None # samoyedddd is not a child of DOGS assert DOGS.get('samoyedddd.PedigreeSamoyed') is None # samoyed is a child of DOGS but LittlePedigreeSamoyed can not be found # at SAMOYEDS modules assert DOGS.get('samoyed.LittlePedigreeSamoyed') is None assert LITTLE_HOUNDS.get('mid_hound.PedigreeSamoyedddddd') is None def test_search_child(self): # Hierarchical Registry # DOGS # _______|_______ # | | # HOUNDS (hound) SAMOYEDS (samoyed) # _______|_______ | # | | | # LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS # (little_hound) (mid_hound) (little_samoyed) registries = self._build_registry() DOGS, HOUNDS, LITTLE_HOUNDS = registries[:3] assert DOGS._search_child('hound') is HOUNDS assert DOGS._search_child('not a child') is None assert DOGS._search_child('little_hound') is LITTLE_HOUNDS assert LITTLE_HOUNDS._search_child('hound') is None assert LITTLE_HOUNDS._search_child('mid_hound') is None @pytest.mark.parametrize('cfg_type', [dict, ConfigDict, Config]) def test_build(self, cfg_type): # Hierarchical Registry # DOGS # _______|_______ # | | # HOUNDS (hound) SAMOYEDS (samoyed) # _______|_______ | # | | | # LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS # (little_hound) (mid_hound) (little_samoyed) registries = self._build_registry() DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS = registries[:4] @DOGS.register_module() class GoldenRetriever: pass gr_cfg = cfg_type(dict(type='GoldenRetriever')) assert isinstance(DOGS.build(gr_cfg), GoldenRetriever) @HOUNDS.register_module() class BloodHound: pass bh_cfg = cfg_type(dict(type='BloodHound')) assert isinstance(HOUNDS.build(bh_cfg), BloodHound) assert isinstance(HOUNDS.build(gr_cfg), GoldenRetriever) @LITTLE_HOUNDS.register_module() class Dachshund: pass d_cfg = cfg_type(dict(type='Dachshund')) assert isinstance(LITTLE_HOUNDS.build(d_cfg), Dachshund) @MID_HOUNDS.register_module() class Beagle: pass b_cfg = cfg_type(dict(type='Beagle')) assert isinstance(MID_HOUNDS.build(b_cfg), Beagle) # test `default_scope` # switch the current registry to another registry DefaultScope.get_instance( f'test-{time.time()}', scope_name='mid_hound') dog = LITTLE_HOUNDS.build(b_cfg) assert isinstance(dog, Beagle) # `default_scope` can not be found DefaultScope.get_instance( f'test2-{time.time()}', scope_name='scope-not-found') dog = MID_HOUNDS.build(b_cfg) assert isinstance(dog, Beagle) def test_repr(self): CATS = Registry('cat') @CATS.register_module() class BritishShorthair: pass @CATS.register_module() class Munchkin: pass repr_str = 'Registry(name=cat, items={' repr_str += ( "'BritishShorthair': <class 'test_registry.TestRegistry.test_repr." "<locals>.BritishShorthair'>, ") repr_str += ( "'Munchkin': <class 'test_registry.TestRegistry.test_repr." "<locals>.Munchkin'>") repr_str += '})' assert repr(CATS) == repr_str @pytest.mark.parametrize('cfg_type', [dict, ConfigDict, Config]) def test_build_from_cfg(cfg_type): BACKBONES = Registry('backbone') @BACKBONES.register_module() class ResNet: def __init__(self, depth, stages=4): self.depth = depth self.stages = stages @BACKBONES.register_module() class ResNeXt: def __init__(self, depth, stages=4): self.depth = depth self.stages = stages # test `cfg` parameter # `cfg` should be a dict, ConfigDict or Config object with pytest.raises( TypeError, match=('cfg should be a dict, ConfigDict or Config, but got ' "<class 'str'>")): cfg = 'ResNet' model = build_from_cfg(cfg, BACKBONES) # `cfg` is a dict, ConfigDict or Config object cfg = cfg_type(dict(type='ResNet', depth=50)) model = build_from_cfg(cfg, BACKBONES) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 4 # `cfg` is a dict but it does not contain the key "type" with pytest.raises(KeyError, match='must contain the key "type"'): cfg = dict(depth=50, stages=4) cfg = cfg_type(cfg) model = build_from_cfg(cfg, BACKBONES) # cfg['type'] should be a str or class with pytest.raises( TypeError, match="type must be a str or valid type, but got <class 'int'>"): cfg = dict(type=1000) cfg = cfg_type(cfg) model = build_from_cfg(cfg, BACKBONES) cfg = cfg_type(dict(type='ResNeXt', depth=50, stages=3)) model = build_from_cfg(cfg, BACKBONES) assert isinstance(model, ResNeXt) assert model.depth == 50 and model.stages == 3 cfg = cfg_type(dict(type=ResNet, depth=50)) model = build_from_cfg(cfg, BACKBONES) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 4 # non-registered class with pytest.raises(KeyError, match='VGG is not in the backbone registry'): cfg = cfg_type(dict(type='VGG')) model = build_from_cfg(cfg, BACKBONES) # `cfg` contains unexpected arguments with pytest.raises(TypeError): cfg = cfg_type(dict(type='ResNet', non_existing_arg=50)) model = build_from_cfg(cfg, BACKBONES) # test `default_args` parameter cfg = cfg_type(dict(type='ResNet', depth=50)) model = build_from_cfg(cfg, BACKBONES, cfg_type(dict(stages=3))) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 3 # default_args must be a dict or None with pytest.raises(TypeError): cfg = cfg_type(dict(type='ResNet', depth=50)) model = build_from_cfg(cfg, BACKBONES, default_args=1) # cfg or default_args should contain the key "type" with pytest.raises(KeyError, match='must contain the key "type"'): cfg = cfg_type(dict(depth=50)) model = build_from_cfg( cfg, BACKBONES, default_args=cfg_type(dict(stages=4))) # "type" defined using default_args cfg = cfg_type(dict(depth=50)) model = build_from_cfg( cfg, BACKBONES, default_args=cfg_type(dict(type='ResNet'))) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 4 cfg = cfg_type(dict(depth=50)) model = build_from_cfg( cfg, BACKBONES, default_args=cfg_type(dict(type=ResNet))) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 4 # test `registry` parameter # incorrect registry type with pytest.raises( TypeError, match=('registry must be a mmengine.Registry object, but got ' "<class 'str'>")): cfg = cfg_type(dict(type='ResNet', depth=50)) model = build_from_cfg(cfg, 'BACKBONES') VISUALIZER = Registry('visualizer') @VISUALIZER.register_module() class Visualizer(ManagerMixin): def __init__(self, name): super().__init__(name) with pytest.raises(RuntimeError): Visualizer.get_current_instance() cfg = dict(type='Visualizer', name='visualizer') build_from_cfg(cfg, VISUALIZER) Visualizer.get_current_instance()