Skip to content
Snippets Groups Projects
Unverified Commit 6ba7ac3a authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

Add the unittest of registry (#5)

* Add the unittest of registry

* improve the format

* fix the test_repr

* refactor the test_init

* add the copyright for tests directory

* add more unit tests

* fix error

* add unit tests for _search_child and _get_root_registry

* build_from_cfg supports dict, ConfictDict and Config

* improve docstring
parent 27dd6175
No related branches found
No related tags found
No related merge requests found
......@@ -48,7 +48,7 @@ repos:
rev: v0.2.0
hooks:
- id: check-copyright
args: ["mmengine"]
args: ["mmengine", "tests"]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.812
hooks:
......
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
from mmengine import Config, ConfigDict, Registry, build_from_cfg
class TestRegistry:
def test_init(self):
CATS = Registry('cat')
assert CATS.name == 'cat'
assert CATS.module_dict == {}
assert CATS.build_func is build_from_cfg
assert len(CATS) == 0
# test `build_func` parameter
def build_func(cfg, registry, default_args):
pass
CATS = Registry('cat', build_func=build_func)
assert CATS.build_func is build_func
# test `parent` parameter
# `parent` is either None or a `Registry` instance
with pytest.raises(AssertionError):
CATS = Registry('little_cat', parent='cat', scope='little_cat')
LITTLECATS = Registry('little_cat', parent=CATS, scope='little_cat')
assert LITTLECATS.parent is CATS
assert CATS._children.get('little_cat') is LITTLECATS
# test `scope` parameter
# `scope` is either None or a string
with pytest.raises(AssertionError):
CATS = Registry('cat', scope=1)
CATS = Registry('cat')
assert CATS.scope == 'test_registry'
CATS = Registry('cat', scope='cat')
assert CATS.scope == 'cat'
def test_split_scope_key(self):
DOGS = Registry('dogs')
scope, key = DOGS.split_scope_key('BloodHound')
assert scope is None and key == 'BloodHound'
scope, key = DOGS.split_scope_key('hound.BloodHound')
assert scope == 'hound' and key == 'BloodHound'
scope, key = DOGS.split_scope_key('hound.little_hound.Dachshund')
assert scope == 'hound' and key == 'little_hound.Dachshund'
def test_register_module(self):
CATS = Registry('cat')
# can only decorate a class
with pytest.raises(TypeError):
@CATS.register_module()
def some_method():
pass
# test `name` parameter which must be either of None, a string or a
# sequence of string
# `name` is None
@CATS.register_module()
class BritishShorthair:
pass
assert len(CATS) == 1
assert CATS.get('BritishShorthair') is BritishShorthair
# `name` is a string
@CATS.register_module(name='Munchkin')
class Munchkin:
pass
assert len(CATS) == 2
assert CATS.get('Munchkin') is Munchkin
assert 'Munchkin' in CATS
# `name` is a sequence of string
@CATS.register_module(name=['Siamese', 'Siamese2'])
class SiameseCat:
pass
assert CATS.get('Siamese') is SiameseCat
assert CATS.get('Siamese2') is SiameseCat
assert len(CATS) == 4
# `name` is an invalid type
with pytest.raises(
TypeError,
match=('name must be either of None, an instance of str or a '
"sequence of str, but got <class 'int'>")):
@CATS.register_module(name=7474741)
class SiameseCat:
pass
# test `force` parameter, which must be a boolean
# force is not a boolean
with pytest.raises(
TypeError,
match="force must be a boolean, but got <class 'int'>"):
@CATS.register_module(force=1)
class BritishShorthair:
pass
# force=False
with pytest.raises(
KeyError,
match='BritishShorthair is already registered in cat'):
@CATS.register_module()
class BritishShorthair:
pass
# force=True
@CATS.register_module(force=True)
class BritishShorthair:
pass
assert len(CATS) == 4
# test `module` parameter, which is either None or a class
# when the `register_module`` is called as a method rather than a
# decorator, which must be a class
with pytest.raises(
TypeError,
match="module must be a class, but got <class 'str'>"):
CATS.register_module(module='string')
class SphynxCat:
pass
CATS.register_module(module=SphynxCat)
assert CATS.get('SphynxCat') is SphynxCat
assert len(CATS) == 5
CATS.register_module(name='Sphynx1', module=SphynxCat)
assert CATS.get('Sphynx1') is SphynxCat
assert len(CATS) == 6
CATS.register_module(name=['Sphynx2', 'Sphynx3'], module=SphynxCat)
assert CATS.get('Sphynx2') is SphynxCat
assert CATS.get('Sphynx3') is SphynxCat
assert len(CATS) == 8
def _build_registry(self):
"""A helper function to build a hierarchy registry."""
# Hierarchy Registry
# DOGS
# _______|_______
# | |
# HOUNDS (hound) SAMOYEDS (samoyed)
# _______|_______ |
# | | |
# LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS
# (little_hound) (mid_hound) (little_samoyed)
registries = []
DOGS = Registry('dogs')
registries.append(DOGS)
HOUNDS = Registry('dogs', parent=DOGS, scope='hound')
registries.append(HOUNDS)
LITTLE_HOUNDS = Registry('dogs', parent=HOUNDS, scope='little_hound')
registries.append(LITTLE_HOUNDS)
MID_HOUNDS = Registry('dogs', parent=HOUNDS, scope='mid_hound')
registries.append(MID_HOUNDS)
SAMOYEDS = Registry('dogs', parent=DOGS, scope='samoyed')
registries.append(SAMOYEDS)
LITTLE_SAMOYEDS = Registry(
'dogs', parent=SAMOYEDS, scope='little_samoyed')
registries.append(LITTLE_SAMOYEDS)
return registries
def test_get_root_registry(self):
# Hierarchy Registry
# DOGS
# _______|_______
# | |
# HOUNDS (hound) SAMOYEDS (samoyed)
# _______|_______ |
# | | |
# LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS
# (little_hound) (mid_hound) (little_samoyed)
DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS = self._build_registry()
assert DOGS._get_root_registry() is DOGS
assert HOUNDS._get_root_registry() is DOGS
assert LITTLE_HOUNDS._get_root_registry() is DOGS
assert MID_HOUNDS._get_root_registry() is DOGS
def test_get(self):
# Hierarchy Registry
# DOGS
# _______|_______
# | |
# HOUNDS (hound) SAMOYEDS (samoyed)
# _______|_______ |
# | | |
# LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS
# (little_hound) (mid_hound) (little_samoyed)
registries = self._build_registry()
DOGS, HOUNDS, LITTLE_HOUNDS = registries[:3]
MID_HOUNDS, SAMOYEDS, LITTLE_SAMOYEDS = registries[3:]
@DOGS.register_module()
class GoldenRetriever:
pass
assert len(DOGS) == 1
assert DOGS.get('GoldenRetriever') is GoldenRetriever
@HOUNDS.register_module()
class BloodHound:
pass
assert len(HOUNDS) == 1
# get key from current registry
assert HOUNDS.get('BloodHound') is BloodHound
# get key from its children
assert DOGS.get('hound.BloodHound') is BloodHound
# get key from current registry
assert HOUNDS.get('hound.BloodHound') is BloodHound
# If the key is not found in the current registry, then look for its
# parent
assert HOUNDS.get('GoldenRetriever') is GoldenRetriever
@LITTLE_HOUNDS.register_module()
class Dachshund:
pass
assert len(LITTLE_HOUNDS) == 1
# get key from current registry
assert LITTLE_HOUNDS.get('Dachshund') is Dachshund
# get key from its parent
assert LITTLE_HOUNDS.get('hound.BloodHound') is BloodHound
# get key from its children
assert HOUNDS.get('little_hound.Dachshund') is Dachshund
# get key from its descendants
assert DOGS.get('hound.little_hound.Dachshund') is Dachshund
# If the key is not found in the current registry, then look for its
# parent
assert LITTLE_HOUNDS.get('BloodHound') is BloodHound
assert LITTLE_HOUNDS.get('GoldenRetriever') is GoldenRetriever
@MID_HOUNDS.register_module()
class Beagle:
pass
# get key from its sibling registries
assert LITTLE_HOUNDS.get('hound.mid_hound.Beagle') is Beagle
@SAMOYEDS.register_module()
class PedigreeSamoyed:
pass
assert len(SAMOYEDS) == 1
# get key from its uncle
assert LITTLE_HOUNDS.get('samoyed.PedigreeSamoyed') is PedigreeSamoyed
@LITTLE_SAMOYEDS.register_module()
class LittlePedigreeSamoyed:
pass
# get key from its cousin
assert LITTLE_HOUNDS.get('samoyed.little_samoyed.LittlePedigreeSamoyed'
) is LittlePedigreeSamoyed
# get key from its nephews
assert HOUNDS.get('samoyed.little_samoyed.LittlePedigreeSamoyed'
) is LittlePedigreeSamoyed
def test_search_child(self):
# Hierarchy Registry
# DOGS
# _______|_______
# | |
# HOUNDS (hound) SAMOYEDS (samoyed)
# _______|_______ |
# | | |
# LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS
# (little_hound) (mid_hound) (little_samoyed)
DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS = self._build_registry()
assert DOGS._search_child('hound') is HOUNDS
assert DOGS._search_child('not a child') is None
assert DOGS._search_child('little_hound') is LITTLE_HOUNDS
assert LITTLE_HOUNDS._search_child('hound') is None
assert LITTLE_HOUNDS._search_child('mid_hound') is None
def test_build(self):
# Hierarchy Registry
# DOGS
# _______|_______
# | |
# HOUNDS (hound) SAMOYEDS (samoyed)
# _______|_______ |
# | | |
# LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS
# (little_hound) (mid_hound) (little_samoyed)
registries = self._build_registry()
DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS = registries[:4]
@DOGS.register_module()
class GoldenRetriever:
pass
gr_cfg = dict(type='GoldenRetriever')
assert isinstance(DOGS.build(gr_cfg), GoldenRetriever)
@HOUNDS.register_module()
class BloodHound:
pass
bh_cfg = dict(type='BloodHound')
assert isinstance(HOUNDS.build(bh_cfg), BloodHound)
assert isinstance(HOUNDS.build(gr_cfg), GoldenRetriever)
@LITTLE_HOUNDS.register_module()
class Dachshund:
pass
d_cfg = dict(type='Dachshund')
assert isinstance(LITTLE_HOUNDS.build(d_cfg), Dachshund)
@MID_HOUNDS.register_module()
class Beagle:
pass
b_cfg = dict(type='Beagle')
assert isinstance(MID_HOUNDS.build(b_cfg), Beagle)
# test `default_scope`
# `default_scope` is an invalid scope
with pytest.raises(KeyError):
LITTLE_HOUNDS.build(b_cfg, default_scope='invalid_mid_hound')
# switch the current registry to another registry
dog = LITTLE_HOUNDS.build(b_cfg, default_scope='mid_hound')
assert isinstance(dog, Beagle)
def test_repr(self):
CATS = Registry('cat')
@CATS.register_module()
class BritishShorthair:
pass
@CATS.register_module()
class Munchkin:
pass
repr_str = 'Registry(name=cat, items={'
repr_str += (
"'BritishShorthair': <class 'test_registry.TestRegistry.test_repr."
"<locals>.BritishShorthair'>, ")
repr_str += (
"'Munchkin': <class 'test_registry.TestRegistry.test_repr."
"<locals>.Munchkin'>")
repr_str += '})'
assert repr(CATS) == repr_str
def test_build_from_cfg():
BACKBONES = Registry('backbone')
@BACKBONES.register_module()
class ResNet:
def __init__(self, depth, stages=4):
self.depth = depth
self.stages = stages
@BACKBONES.register_module()
class ResNeXt:
def __init__(self, depth, stages=4):
self.depth = depth
self.stages = stages
# test `cfg` parameter
# `cfg` should be a dict, ConfigDict or Config object
with pytest.raises(
TypeError, match="cfg must be a dict, but got <class 'str'>"):
cfg = 'ResNet'
model = build_from_cfg(cfg, BACKBONES)
# `cfg` is a dict
cfg = dict(type='ResNet', depth=50)
model = build_from_cfg(cfg, BACKBONES)
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
# `cfg` is a ConfigDict object
cfg = ConfigDict(dict(type='ResNet', depth=50))
model = build_from_cfg(cfg, BACKBONES)
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
# `cfg` is a Config object
cfg = Config(dict(type='ResNet', depth=50))
model = build_from_cfg(cfg, BACKBONES)
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
# `cfg` is a dict but it does not contain the key "type"
with pytest.raises(KeyError, match='must contain the key "type"'):
cfg = dict(depth=50, stages=4)
model = build_from_cfg(cfg, BACKBONES)
# cfg['type'] should be a str or class
with pytest.raises(
TypeError,
match="type must be a str or valid type, but got <class 'int'>"):
cfg = dict(type=1000)
model = build_from_cfg(cfg, BACKBONES)
cfg = dict(type='ResNeXt', depth=50, stages=3)
model = build_from_cfg(cfg, BACKBONES)
assert isinstance(model, ResNeXt)
assert model.depth == 50 and model.stages == 3
cfg = dict(type=ResNet, depth=50)
model = build_from_cfg(cfg, BACKBONES)
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
# non-registered class
with pytest.raises(KeyError, match='VGG is not in the backbone registry'):
cfg = dict(type='VGG')
model = build_from_cfg(cfg, BACKBONES)
# `cfg` contains unexpected arguments
with pytest.raises(TypeError):
cfg = dict(type='ResNet', non_existing_arg=50)
model = build_from_cfg(cfg, BACKBONES)
# test `default_args` parameter
cfg = dict(type='ResNet', depth=50)
model = build_from_cfg(cfg, BACKBONES, default_args={'stages': 3})
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 3
# default_args must be a dict or None
with pytest.raises(TypeError):
cfg = dict(type='ResNet', depth=50)
model = build_from_cfg(cfg, BACKBONES, default_args=1)
# cfg or default_args should contain the key "type"
with pytest.raises(KeyError, match='must contain the key "type"'):
cfg = dict(depth=50)
model = build_from_cfg(cfg, BACKBONES, default_args=dict(stages=4))
# "type" defined using default_args
cfg = dict(depth=50)
model = build_from_cfg(cfg, BACKBONES, default_args=dict(type='ResNet'))
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
cfg = dict(depth=50)
model = build_from_cfg(cfg, BACKBONES, default_args=dict(type=ResNet))
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
# test `registry` parameter
# incorrect registry type
with pytest.raises(
TypeError,
match=('registry must be a mmengine.Registry object, but got '
"<class 'str'>")):
cfg = dict(type='ResNet', depth=50)
model = build_from_cfg(cfg, 'BACKBONES')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment