Newer
Older
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.config import Config, ConfigDict # type: ignore
from mmengine.registry import DefaultScope, Registry, build_from_cfg
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
class TestRegistry:
def test_init(self):
CATS = Registry('cat')
assert CATS.name == 'cat'
assert CATS.module_dict == {}
assert CATS.build_func is build_from_cfg
assert len(CATS) == 0
# test `build_func` parameter
def build_func(cfg, registry, default_args):
pass
CATS = Registry('cat', build_func=build_func)
assert CATS.build_func is build_func
# test `parent` parameter
# `parent` is either None or a `Registry` instance
with pytest.raises(AssertionError):
CATS = Registry('little_cat', parent='cat', scope='little_cat')
LITTLECATS = Registry('little_cat', parent=CATS, scope='little_cat')
assert LITTLECATS.parent is CATS
assert CATS._children.get('little_cat') is LITTLECATS
# test `scope` parameter
# `scope` is either None or a string
with pytest.raises(AssertionError):
CATS = Registry('cat', scope=1)
CATS = Registry('cat')
assert CATS.scope == 'test_registry'
CATS = Registry('cat', scope='cat')
assert CATS.scope == 'cat'
def test_split_scope_key(self):
DOGS = Registry('dogs')
scope, key = DOGS.split_scope_key('BloodHound')
assert scope is None and key == 'BloodHound'
scope, key = DOGS.split_scope_key('hound.BloodHound')
assert scope == 'hound' and key == 'BloodHound'
scope, key = DOGS.split_scope_key('hound.little_hound.Dachshund')
assert scope == 'hound' and key == 'little_hound.Dachshund'
def test_register_module(self):
CATS = Registry('cat')
# can only decorate a class
with pytest.raises(TypeError):
@CATS.register_module()
def some_method():
pass
# test `name` parameter which must be either of None, a string or a
# sequence of string
# `name` is None
@CATS.register_module()
class BritishShorthair:
pass
assert len(CATS) == 1
assert CATS.get('BritishShorthair') is BritishShorthair
# `name` is a string
@CATS.register_module(name='Munchkin')
class Munchkin:
pass
assert len(CATS) == 2
assert CATS.get('Munchkin') is Munchkin
assert 'Munchkin' in CATS
# `name` is a sequence of string
@CATS.register_module(name=['Siamese', 'Siamese2'])
class SiameseCat:
pass
assert CATS.get('Siamese') is SiameseCat
assert CATS.get('Siamese2') is SiameseCat
assert len(CATS) == 4
# `name` is an invalid type
with pytest.raises(
TypeError,
match=('name must be None, an instance of str, or a sequence '
"of str, but got <class 'int'>")):
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
@CATS.register_module(name=7474741)
class SiameseCat:
pass
# test `force` parameter, which must be a boolean
# force is not a boolean
with pytest.raises(
TypeError,
match="force must be a boolean, but got <class 'int'>"):
@CATS.register_module(force=1)
class BritishShorthair:
pass
# force=False
with pytest.raises(
KeyError,
match='BritishShorthair is already registered in cat'):
@CATS.register_module()
class BritishShorthair:
pass
# force=True
@CATS.register_module(force=True)
class BritishShorthair:
pass
assert len(CATS) == 4
# test `module` parameter, which is either None or a class
# when the `register_module`` is called as a method rather than a
# decorator, which must be a class
with pytest.raises(
TypeError,
match="module must be a class, but got <class 'str'>"):
CATS.register_module(module='string')
class SphynxCat:
pass
CATS.register_module(module=SphynxCat)
assert CATS.get('SphynxCat') is SphynxCat
assert len(CATS) == 5
CATS.register_module(name='Sphynx1', module=SphynxCat)
assert CATS.get('Sphynx1') is SphynxCat
assert len(CATS) == 6
CATS.register_module(name=['Sphynx2', 'Sphynx3'], module=SphynxCat)
assert CATS.get('Sphynx2') is SphynxCat
assert CATS.get('Sphynx3') is SphynxCat
assert len(CATS) == 8
def _build_registry(self):
"""A helper function to build a Hierarchical Registry."""
# Hierarchical Registry
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
# DOGS
# _______|_______
# | |
# HOUNDS (hound) SAMOYEDS (samoyed)
# _______|_______ |
# | | |
# LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS
# (little_hound) (mid_hound) (little_samoyed)
registries = []
DOGS = Registry('dogs')
registries.append(DOGS)
HOUNDS = Registry('dogs', parent=DOGS, scope='hound')
registries.append(HOUNDS)
LITTLE_HOUNDS = Registry('dogs', parent=HOUNDS, scope='little_hound')
registries.append(LITTLE_HOUNDS)
MID_HOUNDS = Registry('dogs', parent=HOUNDS, scope='mid_hound')
registries.append(MID_HOUNDS)
SAMOYEDS = Registry('dogs', parent=DOGS, scope='samoyed')
registries.append(SAMOYEDS)
LITTLE_SAMOYEDS = Registry(
'dogs', parent=SAMOYEDS, scope='little_samoyed')
registries.append(LITTLE_SAMOYEDS)
return registries
def test_get_root_registry(self):
# DOGS
# _______|_______
# | |
# HOUNDS (hound) SAMOYEDS (samoyed)
# _______|_______ |
# | | |
# LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS
# (little_hound) (mid_hound) (little_samoyed)
registries = self._build_registry()
DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS = registries[:4]
assert DOGS._get_root_registry() is DOGS
assert HOUNDS._get_root_registry() is DOGS
assert LITTLE_HOUNDS._get_root_registry() is DOGS
assert MID_HOUNDS._get_root_registry() is DOGS
def test_get(self):
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
# DOGS
# _______|_______
# | |
# HOUNDS (hound) SAMOYEDS (samoyed)
# _______|_______ |
# | | |
# LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS
# (little_hound) (mid_hound) (little_samoyed)
registries = self._build_registry()
DOGS, HOUNDS, LITTLE_HOUNDS = registries[:3]
MID_HOUNDS, SAMOYEDS, LITTLE_SAMOYEDS = registries[3:]
@DOGS.register_module()
class GoldenRetriever:
pass
assert len(DOGS) == 1
assert DOGS.get('GoldenRetriever') is GoldenRetriever
@HOUNDS.register_module()
class BloodHound:
pass
assert len(HOUNDS) == 1
# get key from current registry
assert HOUNDS.get('BloodHound') is BloodHound
# get key from its children
assert DOGS.get('hound.BloodHound') is BloodHound
# get key from current registry
assert HOUNDS.get('hound.BloodHound') is BloodHound
# If the key is not found in the current registry, then look for its
# parent
assert HOUNDS.get('GoldenRetriever') is GoldenRetriever
@LITTLE_HOUNDS.register_module()
class Dachshund:
pass
assert len(LITTLE_HOUNDS) == 1
# get key from current registry
assert LITTLE_HOUNDS.get('Dachshund') is Dachshund
# get key from its parent
assert LITTLE_HOUNDS.get('hound.BloodHound') is BloodHound
# get key from its children
assert HOUNDS.get('little_hound.Dachshund') is Dachshund
# get key from its descendants
assert DOGS.get('hound.little_hound.Dachshund') is Dachshund
# If the key is not found in the current registry, then look for its
# parent
assert LITTLE_HOUNDS.get('BloodHound') is BloodHound
assert LITTLE_HOUNDS.get('GoldenRetriever') is GoldenRetriever
@MID_HOUNDS.register_module()
class Beagle:
pass
# get key from its sibling registries
assert LITTLE_HOUNDS.get('hound.mid_hound.Beagle') is Beagle
@SAMOYEDS.register_module()
class PedigreeSamoyed:
pass
assert len(SAMOYEDS) == 1
# get key from its uncle
assert LITTLE_HOUNDS.get('samoyed.PedigreeSamoyed') is PedigreeSamoyed
@LITTLE_SAMOYEDS.register_module()
class LittlePedigreeSamoyed:
pass
# get key from its cousin
assert LITTLE_HOUNDS.get('samoyed.little_samoyed.LittlePedigreeSamoyed'
) is LittlePedigreeSamoyed
# get key from its nephews
assert HOUNDS.get('samoyed.little_samoyed.LittlePedigreeSamoyed'
) is LittlePedigreeSamoyed
# invalid keys
# GoldenRetrieverererer can not be found at LITTLE_HOUNDS modules
assert LITTLE_HOUNDS.get('GoldenRetrieverererer') is None
# samoyedddd is not a child of DOGS
assert DOGS.get('samoyedddd.PedigreeSamoyed') is None
# samoyed is a child of DOGS but LittlePedigreeSamoyed can not be found
# at SAMOYEDS modules
assert DOGS.get('samoyed.LittlePedigreeSamoyed') is None
assert LITTLE_HOUNDS.get('mid_hound.PedigreeSamoyedddddd') is None
# DOGS
# _______|_______
# | |
# HOUNDS (hound) SAMOYEDS (samoyed)
# _______|_______ |
# | | |
# LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS
# (little_hound) (mid_hound) (little_samoyed)
registries = self._build_registry()
DOGS, HOUNDS, LITTLE_HOUNDS = registries[:3]
assert DOGS._search_child('hound') is HOUNDS
assert DOGS._search_child('not a child') is None
assert DOGS._search_child('little_hound') is LITTLE_HOUNDS
assert LITTLE_HOUNDS._search_child('hound') is None
assert LITTLE_HOUNDS._search_child('mid_hound') is None
@pytest.mark.parametrize('cfg_type', [dict, ConfigDict, Config])
def test_build(self, cfg_type):
# Hierarchical Registry
# DOGS
# _______|_______
# | |
# HOUNDS (hound) SAMOYEDS (samoyed)
# _______|_______ |
# | | |
# LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS
# (little_hound) (mid_hound) (little_samoyed)
registries = self._build_registry()
DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS = registries[:4]
@DOGS.register_module()
class GoldenRetriever:
pass
gr_cfg = cfg_type(dict(type='GoldenRetriever'))
assert isinstance(DOGS.build(gr_cfg), GoldenRetriever)
@HOUNDS.register_module()
class BloodHound:
pass
assert isinstance(HOUNDS.build(bh_cfg), BloodHound)
assert isinstance(HOUNDS.build(gr_cfg), GoldenRetriever)
@LITTLE_HOUNDS.register_module()
class Dachshund:
pass
assert isinstance(LITTLE_HOUNDS.build(d_cfg), Dachshund)
@MID_HOUNDS.register_module()
class Beagle:
pass
assert isinstance(MID_HOUNDS.build(b_cfg), Beagle)
# test `default_scope`
# switch the current registry to another registry
DefaultScope.get_instance(
f'test-{time.time()}', scope_name='mid_hound')
dog = LITTLE_HOUNDS.build(b_cfg)
# `default_scope` can not be found
DefaultScope.get_instance(
f'test2-{time.time()}', scope_name='scope-not-found')
dog = MID_HOUNDS.build(b_cfg)
assert isinstance(dog, Beagle)
def test_repr(self):
CATS = Registry('cat')
@CATS.register_module()
class BritishShorthair:
pass
@CATS.register_module()
class Munchkin:
pass
repr_str = 'Registry(name=cat, items={'
repr_str += (
"'BritishShorthair': <class 'test_registry.TestRegistry.test_repr."
"<locals>.BritishShorthair'>, ")
repr_str += (
"'Munchkin': <class 'test_registry.TestRegistry.test_repr."
"<locals>.Munchkin'>")
repr_str += '})'
assert repr(CATS) == repr_str
@pytest.mark.parametrize('cfg_type', [dict, ConfigDict, Config])
def test_build_from_cfg(cfg_type):
BACKBONES = Registry('backbone')
@BACKBONES.register_module()
class ResNet:
def __init__(self, depth, stages=4):
self.depth = depth
self.stages = stages
@BACKBONES.register_module()
class ResNeXt:
def __init__(self, depth, stages=4):
self.depth = depth
self.stages = stages
# test `cfg` parameter
# `cfg` should be a dict, ConfigDict or Config object
with pytest.raises(
TypeError,
match=('cfg should be a dict, ConfigDict or Config, but got '
"<class 'str'>")):
cfg = 'ResNet'
model = build_from_cfg(cfg, BACKBONES)
# `cfg` is a dict, ConfigDict or Config object
cfg = cfg_type(dict(type='ResNet', depth=50))
model = build_from_cfg(cfg, BACKBONES)
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
# `cfg` is a dict but it does not contain the key "type"
with pytest.raises(KeyError, match='must contain the key "type"'):
cfg = dict(depth=50, stages=4)
model = build_from_cfg(cfg, BACKBONES)
# cfg['type'] should be a str or class
with pytest.raises(
TypeError,
match="type must be a str or valid type, but got <class 'int'>"):
cfg = dict(type=1000)
model = build_from_cfg(cfg, BACKBONES)
cfg = cfg_type(dict(type='ResNeXt', depth=50, stages=3))
model = build_from_cfg(cfg, BACKBONES)
assert isinstance(model, ResNeXt)
assert model.depth == 50 and model.stages == 3
model = build_from_cfg(cfg, BACKBONES)
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
# non-registered class
with pytest.raises(KeyError, match='VGG is not in the backbone registry'):
model = build_from_cfg(cfg, BACKBONES)
# `cfg` contains unexpected arguments
with pytest.raises(TypeError):
cfg = cfg_type(dict(type='ResNet', non_existing_arg=50))
model = build_from_cfg(cfg, BACKBONES)
# test `default_args` parameter
cfg = cfg_type(dict(type='ResNet', depth=50))
model = build_from_cfg(cfg, BACKBONES, cfg_type(dict(stages=3)))
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 3
# default_args must be a dict or None
with pytest.raises(TypeError):
model = build_from_cfg(cfg, BACKBONES, default_args=1)
# cfg or default_args should contain the key "type"
with pytest.raises(KeyError, match='must contain the key "type"'):
cfg = cfg_type(dict(depth=50))
model = build_from_cfg(
cfg, BACKBONES, default_args=cfg_type(dict(stages=4)))
cfg = cfg_type(dict(depth=50))
model = build_from_cfg(
cfg, BACKBONES, default_args=cfg_type(dict(type='ResNet')))
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
cfg = cfg_type(dict(depth=50))
model = build_from_cfg(
cfg, BACKBONES, default_args=cfg_type(dict(type=ResNet)))
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
# test `registry` parameter
# incorrect registry type
with pytest.raises(
TypeError,
match=('registry must be a mmengine.Registry object, but got '
"<class 'str'>")):
VISUALIZER = Registry('visualizer')
@VISUALIZER.register_module()
class Visualizer(ManagerMixin):
def __init__(self, name):
super().__init__(name)
with pytest.raises(RuntimeError):
Visualizer.get_current_instance()
cfg = dict(type='Visualizer', name='visualizer')
build_from_cfg(cfg, VISUALIZER)
Visualizer.get_current_instance()