Skip to content
Snippets Groups Projects
Unverified Commit b2ad2210 authored by Qian Zhao's avatar Qian Zhao Committed by GitHub
Browse files

[Feature] Support registering partial functions and more (#595)


* support registering partial functions

* Update mmengine/registry/build_functions.py

Co-authored-by: default avatarMashiro <57566630+HAOCHENYE@users.noreply.github.com>

* Update mmengine/registry/registry.py

Co-authored-by: default avatarMashiro <57566630+HAOCHENYE@users.noreply.github.com>

* Revert unit test and refine

* add current logger and set log level

---------

Co-authored-by: default avatarMashiro <57566630+HAOCHENYE@users.noreply.github.com>
Co-authored-by: default avatarHAOCHENYE <21724054@zju.edu.cn>
parent f76218a4
No related branches found
No related tags found
No related merge requests found
......@@ -104,7 +104,8 @@ def build_from_cfg(
'can be found at '
'https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module' # noqa: E501
)
elif inspect.isclass(obj_type) or inspect.isfunction(obj_type):
# this will include classes, functions, partial functions and more
elif callable(obj_type):
obj_cls = obj_type
else:
raise TypeError(
......@@ -120,12 +121,20 @@ def build_from_cfg(
else:
obj = obj_cls(**args) # type: ignore
print_log(
f'An `{obj_cls.__name__}` instance is built from ' # type: ignore # noqa: E501
'registry, its implementation can be found in '
f'{obj_cls.__module__}', # type: ignore
logger='current',
level=logging.DEBUG)
if (inspect.isclass(obj_cls) or inspect.isfunction(obj_cls)
or inspect.ismethod(obj_cls)):
print_log(
f'An `{obj_cls.__name__}` instance is built from ' # type: ignore # noqa: E501
'registry, and its implementation can be found in '
f'{obj_cls.__module__}', # type: ignore
logger='current',
level=logging.DEBUG)
else:
print_log(
'An instance is built from registry, and its constructor '
f'is {obj_cls}',
logger='current',
level=logging.DEBUG)
return obj
except Exception as e:
......
......@@ -487,8 +487,11 @@ class Registry:
obj_cls = root.get(key)
if obj_cls is not None:
# For some rare cases (e.g. obj_cls is a partial function), obj_cls
# doesn't have `__name__`. Use default value to prevent error
cls_name = getattr(obj_cls, '__name__', str(obj_cls))
print_log(
f'Get class `{obj_cls.__name__}` from "{registry_name}"'
f'Get class `{cls_name}` from "{registry_name}"'
f' registry in "{scope_name}"',
logger='current',
level=logging.DEBUG)
......@@ -565,16 +568,16 @@ class Registry:
"""Register a module.
Args:
module (type): Module class or function to be registered.
module (type): Module to be registered. Typically a class or a
function, but generally all ``Callable`` are acceptable.
module_name (str or list of str, optional): The module name to be
registered. If not specified, the class name will be used.
Defaults to None.
force (bool): Whether to override an existing class with the same
name. Defaults to False.
"""
if not inspect.isclass(module) and not inspect.isfunction(module):
raise TypeError('module must be a class or a function, '
f'but got {type(module)}')
if not callable(module):
raise TypeError(f'module must be Callable, but got {type(module)}')
if module_name is None:
module_name = module.__name__
......
# Copyright (c) OpenMMLab. All rights reserved.
import functools
import time
import pytest
......@@ -59,23 +60,12 @@ class TestRegistry:
CATS = Registry('cat')
@CATS.register_module()
def muchkin():
def muchkin(size):
pass
assert CATS.get('muchkin') is muchkin
assert 'muchkin' in CATS
# can only decorate a class or a function
with pytest.raises(TypeError):
class Demo:
def some_method(self):
pass
method = Demo().some_method
CATS.register_module(name='some_method', module=method)
# test `name` parameter which must be either of None, a string or a
# sequence of string
# `name` is None
......@@ -146,7 +136,7 @@ class TestRegistry:
# decorator, which must be a class
with pytest.raises(
TypeError,
match='module must be a class or a function,'
match='module must be Callable,'
" but got <class 'str'>"):
CATS.register_module(module='string')
......@@ -166,6 +156,17 @@ class TestRegistry:
assert CATS.get('Sphynx3') is SphynxCat
assert len(CATS) == 9
# partial functions can be registered
muchkin0 = functools.partial(muchkin, size=0)
CATS.register_module('muchkin0', False, muchkin0)
# lambda functions can be registered
CATS.register_module(name='unknown cat', module=lambda: 'unknown')
assert CATS.get('muchkin0') is muchkin0
assert 'unknown cat' in CATS
assert 'muchkin0' in CATS
assert len(CATS) == 11
def _build_registry(self):
"""A helper function to build a Hierarchical Registry."""
# Hierarchical Registry
......@@ -227,12 +228,21 @@ class TestRegistry:
DOGS, HOUNDS, LITTLE_HOUNDS = registries[:3]
MID_HOUNDS, SAMOYEDS, LITTLE_SAMOYEDS = registries[3:]
@DOGS.register_module()
def bark(word, times):
return [word] * times
dog_bark = functools.partial(bark, 'woof')
DOGS.register_module('dog_bark', False, dog_bark)
@DOGS.register_module()
class GoldenRetriever:
pass
assert len(DOGS) == 1
assert len(DOGS) == 3
assert DOGS.get('GoldenRetriever') is GoldenRetriever
assert DOGS.get('bark') is bark
assert DOGS.get('dog_bark') is dog_bark
@HOUNDS.register_module()
class BloodHound:
......@@ -249,6 +259,8 @@ class TestRegistry:
# If the key is not found in the current registry, then look for its
# parent
assert HOUNDS.get('GoldenRetriever') is GoldenRetriever
assert HOUNDS.get('bark') is bark
assert HOUNDS.get('dog_bark') is dog_bark
@LITTLE_HOUNDS.register_module()
class Dachshund:
......@@ -340,11 +352,14 @@ class TestRegistry:
DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS, SAMOYEDS = registries[:5]
@DOGS.register_module()
def bark(times=1):
return ' '.join(['woof'] * times)
def bark(word, times):
return ' '.join([word] * times)
bark_cfg = cfg_type(dict(type='bark', times=3))
assert DOGS.build(bark_cfg) == 'woof woof woof'
dog_bark = functools.partial(bark, word='woof')
DOGS.register_module('dog_bark', False, dog_bark)
bark_cfg = cfg_type(dict(type='bark', word='meow', times=3))
dog_bark_cfg = cfg_type(dict(type='dog_bark', times=3))
@DOGS.register_module()
class GoldenRetriever:
......@@ -352,6 +367,8 @@ class TestRegistry:
gr_cfg = cfg_type(dict(type='GoldenRetriever'))
assert isinstance(DOGS.build(gr_cfg), GoldenRetriever)
assert DOGS.build(bark_cfg) == 'meow meow meow'
assert DOGS.build(dog_bark_cfg) == 'woof woof woof'
@HOUNDS.register_module()
class BloodHound:
......@@ -360,6 +377,8 @@ class TestRegistry:
bh_cfg = cfg_type(dict(type='BloodHound'))
assert isinstance(HOUNDS.build(bh_cfg), BloodHound)
assert isinstance(HOUNDS.build(gr_cfg), GoldenRetriever)
assert HOUNDS.build(bark_cfg) == 'meow meow meow'
assert HOUNDS.build(dog_bark_cfg) == 'woof woof woof'
@LITTLE_HOUNDS.register_module()
class Dachshund:
......@@ -419,6 +438,18 @@ class TestRegistry:
assert isinstance(dog.friend, YourSamoyed)
assert DefaultScope.get_current_instance().scope_name != 'samoyed'
# build an instance by lambda or partial function.
lambda_dog = lambda name: name # noqa: E731
DOGS.register_module(name='lambda_dog', module=lambda_dog)
lambda_cfg = cfg_type(dict(type='lambda_dog', name='unknown'))
assert DOGS.build(lambda_cfg) == 'unknown'
DOGS.register_module(
name='patial dog',
module=functools.partial(lambda_dog, name='patial'))
unknown_cfg = cfg_type(dict(type='patial dog'))
assert DOGS.build(unknown_cfg) == 'patial'
def test_switch_scope_and_registry(self):
DOGS = Registry('dogs')
HOUNDS = Registry('hounds', scope='hound', parent=DOGS)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment