Skip to content
Snippets Groups Projects
Unverified Commit 65bc9503 authored by Jiazhen Wang's avatar Jiazhen Wang Committed by GitHub
Browse files

[Enhance] Support Custom Runner (#258)

* support custom runner

* change build_runner_from_cfg

* refine docstring

* refine docstring
parent 94c7c3be
No related branches found
No related tags found
No related merge requests found
# Copyright (c) OpenMMLab. All rights reserved.
from .default_scope import DefaultScope
from .registry import Registry, build_from_cfg
from .registry import Registry, build_from_cfg, build_runner_from_cfg
from .root import (DATA_SAMPLERS, DATASETS, HOOKS, LOG_PROCESSORS, LOOPS,
METRICS, MODEL_WRAPPERS, MODELS, OPTIM_WRAPPER_CONSTRUCTORS,
OPTIM_WRAPPERS, OPTIMIZERS, PARAM_SCHEDULERS,
......@@ -14,5 +14,6 @@ __all__ = [
'OPTIMIZERS', 'OPTIM_WRAPPER_CONSTRUCTORS', 'TASK_UTILS',
'PARAM_SCHEDULERS', 'METRICS', 'MODEL_WRAPPERS', 'OPTIM_WRAPPERS', 'LOOPS',
'VISBACKENDS', 'VISUALIZERS', 'LOG_PROCESSORS', 'DefaultScope',
'traverse_registry_tree', 'count_registered_modules'
'traverse_registry_tree', 'count_registered_modules',
'build_runner_from_cfg'
]
......@@ -10,6 +10,73 @@ from ..utils import ManagerMixin, is_seq_of
from .default_scope import DefaultScope
def build_runner_from_cfg(cfg: Union[dict, ConfigDict, Config],
registry: 'Registry') -> Any:
"""Build a Runner object.
Examples:
>>> from mmengine import Registry, build_runner_from_cfg
>>> RUNNERS = Registry('runners', build_func=build_runner_from_cfg)
>>> @RUNNERS.register_module()
>>> class CustomRunner(Runner):
>>> def setup_env(env_cfg):
>>> pass
>>> cfg = dict(runner_type='CustomRunner', ...)
>>> custom_runner = RUNNERS.build(cfg)
Args:
cfg (dict or ConfigDict or Config): Config dict. If "runner_type" key
exists, it will be used to build a custom runner. Otherwise, it
will be used to build a default runner.
registry (:obj:`Registry`): The registry to search the type from.
Returns:
object: The constructed runner object.
"""
from ..logging.logger import MMLogger
assert isinstance(
cfg,
(dict, ConfigDict, Config
)), f'cfg should be a dict, ConfigDict or Config, but got {type(cfg)}'
assert isinstance(
registry, Registry), ('registry should be a mmengine.Registry object',
f'but got {type(registry)}')
args = cfg.copy()
obj_type = args.pop('runner_type', 'mmengine.Runner')
if isinstance(obj_type, str):
runner_cls = registry.get(obj_type)
if runner_cls is None:
raise KeyError(
f'{obj_type} is not in the {registry.name} registry. '
f'Please check whether the value of `{obj_type}` is correct or'
' it was registered as expected. More details can be found at'
' https://mmengine.readthedocs.io/en/latest/tutorials/config.html#import-custom-python-modules' # noqa: E501
)
elif inspect.isclass(obj_type):
runner_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')
try:
runner = runner_cls.from_cfg(args) # type: ignore
logger: MMLogger = MMLogger.get_current_instance()
logger.info(
f'An `{runner_cls.__name__}` instance is built ' # type: ignore
f'from registry, its implementation can be found in'
f'{runner_cls.__module__}') # type: ignore
return runner
except Exception as e:
# Normal TypeError does not print class name.
cls_location = '/'.join(
runner_cls.__module__.split('.')) # type: ignore
raise type(e)(
f'class `{runner_cls.__name__}` in ' # type: ignore
f'{cls_location}.py: {e}')
def build_from_cfg(
cfg: Union[dict, ConfigDict, Config],
registry: 'Registry',
......
......@@ -6,10 +6,10 @@ More datails can be found at
https://mmengine.readthedocs.io/en/latest/tutorials/registry.html.
"""
from .registry import Registry
from .registry import Registry, build_runner_from_cfg
# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
RUNNERS = Registry('runner')
RUNNERS = Registry('runner', build_func=build_runner_from_cfg)
# manage runner constructors that define how to initialize runners
RUNNER_CONSTRUCTORS = Registry('runner constructor')
# manage all kinds of loops like `EpochBasedTrainLoop`
......
......@@ -30,7 +30,7 @@ from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler,
build_optim_wrapper)
from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS,
VISUALIZERS, DefaultScope,
RUNNERS, VISUALIZERS, DefaultScope,
count_registered_modules)
from mmengine.registry.root import LOG_PROCESSORS
from mmengine.utils import (TORCH_VERSION, digit_version,
......@@ -49,6 +49,7 @@ ParamSchedulerType = Union[List[_ParamScheduler], Dict[str,
OptimWrapperType = Union[OptimWrapper, OptimWrapperDict]
@RUNNERS.register_module()
class Runner:
"""A training helper for PyTorch.
......
......@@ -24,7 +24,7 @@ from mmengine.optim import (DefaultOptimWrapperConstructor, MultiStepLR,
from mmengine.registry import (DATASETS, HOOKS, LOG_PROCESSORS, LOOPS, METRICS,
MODEL_WRAPPERS, MODELS,
OPTIM_WRAPPER_CONSTRUCTORS, PARAM_SCHEDULERS,
Registry)
RUNNERS, Registry)
from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop,
Runner, TestLoop, ValLoop)
from mmengine.runner.priority import Priority, get_priority
......@@ -215,6 +215,41 @@ class CustomLogProcessor(LogProcessor):
self._check_custom_cfg()
@RUNNERS.register_module()
class CustomRunner(Runner):
def __init__(self,
model,
work_dir,
train_dataloader=None,
val_dataloader=None,
test_dataloader=None,
train_cfg=None,
val_cfg=None,
test_cfg=None,
optimizer=None,
param_scheduler=None,
val_evaluator=None,
test_evaluator=None,
default_hooks=None,
custom_hooks=None,
load_from=None,
resume=False,
launcher='none',
env_cfg=dict(dist_cfg=dict(backend='nccl')),
log_processor=None,
log_level='INFO',
visualizer=None,
default_scope=None,
randomness=dict(seed=None),
experiment_name=None,
cfg=None):
pass
def setup_env(self, env_cfg):
pass
def collate_fn(data_batch):
return data_batch
......@@ -1511,3 +1546,17 @@ class TestRunner(TestCase):
self.assertTrue(runner._has_loaded)
self.assertIsInstance(runner.optim_wrapper.optimizer, SGD)
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
def test_build_runner(self):
# No need to test other cases which have been tested in
# `test_build_from_cfg`
# test custom runner
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_runner1'
cfg.runner_type = 'CustomRunner'
assert isinstance(RUNNERS.build(cfg), CustomRunner)
# test default runner
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_runner2'
assert isinstance(RUNNERS.build(cfg), Runner)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment