From 6ba7ac3a8e15c121dcf07892e30c0130c0c53b5e Mon Sep 17 00:00:00 2001
From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Date: Tue, 15 Feb 2022 15:37:53 +0800
Subject: [PATCH] Add the unittest of registry (#5)

* Add the unittest of registry

* improve the format

* fix the test_repr

* refactor the test_init

* add the copyright for tests directory

* add more unit tests

* fix error

* add unit tests for _search_child and _get_root_registry

* build_from_cfg supports dict, ConfictDict and Config

* improve docstring
---
 .pre-commit-config.yaml              |   2 +-
 tests/test_registry/test_registry.py | 478 +++++++++++++++++++++++++++
 2 files changed, 479 insertions(+), 1 deletion(-)
 create mode 100644 tests/test_registry/test_registry.py

diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 46bdb7d8..fb7df7ad 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -48,7 +48,7 @@ repos:
     rev: v0.2.0
     hooks:
       - id: check-copyright
-        args: ["mmengine"]
+        args: ["mmengine", "tests"]
   - repo: https://github.com/pre-commit/mirrors-mypy
     rev: v0.812
     hooks:
diff --git a/tests/test_registry/test_registry.py b/tests/test_registry/test_registry.py
new file mode 100644
index 00000000..0f058cf6
--- /dev/null
+++ b/tests/test_registry/test_registry.py
@@ -0,0 +1,478 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import pytest
+
+from mmengine import Config, ConfigDict, Registry, build_from_cfg
+
+
+class TestRegistry:
+
+    def test_init(self):
+        CATS = Registry('cat')
+        assert CATS.name == 'cat'
+        assert CATS.module_dict == {}
+        assert CATS.build_func is build_from_cfg
+        assert len(CATS) == 0
+
+        # test `build_func` parameter
+        def build_func(cfg, registry, default_args):
+            pass
+
+        CATS = Registry('cat', build_func=build_func)
+        assert CATS.build_func is build_func
+
+        # test `parent` parameter
+        # `parent` is either None or a `Registry` instance
+        with pytest.raises(AssertionError):
+            CATS = Registry('little_cat', parent='cat', scope='little_cat')
+
+        LITTLECATS = Registry('little_cat', parent=CATS, scope='little_cat')
+        assert LITTLECATS.parent is CATS
+        assert CATS._children.get('little_cat') is LITTLECATS
+
+        # test `scope` parameter
+        # `scope` is either None or a string
+        with pytest.raises(AssertionError):
+            CATS = Registry('cat', scope=1)
+
+        CATS = Registry('cat')
+        assert CATS.scope == 'test_registry'
+
+        CATS = Registry('cat', scope='cat')
+        assert CATS.scope == 'cat'
+
+    def test_split_scope_key(self):
+        DOGS = Registry('dogs')
+
+        scope, key = DOGS.split_scope_key('BloodHound')
+        assert scope is None and key == 'BloodHound'
+        scope, key = DOGS.split_scope_key('hound.BloodHound')
+        assert scope == 'hound' and key == 'BloodHound'
+        scope, key = DOGS.split_scope_key('hound.little_hound.Dachshund')
+        assert scope == 'hound' and key == 'little_hound.Dachshund'
+
+    def test_register_module(self):
+        CATS = Registry('cat')
+
+        # can only decorate a class
+        with pytest.raises(TypeError):
+
+            @CATS.register_module()
+            def some_method():
+                pass
+
+        # test `name` parameter which must be either of None, a string or a
+        # sequence of string
+        # `name` is None
+        @CATS.register_module()
+        class BritishShorthair:
+            pass
+
+        assert len(CATS) == 1
+        assert CATS.get('BritishShorthair') is BritishShorthair
+
+        # `name` is a string
+        @CATS.register_module(name='Munchkin')
+        class Munchkin:
+            pass
+
+        assert len(CATS) == 2
+        assert CATS.get('Munchkin') is Munchkin
+        assert 'Munchkin' in CATS
+
+        # `name` is a sequence of string
+        @CATS.register_module(name=['Siamese', 'Siamese2'])
+        class SiameseCat:
+            pass
+
+        assert CATS.get('Siamese') is SiameseCat
+        assert CATS.get('Siamese2') is SiameseCat
+        assert len(CATS) == 4
+
+        # `name` is an invalid type
+        with pytest.raises(
+                TypeError,
+                match=('name must be either of None, an instance of str or a '
+                       "sequence of str, but got <class 'int'>")):
+
+            @CATS.register_module(name=7474741)
+            class SiameseCat:
+                pass
+
+        # test `force` parameter, which must be a boolean
+        # force is not a boolean
+        with pytest.raises(
+                TypeError,
+                match="force must be a boolean, but got <class 'int'>"):
+
+            @CATS.register_module(force=1)
+            class BritishShorthair:
+                pass
+
+        # force=False
+        with pytest.raises(
+                KeyError,
+                match='BritishShorthair is already registered in cat'):
+
+            @CATS.register_module()
+            class BritishShorthair:
+                pass
+
+        # force=True
+        @CATS.register_module(force=True)
+        class BritishShorthair:
+            pass
+
+        assert len(CATS) == 4
+
+        # test `module` parameter, which is either None or a class
+        # when the `register_module`` is called as a method rather than a
+        # decorator, which must be a class
+        with pytest.raises(
+                TypeError,
+                match="module must be a class, but got <class 'str'>"):
+            CATS.register_module(module='string')
+
+        class SphynxCat:
+            pass
+
+        CATS.register_module(module=SphynxCat)
+        assert CATS.get('SphynxCat') is SphynxCat
+        assert len(CATS) == 5
+
+        CATS.register_module(name='Sphynx1', module=SphynxCat)
+        assert CATS.get('Sphynx1') is SphynxCat
+        assert len(CATS) == 6
+
+        CATS.register_module(name=['Sphynx2', 'Sphynx3'], module=SphynxCat)
+        assert CATS.get('Sphynx2') is SphynxCat
+        assert CATS.get('Sphynx3') is SphynxCat
+        assert len(CATS) == 8
+
+    def _build_registry(self):
+        """A helper function to build a hierarchy registry."""
+        #        Hierarchy Registry
+        #                           DOGS
+        #                      _______|_______
+        #                     |               |
+        #            HOUNDS (hound)          SAMOYEDS (samoyed)
+        #           _______|_______                |
+        #          |               |               |
+        #     LITTLE_HOUNDS    MID_HOUNDS   LITTLE_SAMOYEDS
+        #     (little_hound)   (mid_hound)  (little_samoyed)
+        registries = []
+        DOGS = Registry('dogs')
+        registries.append(DOGS)
+        HOUNDS = Registry('dogs', parent=DOGS, scope='hound')
+        registries.append(HOUNDS)
+        LITTLE_HOUNDS = Registry('dogs', parent=HOUNDS, scope='little_hound')
+        registries.append(LITTLE_HOUNDS)
+        MID_HOUNDS = Registry('dogs', parent=HOUNDS, scope='mid_hound')
+        registries.append(MID_HOUNDS)
+        SAMOYEDS = Registry('dogs', parent=DOGS, scope='samoyed')
+        registries.append(SAMOYEDS)
+        LITTLE_SAMOYEDS = Registry(
+            'dogs', parent=SAMOYEDS, scope='little_samoyed')
+        registries.append(LITTLE_SAMOYEDS)
+
+        return registries
+
+    def test_get_root_registry(self):
+        #        Hierarchy Registry
+        #                           DOGS
+        #                      _______|_______
+        #                     |               |
+        #            HOUNDS (hound)          SAMOYEDS (samoyed)
+        #           _______|_______                |
+        #          |               |               |
+        #     LITTLE_HOUNDS    MID_HOUNDS   LITTLE_SAMOYEDS
+        #     (little_hound)   (mid_hound)  (little_samoyed)
+        DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS = self._build_registry()
+
+        assert DOGS._get_root_registry() is DOGS
+        assert HOUNDS._get_root_registry() is DOGS
+        assert LITTLE_HOUNDS._get_root_registry() is DOGS
+        assert MID_HOUNDS._get_root_registry() is DOGS
+
+    def test_get(self):
+        #        Hierarchy Registry
+        #                           DOGS
+        #                      _______|_______
+        #                     |               |
+        #            HOUNDS (hound)          SAMOYEDS (samoyed)
+        #           _______|_______                |
+        #          |               |               |
+        #     LITTLE_HOUNDS    MID_HOUNDS   LITTLE_SAMOYEDS
+        #     (little_hound)   (mid_hound)  (little_samoyed)
+        registries = self._build_registry()
+        DOGS, HOUNDS, LITTLE_HOUNDS = registries[:3]
+        MID_HOUNDS, SAMOYEDS, LITTLE_SAMOYEDS = registries[3:]
+
+        @DOGS.register_module()
+        class GoldenRetriever:
+            pass
+
+        assert len(DOGS) == 1
+        assert DOGS.get('GoldenRetriever') is GoldenRetriever
+
+        @HOUNDS.register_module()
+        class BloodHound:
+            pass
+
+        assert len(HOUNDS) == 1
+        # get key from current registry
+        assert HOUNDS.get('BloodHound') is BloodHound
+        # get key from its children
+        assert DOGS.get('hound.BloodHound') is BloodHound
+        # get key from current registry
+        assert HOUNDS.get('hound.BloodHound') is BloodHound
+
+        # If the key is not found in the current registry, then look for its
+        # parent
+        assert HOUNDS.get('GoldenRetriever') is GoldenRetriever
+
+        @LITTLE_HOUNDS.register_module()
+        class Dachshund:
+            pass
+
+        assert len(LITTLE_HOUNDS) == 1
+        # get key from current registry
+        assert LITTLE_HOUNDS.get('Dachshund') is Dachshund
+        # get key from its parent
+        assert LITTLE_HOUNDS.get('hound.BloodHound') is BloodHound
+        # get key from its children
+        assert HOUNDS.get('little_hound.Dachshund') is Dachshund
+        # get key from its descendants
+        assert DOGS.get('hound.little_hound.Dachshund') is Dachshund
+
+        # If the key is not found in the current registry, then look for its
+        # parent
+        assert LITTLE_HOUNDS.get('BloodHound') is BloodHound
+        assert LITTLE_HOUNDS.get('GoldenRetriever') is GoldenRetriever
+
+        @MID_HOUNDS.register_module()
+        class Beagle:
+            pass
+
+        # get key from its sibling registries
+        assert LITTLE_HOUNDS.get('hound.mid_hound.Beagle') is Beagle
+
+        @SAMOYEDS.register_module()
+        class PedigreeSamoyed:
+            pass
+
+        assert len(SAMOYEDS) == 1
+        # get key from its uncle
+        assert LITTLE_HOUNDS.get('samoyed.PedigreeSamoyed') is PedigreeSamoyed
+
+        @LITTLE_SAMOYEDS.register_module()
+        class LittlePedigreeSamoyed:
+            pass
+
+        # get key from its cousin
+        assert LITTLE_HOUNDS.get('samoyed.little_samoyed.LittlePedigreeSamoyed'
+                                 ) is LittlePedigreeSamoyed
+
+        # get key from its nephews
+        assert HOUNDS.get('samoyed.little_samoyed.LittlePedigreeSamoyed'
+                          ) is LittlePedigreeSamoyed
+
+    def test_search_child(self):
+        #        Hierarchy Registry
+        #                           DOGS
+        #                      _______|_______
+        #                     |               |
+        #            HOUNDS (hound)          SAMOYEDS (samoyed)
+        #           _______|_______                |
+        #          |               |               |
+        #     LITTLE_HOUNDS    MID_HOUNDS   LITTLE_SAMOYEDS
+        #     (little_hound)   (mid_hound)  (little_samoyed)
+        DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS = self._build_registry()
+
+        assert DOGS._search_child('hound') is HOUNDS
+        assert DOGS._search_child('not a child') is None
+        assert DOGS._search_child('little_hound') is LITTLE_HOUNDS
+        assert LITTLE_HOUNDS._search_child('hound') is None
+        assert LITTLE_HOUNDS._search_child('mid_hound') is None
+
+    def test_build(self):
+        #        Hierarchy Registry
+        #                           DOGS
+        #                      _______|_______
+        #                     |               |
+        #            HOUNDS (hound)          SAMOYEDS (samoyed)
+        #           _______|_______                |
+        #          |               |               |
+        #     LITTLE_HOUNDS    MID_HOUNDS   LITTLE_SAMOYEDS
+        #     (little_hound)   (mid_hound)  (little_samoyed)
+        registries = self._build_registry()
+        DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS = registries[:4]
+
+        @DOGS.register_module()
+        class GoldenRetriever:
+            pass
+
+        gr_cfg = dict(type='GoldenRetriever')
+        assert isinstance(DOGS.build(gr_cfg), GoldenRetriever)
+
+        @HOUNDS.register_module()
+        class BloodHound:
+            pass
+
+        bh_cfg = dict(type='BloodHound')
+        assert isinstance(HOUNDS.build(bh_cfg), BloodHound)
+        assert isinstance(HOUNDS.build(gr_cfg), GoldenRetriever)
+
+        @LITTLE_HOUNDS.register_module()
+        class Dachshund:
+            pass
+
+        d_cfg = dict(type='Dachshund')
+        assert isinstance(LITTLE_HOUNDS.build(d_cfg), Dachshund)
+
+        @MID_HOUNDS.register_module()
+        class Beagle:
+            pass
+
+        b_cfg = dict(type='Beagle')
+        assert isinstance(MID_HOUNDS.build(b_cfg), Beagle)
+
+        # test `default_scope`
+        # `default_scope` is an invalid scope
+        with pytest.raises(KeyError):
+            LITTLE_HOUNDS.build(b_cfg, default_scope='invalid_mid_hound')
+
+        # switch the current registry to another registry
+        dog = LITTLE_HOUNDS.build(b_cfg, default_scope='mid_hound')
+        assert isinstance(dog, Beagle)
+
+    def test_repr(self):
+        CATS = Registry('cat')
+
+        @CATS.register_module()
+        class BritishShorthair:
+            pass
+
+        @CATS.register_module()
+        class Munchkin:
+            pass
+
+        repr_str = 'Registry(name=cat, items={'
+        repr_str += (
+            "'BritishShorthair': <class 'test_registry.TestRegistry.test_repr."
+            "<locals>.BritishShorthair'>, ")
+        repr_str += (
+            "'Munchkin': <class 'test_registry.TestRegistry.test_repr."
+            "<locals>.Munchkin'>")
+        repr_str += '})'
+        assert repr(CATS) == repr_str
+
+
+def test_build_from_cfg():
+    BACKBONES = Registry('backbone')
+
+    @BACKBONES.register_module()
+    class ResNet:
+
+        def __init__(self, depth, stages=4):
+            self.depth = depth
+            self.stages = stages
+
+    @BACKBONES.register_module()
+    class ResNeXt:
+
+        def __init__(self, depth, stages=4):
+            self.depth = depth
+            self.stages = stages
+
+    # test `cfg` parameter
+    # `cfg` should be a dict, ConfigDict or Config object
+    with pytest.raises(
+            TypeError, match="cfg must be a dict, but got <class 'str'>"):
+        cfg = 'ResNet'
+        model = build_from_cfg(cfg, BACKBONES)
+
+    # `cfg` is a dict
+    cfg = dict(type='ResNet', depth=50)
+    model = build_from_cfg(cfg, BACKBONES)
+    assert isinstance(model, ResNet)
+    assert model.depth == 50 and model.stages == 4
+
+    # `cfg` is a ConfigDict object
+    cfg = ConfigDict(dict(type='ResNet', depth=50))
+    model = build_from_cfg(cfg, BACKBONES)
+    assert isinstance(model, ResNet)
+    assert model.depth == 50 and model.stages == 4
+
+    # `cfg` is a Config object
+    cfg = Config(dict(type='ResNet', depth=50))
+    model = build_from_cfg(cfg, BACKBONES)
+    assert isinstance(model, ResNet)
+    assert model.depth == 50 and model.stages == 4
+
+    # `cfg` is a dict but it does not contain the key "type"
+    with pytest.raises(KeyError, match='must contain the key "type"'):
+        cfg = dict(depth=50, stages=4)
+        model = build_from_cfg(cfg, BACKBONES)
+
+    # cfg['type'] should be a str or class
+    with pytest.raises(
+            TypeError,
+            match="type must be a str or valid type, but got <class 'int'>"):
+        cfg = dict(type=1000)
+        model = build_from_cfg(cfg, BACKBONES)
+
+    cfg = dict(type='ResNeXt', depth=50, stages=3)
+    model = build_from_cfg(cfg, BACKBONES)
+    assert isinstance(model, ResNeXt)
+    assert model.depth == 50 and model.stages == 3
+
+    cfg = dict(type=ResNet, depth=50)
+    model = build_from_cfg(cfg, BACKBONES)
+    assert isinstance(model, ResNet)
+    assert model.depth == 50 and model.stages == 4
+
+    # non-registered class
+    with pytest.raises(KeyError, match='VGG is not in the backbone registry'):
+        cfg = dict(type='VGG')
+        model = build_from_cfg(cfg, BACKBONES)
+
+    # `cfg` contains unexpected arguments
+    with pytest.raises(TypeError):
+        cfg = dict(type='ResNet', non_existing_arg=50)
+        model = build_from_cfg(cfg, BACKBONES)
+
+    # test `default_args` parameter
+    cfg = dict(type='ResNet', depth=50)
+    model = build_from_cfg(cfg, BACKBONES, default_args={'stages': 3})
+    assert isinstance(model, ResNet)
+    assert model.depth == 50 and model.stages == 3
+
+    # default_args must be a dict or None
+    with pytest.raises(TypeError):
+        cfg = dict(type='ResNet', depth=50)
+        model = build_from_cfg(cfg, BACKBONES, default_args=1)
+
+    # cfg or default_args should contain the key "type"
+    with pytest.raises(KeyError, match='must contain the key "type"'):
+        cfg = dict(depth=50)
+        model = build_from_cfg(cfg, BACKBONES, default_args=dict(stages=4))
+
+    # "type" defined using default_args
+    cfg = dict(depth=50)
+    model = build_from_cfg(cfg, BACKBONES, default_args=dict(type='ResNet'))
+    assert isinstance(model, ResNet)
+    assert model.depth == 50 and model.stages == 4
+
+    cfg = dict(depth=50)
+    model = build_from_cfg(cfg, BACKBONES, default_args=dict(type=ResNet))
+    assert isinstance(model, ResNet)
+    assert model.depth == 50 and model.stages == 4
+
+    # test `registry` parameter
+    # incorrect registry type
+    with pytest.raises(
+            TypeError,
+            match=('registry must be a mmengine.Registry object, but got '
+                   "<class 'str'>")):
+        cfg = dict(type='ResNet', depth=50)
+        model = build_from_cfg(cfg, 'BACKBONES')
-- 
GitLab