From 6ba7ac3a8e15c121dcf07892e30c0130c0c53b5e Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Tue, 15 Feb 2022 15:37:53 +0800 Subject: [PATCH] Add the unittest of registry (#5) * Add the unittest of registry * improve the format * fix the test_repr * refactor the test_init * add the copyright for tests directory * add more unit tests * fix error * add unit tests for _search_child and _get_root_registry * build_from_cfg supports dict, ConfictDict and Config * improve docstring --- .pre-commit-config.yaml | 2 +- tests/test_registry/test_registry.py | 478 +++++++++++++++++++++++++++ 2 files changed, 479 insertions(+), 1 deletion(-) create mode 100644 tests/test_registry/test_registry.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 46bdb7d8..fb7df7ad 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -48,7 +48,7 @@ repos: rev: v0.2.0 hooks: - id: check-copyright - args: ["mmengine"] + args: ["mmengine", "tests"] - repo: https://github.com/pre-commit/mirrors-mypy rev: v0.812 hooks: diff --git a/tests/test_registry/test_registry.py b/tests/test_registry/test_registry.py new file mode 100644 index 00000000..0f058cf6 --- /dev/null +++ b/tests/test_registry/test_registry.py @@ -0,0 +1,478 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest + +from mmengine import Config, ConfigDict, Registry, build_from_cfg + + +class TestRegistry: + + def test_init(self): + CATS = Registry('cat') + assert CATS.name == 'cat' + assert CATS.module_dict == {} + assert CATS.build_func is build_from_cfg + assert len(CATS) == 0 + + # test `build_func` parameter + def build_func(cfg, registry, default_args): + pass + + CATS = Registry('cat', build_func=build_func) + assert CATS.build_func is build_func + + # test `parent` parameter + # `parent` is either None or a `Registry` instance + with pytest.raises(AssertionError): + CATS = Registry('little_cat', parent='cat', scope='little_cat') + + LITTLECATS = Registry('little_cat', parent=CATS, scope='little_cat') + assert LITTLECATS.parent is CATS + assert CATS._children.get('little_cat') is LITTLECATS + + # test `scope` parameter + # `scope` is either None or a string + with pytest.raises(AssertionError): + CATS = Registry('cat', scope=1) + + CATS = Registry('cat') + assert CATS.scope == 'test_registry' + + CATS = Registry('cat', scope='cat') + assert CATS.scope == 'cat' + + def test_split_scope_key(self): + DOGS = Registry('dogs') + + scope, key = DOGS.split_scope_key('BloodHound') + assert scope is None and key == 'BloodHound' + scope, key = DOGS.split_scope_key('hound.BloodHound') + assert scope == 'hound' and key == 'BloodHound' + scope, key = DOGS.split_scope_key('hound.little_hound.Dachshund') + assert scope == 'hound' and key == 'little_hound.Dachshund' + + def test_register_module(self): + CATS = Registry('cat') + + # can only decorate a class + with pytest.raises(TypeError): + + @CATS.register_module() + def some_method(): + pass + + # test `name` parameter which must be either of None, a string or a + # sequence of string + # `name` is None + @CATS.register_module() + class BritishShorthair: + pass + + assert len(CATS) == 1 + assert CATS.get('BritishShorthair') is BritishShorthair + + # `name` is a string + @CATS.register_module(name='Munchkin') + class Munchkin: + pass + + assert len(CATS) == 2 + assert CATS.get('Munchkin') is Munchkin + assert 'Munchkin' in CATS + + # `name` is a sequence of string + @CATS.register_module(name=['Siamese', 'Siamese2']) + class SiameseCat: + pass + + assert CATS.get('Siamese') is SiameseCat + assert CATS.get('Siamese2') is SiameseCat + assert len(CATS) == 4 + + # `name` is an invalid type + with pytest.raises( + TypeError, + match=('name must be either of None, an instance of str or a ' + "sequence of str, but got <class 'int'>")): + + @CATS.register_module(name=7474741) + class SiameseCat: + pass + + # test `force` parameter, which must be a boolean + # force is not a boolean + with pytest.raises( + TypeError, + match="force must be a boolean, but got <class 'int'>"): + + @CATS.register_module(force=1) + class BritishShorthair: + pass + + # force=False + with pytest.raises( + KeyError, + match='BritishShorthair is already registered in cat'): + + @CATS.register_module() + class BritishShorthair: + pass + + # force=True + @CATS.register_module(force=True) + class BritishShorthair: + pass + + assert len(CATS) == 4 + + # test `module` parameter, which is either None or a class + # when the `register_module`` is called as a method rather than a + # decorator, which must be a class + with pytest.raises( + TypeError, + match="module must be a class, but got <class 'str'>"): + CATS.register_module(module='string') + + class SphynxCat: + pass + + CATS.register_module(module=SphynxCat) + assert CATS.get('SphynxCat') is SphynxCat + assert len(CATS) == 5 + + CATS.register_module(name='Sphynx1', module=SphynxCat) + assert CATS.get('Sphynx1') is SphynxCat + assert len(CATS) == 6 + + CATS.register_module(name=['Sphynx2', 'Sphynx3'], module=SphynxCat) + assert CATS.get('Sphynx2') is SphynxCat + assert CATS.get('Sphynx3') is SphynxCat + assert len(CATS) == 8 + + def _build_registry(self): + """A helper function to build a hierarchy registry.""" + # Hierarchy Registry + # DOGS + # _______|_______ + # | | + # HOUNDS (hound) SAMOYEDS (samoyed) + # _______|_______ | + # | | | + # LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS + # (little_hound) (mid_hound) (little_samoyed) + registries = [] + DOGS = Registry('dogs') + registries.append(DOGS) + HOUNDS = Registry('dogs', parent=DOGS, scope='hound') + registries.append(HOUNDS) + LITTLE_HOUNDS = Registry('dogs', parent=HOUNDS, scope='little_hound') + registries.append(LITTLE_HOUNDS) + MID_HOUNDS = Registry('dogs', parent=HOUNDS, scope='mid_hound') + registries.append(MID_HOUNDS) + SAMOYEDS = Registry('dogs', parent=DOGS, scope='samoyed') + registries.append(SAMOYEDS) + LITTLE_SAMOYEDS = Registry( + 'dogs', parent=SAMOYEDS, scope='little_samoyed') + registries.append(LITTLE_SAMOYEDS) + + return registries + + def test_get_root_registry(self): + # Hierarchy Registry + # DOGS + # _______|_______ + # | | + # HOUNDS (hound) SAMOYEDS (samoyed) + # _______|_______ | + # | | | + # LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS + # (little_hound) (mid_hound) (little_samoyed) + DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS = self._build_registry() + + assert DOGS._get_root_registry() is DOGS + assert HOUNDS._get_root_registry() is DOGS + assert LITTLE_HOUNDS._get_root_registry() is DOGS + assert MID_HOUNDS._get_root_registry() is DOGS + + def test_get(self): + # Hierarchy Registry + # DOGS + # _______|_______ + # | | + # HOUNDS (hound) SAMOYEDS (samoyed) + # _______|_______ | + # | | | + # LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS + # (little_hound) (mid_hound) (little_samoyed) + registries = self._build_registry() + DOGS, HOUNDS, LITTLE_HOUNDS = registries[:3] + MID_HOUNDS, SAMOYEDS, LITTLE_SAMOYEDS = registries[3:] + + @DOGS.register_module() + class GoldenRetriever: + pass + + assert len(DOGS) == 1 + assert DOGS.get('GoldenRetriever') is GoldenRetriever + + @HOUNDS.register_module() + class BloodHound: + pass + + assert len(HOUNDS) == 1 + # get key from current registry + assert HOUNDS.get('BloodHound') is BloodHound + # get key from its children + assert DOGS.get('hound.BloodHound') is BloodHound + # get key from current registry + assert HOUNDS.get('hound.BloodHound') is BloodHound + + # If the key is not found in the current registry, then look for its + # parent + assert HOUNDS.get('GoldenRetriever') is GoldenRetriever + + @LITTLE_HOUNDS.register_module() + class Dachshund: + pass + + assert len(LITTLE_HOUNDS) == 1 + # get key from current registry + assert LITTLE_HOUNDS.get('Dachshund') is Dachshund + # get key from its parent + assert LITTLE_HOUNDS.get('hound.BloodHound') is BloodHound + # get key from its children + assert HOUNDS.get('little_hound.Dachshund') is Dachshund + # get key from its descendants + assert DOGS.get('hound.little_hound.Dachshund') is Dachshund + + # If the key is not found in the current registry, then look for its + # parent + assert LITTLE_HOUNDS.get('BloodHound') is BloodHound + assert LITTLE_HOUNDS.get('GoldenRetriever') is GoldenRetriever + + @MID_HOUNDS.register_module() + class Beagle: + pass + + # get key from its sibling registries + assert LITTLE_HOUNDS.get('hound.mid_hound.Beagle') is Beagle + + @SAMOYEDS.register_module() + class PedigreeSamoyed: + pass + + assert len(SAMOYEDS) == 1 + # get key from its uncle + assert LITTLE_HOUNDS.get('samoyed.PedigreeSamoyed') is PedigreeSamoyed + + @LITTLE_SAMOYEDS.register_module() + class LittlePedigreeSamoyed: + pass + + # get key from its cousin + assert LITTLE_HOUNDS.get('samoyed.little_samoyed.LittlePedigreeSamoyed' + ) is LittlePedigreeSamoyed + + # get key from its nephews + assert HOUNDS.get('samoyed.little_samoyed.LittlePedigreeSamoyed' + ) is LittlePedigreeSamoyed + + def test_search_child(self): + # Hierarchy Registry + # DOGS + # _______|_______ + # | | + # HOUNDS (hound) SAMOYEDS (samoyed) + # _______|_______ | + # | | | + # LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS + # (little_hound) (mid_hound) (little_samoyed) + DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS = self._build_registry() + + assert DOGS._search_child('hound') is HOUNDS + assert DOGS._search_child('not a child') is None + assert DOGS._search_child('little_hound') is LITTLE_HOUNDS + assert LITTLE_HOUNDS._search_child('hound') is None + assert LITTLE_HOUNDS._search_child('mid_hound') is None + + def test_build(self): + # Hierarchy Registry + # DOGS + # _______|_______ + # | | + # HOUNDS (hound) SAMOYEDS (samoyed) + # _______|_______ | + # | | | + # LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS + # (little_hound) (mid_hound) (little_samoyed) + registries = self._build_registry() + DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS = registries[:4] + + @DOGS.register_module() + class GoldenRetriever: + pass + + gr_cfg = dict(type='GoldenRetriever') + assert isinstance(DOGS.build(gr_cfg), GoldenRetriever) + + @HOUNDS.register_module() + class BloodHound: + pass + + bh_cfg = dict(type='BloodHound') + assert isinstance(HOUNDS.build(bh_cfg), BloodHound) + assert isinstance(HOUNDS.build(gr_cfg), GoldenRetriever) + + @LITTLE_HOUNDS.register_module() + class Dachshund: + pass + + d_cfg = dict(type='Dachshund') + assert isinstance(LITTLE_HOUNDS.build(d_cfg), Dachshund) + + @MID_HOUNDS.register_module() + class Beagle: + pass + + b_cfg = dict(type='Beagle') + assert isinstance(MID_HOUNDS.build(b_cfg), Beagle) + + # test `default_scope` + # `default_scope` is an invalid scope + with pytest.raises(KeyError): + LITTLE_HOUNDS.build(b_cfg, default_scope='invalid_mid_hound') + + # switch the current registry to another registry + dog = LITTLE_HOUNDS.build(b_cfg, default_scope='mid_hound') + assert isinstance(dog, Beagle) + + def test_repr(self): + CATS = Registry('cat') + + @CATS.register_module() + class BritishShorthair: + pass + + @CATS.register_module() + class Munchkin: + pass + + repr_str = 'Registry(name=cat, items={' + repr_str += ( + "'BritishShorthair': <class 'test_registry.TestRegistry.test_repr." + "<locals>.BritishShorthair'>, ") + repr_str += ( + "'Munchkin': <class 'test_registry.TestRegistry.test_repr." + "<locals>.Munchkin'>") + repr_str += '})' + assert repr(CATS) == repr_str + + +def test_build_from_cfg(): + BACKBONES = Registry('backbone') + + @BACKBONES.register_module() + class ResNet: + + def __init__(self, depth, stages=4): + self.depth = depth + self.stages = stages + + @BACKBONES.register_module() + class ResNeXt: + + def __init__(self, depth, stages=4): + self.depth = depth + self.stages = stages + + # test `cfg` parameter + # `cfg` should be a dict, ConfigDict or Config object + with pytest.raises( + TypeError, match="cfg must be a dict, but got <class 'str'>"): + cfg = 'ResNet' + model = build_from_cfg(cfg, BACKBONES) + + # `cfg` is a dict + cfg = dict(type='ResNet', depth=50) + model = build_from_cfg(cfg, BACKBONES) + assert isinstance(model, ResNet) + assert model.depth == 50 and model.stages == 4 + + # `cfg` is a ConfigDict object + cfg = ConfigDict(dict(type='ResNet', depth=50)) + model = build_from_cfg(cfg, BACKBONES) + assert isinstance(model, ResNet) + assert model.depth == 50 and model.stages == 4 + + # `cfg` is a Config object + cfg = Config(dict(type='ResNet', depth=50)) + model = build_from_cfg(cfg, BACKBONES) + assert isinstance(model, ResNet) + assert model.depth == 50 and model.stages == 4 + + # `cfg` is a dict but it does not contain the key "type" + with pytest.raises(KeyError, match='must contain the key "type"'): + cfg = dict(depth=50, stages=4) + model = build_from_cfg(cfg, BACKBONES) + + # cfg['type'] should be a str or class + with pytest.raises( + TypeError, + match="type must be a str or valid type, but got <class 'int'>"): + cfg = dict(type=1000) + model = build_from_cfg(cfg, BACKBONES) + + cfg = dict(type='ResNeXt', depth=50, stages=3) + model = build_from_cfg(cfg, BACKBONES) + assert isinstance(model, ResNeXt) + assert model.depth == 50 and model.stages == 3 + + cfg = dict(type=ResNet, depth=50) + model = build_from_cfg(cfg, BACKBONES) + assert isinstance(model, ResNet) + assert model.depth == 50 and model.stages == 4 + + # non-registered class + with pytest.raises(KeyError, match='VGG is not in the backbone registry'): + cfg = dict(type='VGG') + model = build_from_cfg(cfg, BACKBONES) + + # `cfg` contains unexpected arguments + with pytest.raises(TypeError): + cfg = dict(type='ResNet', non_existing_arg=50) + model = build_from_cfg(cfg, BACKBONES) + + # test `default_args` parameter + cfg = dict(type='ResNet', depth=50) + model = build_from_cfg(cfg, BACKBONES, default_args={'stages': 3}) + assert isinstance(model, ResNet) + assert model.depth == 50 and model.stages == 3 + + # default_args must be a dict or None + with pytest.raises(TypeError): + cfg = dict(type='ResNet', depth=50) + model = build_from_cfg(cfg, BACKBONES, default_args=1) + + # cfg or default_args should contain the key "type" + with pytest.raises(KeyError, match='must contain the key "type"'): + cfg = dict(depth=50) + model = build_from_cfg(cfg, BACKBONES, default_args=dict(stages=4)) + + # "type" defined using default_args + cfg = dict(depth=50) + model = build_from_cfg(cfg, BACKBONES, default_args=dict(type='ResNet')) + assert isinstance(model, ResNet) + assert model.depth == 50 and model.stages == 4 + + cfg = dict(depth=50) + model = build_from_cfg(cfg, BACKBONES, default_args=dict(type=ResNet)) + assert isinstance(model, ResNet) + assert model.depth == 50 and model.stages == 4 + + # test `registry` parameter + # incorrect registry type + with pytest.raises( + TypeError, + match=('registry must be a mmengine.Registry object, but got ' + "<class 'str'>")): + cfg = dict(type='ResNet', depth=50) + model = build_from_cfg(cfg, 'BACKBONES') -- GitLab