Skip to content
Snippets Groups Projects
Unverified Commit f1de071c authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

[Enhancement] Refactor Runner (#139)

* [Enhancement] Rename build_from_cfg to from_cfg

* refactor build_logger and build_message_hub

* remove time.sleep from unit tests

* minor fix

* move set_randomness from setup_env

* improve docstring

* refine comments

* print a warning information

* refine comments

* simplify the interface of build_logger
parent 9a61b389
No related branches found
No related tags found
No related merge requests found
...@@ -32,7 +32,8 @@ from mmengine.optim import _ParamScheduler, build_optimizer ...@@ -32,7 +32,8 @@ from mmengine.optim import _ParamScheduler, build_optimizer
from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS, MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS,
DefaultScope) DefaultScope)
from mmengine.utils import find_latest_checkpoint, is_list_of, symlink from mmengine.utils import (TORCH_VERSION, digit_version,
find_latest_checkpoint, is_list_of, symlink)
from mmengine.visualization import ComposedWriter from mmengine.visualization import ComposedWriter
from .base_loop import BaseLoop from .base_loop import BaseLoop
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model, from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
...@@ -47,6 +48,20 @@ ConfigType = Union[Dict, Config, ConfigDict] ...@@ -47,6 +48,20 @@ ConfigType = Union[Dict, Config, ConfigDict]
class Runner: class Runner:
"""A training helper for PyTorch. """A training helper for PyTorch.
Runner object can be built from config by ``runner = Runner.from_cfg(cfg)``
where the ``cfg`` usually contains training, validation, and test-related
configurations to build corresponding components. We usually use the
same config to launch training, testing, and validation tasks. However,
only some of these components are necessary at the same time, e.g.,
testing a model does not need training or validation-related components.
To avoid repeatedly modifying config, the construction of ``Runner`` adopts
lazy initialization to only initialize components when they are going to be
used. Therefore, the model is always initialized at the beginning, and
training, validation, and, testing related components are only initialized
when calling ``runner.train()``, ``runner.val()``, and ``runner.test()``,
respectively.
Args: Args:
model (:obj:`torch.nn.Module` or dict): The model to be run. It can be model (:obj:`torch.nn.Module` or dict): The model to be run. It can be
a dict used for build a model. a dict used for build a model.
...@@ -114,23 +129,21 @@ class Runner: ...@@ -114,23 +129,21 @@ class Runner:
non-distributed environment will be launched. non-distributed environment will be launched.
env_cfg (dict): A dict used for setting environment. Defaults to env_cfg (dict): A dict used for setting environment. Defaults to
dict(dist_cfg=dict(backend='nccl')). dict(dist_cfg=dict(backend='nccl')).
logger (MMLogger or dict, optional): A MMLogger object or a dict to log_level (int or str): The log level of MMLogger handlers.
build MMLogger object. Defaults to None. If not specified, default Defaults to 'INFO'.
config will be used.
message_hub (MessageHub or dict, optional): A Messagehub object or a
dict to build MessageHub object. Defaults to None. If not
specified, default config will be used.
writer (ComposedWriter or dict, optional): A ComposedWriter object or a writer (ComposedWriter or dict, optional): A ComposedWriter object or a
dict build ComposedWriter object. Defaults to None. If not dict build ComposedWriter object. Defaults to None. If not
specified, default config will be used. specified, default config will be used.
default_scope (str, optional): Used to reset registries location. default_scope (str, optional): Used to reset registries location.
Defaults to None. Defaults to None.
seed (int, optional): A number to set random modules. If not specified, randomness (dict): Some settings to make the experiment as reproducible
a random number will be set as seed. Defaults to None. as possible like seed and deterministic.
deterministic (bool): Whether cudnn to select deterministic algorithms. Defaults to ``dict(seed=None)``. If seed is None, a random number
Defaults to False. will be generated and it will be broadcasted to all other processes
See https://pytorch.org/docs/stable/notes/randomness.html for if in distributed environment. If ``cudnn_benchmarch`` is
more details. ``True`` in ``env_cfg`` but ``deterministic`` is ``True`` in
``randomness``, the value of ``torch.backends.cudnn.benchmark``
will be ``False`` finally.
experiment_name (str, optional): Name of current experiment. If not experiment_name (str, optional): Name of current experiment. If not
specified, timestamp will be used as ``experiment_name``. specified, timestamp will be used as ``experiment_name``.
Defaults to None. Defaults to None.
...@@ -173,13 +186,11 @@ class Runner: ...@@ -173,13 +186,11 @@ class Runner:
param_scheduler=dict(type='ParamSchedulerHook')), param_scheduler=dict(type='ParamSchedulerHook')),
launcher='none', launcher='none',
env_cfg=dict(dist_cfg=dict(backend='nccl')), env_cfg=dict(dist_cfg=dict(backend='nccl')),
logger=dict(log_level='INFO'),
message_hub=None,
writer=dict( writer=dict(
name='composed_writer', name='composed_writer',
writers=[dict(type='LocalWriter', save_dir='temp_dir')]) writers=[dict(type='LocalWriter', save_dir='temp_dir')])
) )
>>> runner = Runner.build_from_cfg(cfg) >>> runner = Runner.from_cfg(cfg)
>>> runner.train() >>> runner.train()
>>> runner.test() >>> runner.test()
""" """
...@@ -208,19 +219,17 @@ class Runner: ...@@ -208,19 +219,17 @@ class Runner:
resume: bool = False, resume: bool = False,
launcher: str = 'none', launcher: str = 'none',
env_cfg: Dict = dict(dist_cfg=dict(backend='nccl')), env_cfg: Dict = dict(dist_cfg=dict(backend='nccl')),
logger: Optional[Union[MMLogger, Dict]] = None, log_level: str = 'INFO',
message_hub: Optional[Union[MessageHub, Dict]] = None,
writer: Optional[Union[ComposedWriter, Dict]] = None, writer: Optional[Union[ComposedWriter, Dict]] = None,
default_scope: Optional[str] = None, default_scope: Optional[str] = None,
seed: Optional[int] = None, randomness: Dict = dict(seed=None),
deterministic: bool = False,
experiment_name: Optional[str] = None, experiment_name: Optional[str] = None,
cfg: Optional[ConfigType] = None, cfg: Optional[ConfigType] = None,
): ):
self._work_dir = osp.abspath(work_dir) self._work_dir = osp.abspath(work_dir)
mmengine.mkdir_or_exist(self._work_dir) mmengine.mkdir_or_exist(self._work_dir)
# recursively copy the ``cfg`` because `self.cfg` will be modified # recursively copy the `cfg` because `self.cfg` will be modified
# everywhere. # everywhere.
if cfg is not None: if cfg is not None:
self.cfg = copy.deepcopy(cfg) self.cfg = copy.deepcopy(cfg)
...@@ -231,21 +240,25 @@ class Runner: ...@@ -231,21 +240,25 @@ class Runner:
self._iter = 0 self._iter = 0
# lazy initialization # lazy initialization
training_related = [ training_related = [train_dataloader, train_cfg, optimizer]
train_dataloader, train_cfg, optimizer, param_scheduler
]
if not (all(item is None for item in training_related) if not (all(item is None for item in training_related)
or all(item is not None for item in training_related)): or all(item is not None for item in training_related)):
raise ValueError( raise ValueError(
'train_dataloader, train_cfg, optimizer and param_scheduler ' 'train_dataloader, train_cfg, and optimizer should be either '
'should be either all None or not None, but got ' 'all None or not None, but got '
f'train_dataloader={train_dataloader}, ' f'train_dataloader={train_dataloader}, '
f'train_cfg={train_cfg}, ' f'train_cfg={train_cfg}, '
f'optimizer={optimizer}, ' f'optimizer={optimizer}.')
f'param_scheduler={param_scheduler}.')
self.train_dataloader = train_dataloader self.train_dataloader = train_dataloader
self.train_loop = train_cfg self.train_loop = train_cfg
self.optimizer = optimizer self.optimizer = optimizer
# If there is no need to adjust learning rate, momentum or other
# parameters of optimizer, param_scheduler can be None
if param_scheduler is not None and self.optimizer is None:
raise ValueError(
'param_scheduler should be None when optimizer is None, '
f'but got {param_scheduler}')
if not isinstance(param_scheduler, Sequence): if not isinstance(param_scheduler, Sequence):
self.param_schedulers = [param_scheduler] self.param_schedulers = [param_scheduler]
else: else:
...@@ -256,7 +269,7 @@ class Runner: ...@@ -256,7 +269,7 @@ class Runner:
for item in val_related) or all(item is not None for item in val_related) or all(item is not None
for item in val_related)): for item in val_related)):
raise ValueError( raise ValueError(
'val_dataloader, val_cfg and val_evaluator should be either ' 'val_dataloader, val_cfg, and val_evaluator should be either '
'all None or not None, but got ' 'all None or not None, but got '
f'val_dataloader={val_dataloader}, val_cfg={val_cfg}, ' f'val_dataloader={val_dataloader}, val_cfg={val_cfg}, '
f'val_evaluator={val_evaluator}') f'val_evaluator={val_evaluator}')
...@@ -268,8 +281,8 @@ class Runner: ...@@ -268,8 +281,8 @@ class Runner:
if not (all(item is None for item in test_related) if not (all(item is None for item in test_related)
or all(item is not None for item in test_related)): or all(item is not None for item in test_related)):
raise ValueError( raise ValueError(
'test_dataloader, test_cfg and test_evaluator should be either' 'test_dataloader, test_cfg, and test_evaluator should be '
' all None or not None, but got ' 'either all None or not None, but got '
f'test_dataloader={test_dataloader}, test_cfg={test_cfg}, ' f'test_dataloader={test_dataloader}, test_cfg={test_cfg}, '
f'test_evaluator={test_evaluator}') f'test_evaluator={test_evaluator}')
self.test_dataloader = test_dataloader self.test_dataloader = test_dataloader
...@@ -282,10 +295,13 @@ class Runner: ...@@ -282,10 +295,13 @@ class Runner:
else: else:
self._distributed = True self._distributed = True
# self._deterministic, self._seed and self._timestamp will be set in # self._timestamp will be set in the `setup_env` method. Besides,
# the `setup_env`` method. Besides, it also will initialize # it also will initialize multi-process and (or) distributed
# multi-process and (or) distributed environment. # environment.
self.setup_env(env_cfg, seed, deterministic) self.setup_env(env_cfg)
# self._deterministic and self._seed will be set in the
# `set_randomness`` method
self.set_randomness(**randomness)
if experiment_name is not None: if experiment_name is not None:
self._experiment_name = f'{experiment_name}_{self._timestamp}' self._experiment_name = f'{experiment_name}_{self._timestamp}'
...@@ -296,9 +312,9 @@ class Runner: ...@@ -296,9 +312,9 @@ class Runner:
else: else:
self._experiment_name = self.timestamp self._experiment_name = self.timestamp
self.logger = self.build_logger(logger) self.logger = self.build_logger(log_level=log_level)
# message hub used for component interaction # message hub used for component interaction
self.message_hub = self.build_message_hub(message_hub) self.message_hub = self.build_message_hub()
# writer used for writing log or visualizing all kinds of data # writer used for writing log or visualizing all kinds of data
self.writer = self.build_writer(writer) self.writer = self.build_writer(writer)
# Used to reset registries location. See :meth:`Registry.build` for # Used to reset registries location. See :meth:`Registry.build` for
...@@ -333,7 +349,7 @@ class Runner: ...@@ -333,7 +349,7 @@ class Runner:
self.dump_config() self.dump_config()
@classmethod @classmethod
def build_from_cfg(cls, cfg: ConfigType) -> 'Runner': def from_cfg(cls, cfg: ConfigType) -> 'Runner':
"""Build a runner from config. """Build a runner from config.
Args: Args:
...@@ -363,12 +379,11 @@ class Runner: ...@@ -363,12 +379,11 @@ class Runner:
resume=cfg.get('resume', False), resume=cfg.get('resume', False),
launcher=cfg.get('launcher', 'none'), launcher=cfg.get('launcher', 'none'),
env_cfg=cfg.get('env_cfg'), # type: ignore env_cfg=cfg.get('env_cfg'), # type: ignore
logger=cfg.get('log_cfg'), log_level=cfg.get('log_level', 'INFO'),
message_hub=cfg.get('message_hub'),
writer=cfg.get('writer'), writer=cfg.get('writer'),
default_scope=cfg.get('default_scope'), default_scope=cfg.get('default_scope'),
seed=cfg.get('seed'), randomness=cfg.get('randomness', dict(seed=None)),
deterministic=cfg.get('deterministic', False), experiment_name=cfg.get('experiment_name'),
cfg=cfg, cfg=cfg,
) )
...@@ -439,10 +454,7 @@ class Runner: ...@@ -439,10 +454,7 @@ class Runner:
"""list[:obj:`Hook`]: A list of registered hooks.""" """list[:obj:`Hook`]: A list of registered hooks."""
return self._hooks return self._hooks
def setup_env(self, def setup_env(self, env_cfg: Dict) -> None:
env_cfg: Dict,
seed: Optional[int],
deterministic: bool = False) -> None:
"""Setup environment. """Setup environment.
An example of ``env_cfg``:: An example of ``env_cfg``::
...@@ -458,17 +470,7 @@ class Runner: ...@@ -458,17 +470,7 @@ class Runner:
Args: Args:
env_cfg (dict): Config for setting environment. env_cfg (dict): Config for setting environment.
seed (int, optional): A number to set random modules. If not
specified, a random number will be set as seed.
Defaults to None.
deterministic (bool): Whether cudnn to select deterministic
algorithms. Defaults to False.
See https://pytorch.org/docs/stable/notes/randomness.html for
more details.
""" """
self._deterministic = deterministic
self._seed = seed
if env_cfg.get('cudnn_benchmark'): if env_cfg.get('cudnn_benchmark'):
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
...@@ -490,9 +492,6 @@ class Runner: ...@@ -490,9 +492,6 @@ class Runner:
self._timestamp = time.strftime('%Y%m%d_%H%M%S', self._timestamp = time.strftime('%Y%m%d_%H%M%S',
time.localtime(timestamp.item())) time.localtime(timestamp.item()))
# set random seeds
self._set_random_seed()
def _set_multi_processing(self, def _set_multi_processing(self,
mp_start_method: str = 'fork', mp_start_method: str = 'fork',
opencv_num_threads: int = 0) -> None: opencv_num_threads: int = 0) -> None:
...@@ -546,16 +545,20 @@ class Runner: ...@@ -546,16 +545,20 @@ class Runner:
'optimal performance in your application as needed.') 'optimal performance in your application as needed.')
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads) os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
def _set_random_seed(self) -> None: def set_randomness(self, seed, deterministic: bool = False) -> None:
"""Set random seed to guarantee reproducible results. """Set random seed to guarantee reproducible results.
Warning: Args:
Results can not be guaranteed to resproducible if ``self.seed`` is seed (int): A number to set random modules.
None because :meth:`_set_random_seed` will generate a random seed deterministic (bool): Whether to set the deterministic option for
when launching a new experiment. CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False.
See https://pytorch.org/docs/stable/notes/randomness.html for details. Defaults to False.
See https://pytorch.org/docs/stable/notes/randomness.html for
more details.
""" """
self._deterministic = deterministic
self._seed = seed
if self._seed is None: if self._seed is None:
self._seed = sync_random_seed() self._seed = sync_random_seed()
...@@ -563,69 +566,62 @@ class Runner: ...@@ -563,69 +566,62 @@ class Runner:
np.random.seed(self._seed) np.random.seed(self._seed)
torch.manual_seed(self._seed) torch.manual_seed(self._seed)
torch.cuda.manual_seed_all(self._seed) torch.cuda.manual_seed_all(self._seed)
if self._deterministic: if deterministic:
torch.backends.cudnn.deterministic = True if torch.backends.cudnn.benchmark:
warnings.warn(
'torch.backends.cudnn.benchmark is going to be set as '
'`False` to cause cuDNN to deterministically select an '
'algorithm')
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if digit_version(TORCH_VERSION) >= digit_version('1.10.0'):
torch.use_deterministic_algorithms(True)
def build_logger(self, def build_logger(self,
logger: Optional[Union[MMLogger, log_level: Union[int, str] = 'INFO',
Dict]] = None) -> MMLogger: log_file: str = None,
**kwargs) -> MMLogger:
"""Build a global asscessable MMLogger. """Build a global asscessable MMLogger.
Args: Args:
logger (MMLogger or dict, optional): A MMLogger object or a dict to log_level (int or str): The log level of MMLogger handlers.
build MMLogger object. If ``logger`` is a MMLogger object, just Defaults to 'INFO'.
returns itself. If not specified, default config will be used log_file (str, optional): Path of filename to save log.
to build MMLogger object. Defaults to None. Defaults to None.
**kwargs: Remaining parameters passed to ``MMLogger``.
Returns: Returns:
MMLogger: A MMLogger object build from ``logger``. MMLogger: A MMLogger object build from ``logger``.
""" """
if isinstance(logger, MMLogger): if log_file is None:
return logger log_file = osp.join(self.work_dir, f'{self._experiment_name}.log')
elif logger is None:
logger = dict(
name=self._experiment_name,
log_level='INFO',
log_file=osp.join(self.work_dir,
f'{self._experiment_name}.log'))
elif isinstance(logger, dict):
# ensure logger containing name key
logger.setdefault('name', self._experiment_name)
else:
raise TypeError(
'logger should be MMLogger object, a dict or None, '
f'but got {logger}')
return MMLogger.get_instance(**logger) log_cfg = dict(log_level=log_level, log_file=log_file, **kwargs)
log_cfg.setdefault('name', self._experiment_name)
def build_message_hub( return MMLogger.get_instance(**log_cfg) # type: ignore
self,
message_hub: Optional[Union[MessageHub, def build_message_hub(self,
Dict]] = None) -> MessageHub: message_hub: Optional[Dict] = None) -> MessageHub:
"""Build a global asscessable MessageHub. """Build a global asscessable MessageHub.
Args: Args:
message_hub (MessageHub or dict, optional): A MessageHub object or message_hub (dict, optional): A dict to build MessageHub object.
a dict to build MessageHub object. If ``message_hub`` is a If not specified, default config will be used to build
MessageHub object, just returns itself. If not specified, MessageHub object. Defaults to None.
default config will be used to build MessageHub object.
Defaults to None.
Returns: Returns:
MessageHub: A MessageHub object build from ``message_hub``. MessageHub: A MessageHub object build from ``message_hub``.
""" """
if isinstance(message_hub, MessageHub): if message_hub is None:
return message_hub
elif message_hub is None:
message_hub = dict(name=self._experiment_name) message_hub = dict(name=self._experiment_name)
elif isinstance(message_hub, dict): elif isinstance(message_hub, dict):
# ensure message_hub containing name key # ensure message_hub containing name key
message_hub.setdefault('name', self._experiment_name) message_hub.setdefault('name', self._experiment_name)
else: else:
raise TypeError( raise TypeError(
'message_hub should be MessageHub object, a dict or None, ' f'message_hub should be dict or None, but got {message_hub}')
f'but got {message_hub}')
return MessageHub.get_instance(**message_hub) return MessageHub.get_instance(**message_hub)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import copy import copy
import logging
import os.path as osp import os.path as osp
import shutil import shutil
import tempfile import tempfile
import time
from unittest import TestCase from unittest import TestCase
import torch import torch
...@@ -229,8 +227,6 @@ class TestRunner(TestCase): ...@@ -229,8 +227,6 @@ class TestRunner(TestCase):
optimizer=dict(type='OptimizerHook', grad_clip=None), optimizer=dict(type='OptimizerHook', grad_clip=None),
param_scheduler=dict(type='ParamSchedulerHook')) param_scheduler=dict(type='ParamSchedulerHook'))
time.sleep(1)
def tearDown(self): def tearDown(self):
shutil.rmtree(self.temp_dir) shutil.rmtree(self.temp_dir)
...@@ -238,89 +234,99 @@ class TestRunner(TestCase): ...@@ -238,89 +234,99 @@ class TestRunner(TestCase):
# 1. test arguments # 1. test arguments
# 1.1 train_dataloader, train_cfg, optimizer and param_scheduler # 1.1 train_dataloader, train_cfg, optimizer and param_scheduler
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_init1'
cfg.pop('train_cfg') cfg.pop('train_cfg')
with self.assertRaisesRegex(ValueError, 'either all None or not None'): with self.assertRaisesRegex(ValueError, 'either all None or not None'):
Runner(**cfg) Runner(**cfg)
# all of training related configs are None # all of training related configs are None and param_scheduler should
# also be None
cfg.experiment_name = 'test_init2'
cfg.pop('train_dataloader') cfg.pop('train_dataloader')
cfg.pop('optimizer') cfg.pop('optimizer')
cfg.pop('param_scheduler') cfg.pop('param_scheduler')
runner = Runner(**cfg) runner = Runner(**cfg)
self.assertIsInstance(runner, Runner) self.assertIsInstance(runner, Runner)
# avoid different runners having same timestamp
time.sleep(1)
# all of training related configs are not None # all of training related configs are not None
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_init3'
runner = Runner(**cfg)
self.assertIsInstance(runner, Runner)
# all of training related configs are not None and param_scheduler
# can be None
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_init4'
cfg.pop('param_scheduler')
runner = Runner(**cfg) runner = Runner(**cfg)
self.assertIsInstance(runner, Runner) self.assertIsInstance(runner, Runner)
# param_scheduler should be None when optimizer is None
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_init5'
cfg.pop('train_cfg')
cfg.pop('train_dataloader')
cfg.pop('optimizer')
with self.assertRaisesRegex(ValueError, 'should be None'):
runner = Runner(**cfg)
# 1.2 val_dataloader, val_evaluator, val_cfg # 1.2 val_dataloader, val_evaluator, val_cfg
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_init6'
cfg.pop('val_cfg') cfg.pop('val_cfg')
with self.assertRaisesRegex(ValueError, 'either all None or not None'): with self.assertRaisesRegex(ValueError, 'either all None or not None'):
Runner(**cfg) Runner(**cfg)
time.sleep(1) cfg.experiment_name = 'test_init7'
cfg.pop('val_dataloader') cfg.pop('val_dataloader')
cfg.pop('val_evaluator') cfg.pop('val_evaluator')
runner = Runner(**cfg) runner = Runner(**cfg)
self.assertIsInstance(runner, Runner) self.assertIsInstance(runner, Runner)
time.sleep(1)
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_init8'
runner = Runner(**cfg) runner = Runner(**cfg)
self.assertIsInstance(runner, Runner) self.assertIsInstance(runner, Runner)
# 1.3 test_dataloader, test_evaluator and test_cfg # 1.3 test_dataloader, test_evaluator and test_cfg
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_init9'
cfg.pop('test_cfg') cfg.pop('test_cfg')
with self.assertRaisesRegex(ValueError, 'either all None or not None'): with self.assertRaisesRegex(ValueError, 'either all None or not None'):
runner = Runner(**cfg) runner = Runner(**cfg)
time.sleep(1) cfg.experiment_name = 'test_init10'
cfg.pop('test_dataloader') cfg.pop('test_dataloader')
cfg.pop('test_evaluator') cfg.pop('test_evaluator')
runner = Runner(**cfg) runner = Runner(**cfg)
self.assertIsInstance(runner, Runner) self.assertIsInstance(runner, Runner)
time.sleep(1)
# 1.4 test env params # 1.4 test env params
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_init11'
runner = Runner(**cfg) runner = Runner(**cfg)
self.assertFalse(runner.distributed) self.assertFalse(runner.distributed)
self.assertFalse(runner.deterministic) self.assertFalse(runner.deterministic)
time.sleep(1)
# 1.5 message_hub, logger and writer # 1.5 message_hub, logger and writer
# they are all not specified # they are all not specified
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_init12'
runner = Runner(**cfg) runner = Runner(**cfg)
self.assertIsInstance(runner.logger, MMLogger) self.assertIsInstance(runner.logger, MMLogger)
self.assertIsInstance(runner.message_hub, MessageHub) self.assertIsInstance(runner.message_hub, MessageHub)
self.assertIsInstance(runner.writer, ComposedWriter) self.assertIsInstance(runner.writer, ComposedWriter)
time.sleep(1)
# they are all specified # they are all specified
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.logger = dict(name='test_logger') cfg.experiment_name = 'test_init13'
cfg.message_hub = dict(name='test_message_hub') cfg.log_level = 'INFO'
cfg.writer = dict(name='test_writer') cfg.writer = dict(name='test_writer')
runner = Runner(**cfg) runner = Runner(**cfg)
self.assertIsInstance(runner.logger, MMLogger) self.assertIsInstance(runner.logger, MMLogger)
self.assertEqual(runner.logger.instance_name, 'test_logger')
self.assertIsInstance(runner.message_hub, MessageHub) self.assertIsInstance(runner.message_hub, MessageHub)
self.assertEqual(runner.message_hub.instance_name, 'test_message_hub')
self.assertIsInstance(runner.writer, ComposedWriter) self.assertIsInstance(runner.writer, ComposedWriter)
self.assertEqual(runner.writer.instance_name, 'test_writer')
assert runner.distributed is False assert runner.distributed is False
assert runner.seed is not None assert runner.seed is not None
...@@ -358,7 +364,6 @@ class TestRunner(TestCase): ...@@ -358,7 +364,6 @@ class TestRunner(TestCase):
self.assertIsInstance(runner.test_loop.dataloader, DataLoader) self.assertIsInstance(runner.test_loop.dataloader, DataLoader)
self.assertIsInstance(runner.test_loop.evaluator, ToyEvaluator1) self.assertIsInstance(runner.test_loop.evaluator, ToyEvaluator1)
time.sleep(1)
# 4. initialize runner with objects rather than config # 4. initialize runner with objects rather than config
model = ToyModel() model = ToyModel()
optimizer = SGD( optimizer = SGD(
...@@ -385,84 +390,66 @@ class TestRunner(TestCase): ...@@ -385,84 +390,66 @@ class TestRunner(TestCase):
test_dataloader=test_dataloader, test_dataloader=test_dataloader,
test_evaluator=ToyEvaluator1(), test_evaluator=ToyEvaluator1(),
default_hooks=dict(param_scheduler=toy_hook), default_hooks=dict(param_scheduler=toy_hook),
custom_hooks=[toy_hook2]) custom_hooks=[toy_hook2],
experiment_name='test_init14')
runner.train() runner.train()
runner.test() runner.test()
# 5. test `dump_config` # 5. test `dump_config`
# TODO # TODO
def test_build_from_cfg(self): def test_from_cfg(self):
runner = Runner.build_from_cfg(cfg=self.epoch_based_cfg) runner = Runner.from_cfg(cfg=self.epoch_based_cfg)
self.assertIsInstance(runner, Runner) self.assertIsInstance(runner, Runner)
def test_setup_env(self): def test_setup_env(self):
# TODO # TODO
pass pass
def test_logger(self): def test_build_logger(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg) self.epoch_based_cfg.experiment_name = 'test_build_logger1'
runner = Runner.from_cfg(self.epoch_based_cfg)
self.assertIsInstance(runner.logger, MMLogger) self.assertIsInstance(runner.logger, MMLogger)
self.assertEqual(runner.experiment_name, runner.logger.instance_name) self.assertEqual(runner.experiment_name, runner.logger.instance_name)
self.assertEqual(runner.logger.level, logging.NOTSET)
# input is a MMLogger object
self.assertEqual(
id(runner.build_logger(runner.logger)), id(runner.logger))
# input is None
runner._experiment_name = 'logger_name1'
logger = runner.build_logger(None)
self.assertIsInstance(logger, MMLogger)
self.assertEqual(logger.instance_name, 'logger_name1')
# input is a dict # input is a dict
log_cfg = dict(name='logger_name2') logger = runner.build_logger(name='test_build_logger2')
logger = runner.build_logger(log_cfg)
self.assertIsInstance(logger, MMLogger) self.assertIsInstance(logger, MMLogger)
self.assertEqual(logger.instance_name, 'logger_name2') self.assertEqual(logger.instance_name, 'test_build_logger2')
# input is a dict but does not contain name key # input is a dict but does not contain name key
runner._experiment_name = 'logger_name3' runner._experiment_name = 'test_build_logger3'
log_cfg = dict() logger = runner.build_logger()
logger = runner.build_logger(log_cfg)
self.assertIsInstance(logger, MMLogger) self.assertIsInstance(logger, MMLogger)
self.assertEqual(logger.instance_name, 'logger_name3') self.assertEqual(logger.instance_name, 'test_build_logger3')
# input is not a valid type
with self.assertRaisesRegex(TypeError, 'logger should be'):
runner.build_logger('invalid-type')
def test_build_message_hub(self): def test_build_message_hub(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg) self.epoch_based_cfg.experiment_name = 'test_build_message_hub1'
runner = Runner.from_cfg(self.epoch_based_cfg)
self.assertIsInstance(runner.message_hub, MessageHub) self.assertIsInstance(runner.message_hub, MessageHub)
self.assertEqual(runner.message_hub.instance_name, self.assertEqual(runner.message_hub.instance_name,
runner.experiment_name) runner.experiment_name)
# input is a MessageHub object
self.assertEqual(
id(runner.build_message_hub(runner.message_hub)),
id(runner.message_hub))
# input is a dict # input is a dict
message_hub_cfg = dict(name='message_hub_name1') message_hub_cfg = dict(name='test_build_message_hub2')
message_hub = runner.build_message_hub(message_hub_cfg) message_hub = runner.build_message_hub(message_hub_cfg)
self.assertIsInstance(message_hub, MessageHub) self.assertIsInstance(message_hub, MessageHub)
self.assertEqual(message_hub.instance_name, 'message_hub_name1') self.assertEqual(message_hub.instance_name, 'test_build_message_hub2')
# input is a dict but does not contain name key # input is a dict but does not contain name key
runner._experiment_name = 'message_hub_name2' runner._experiment_name = 'test_build_message_hub3'
message_hub_cfg = dict() message_hub_cfg = dict()
message_hub = runner.build_message_hub(message_hub_cfg) message_hub = runner.build_message_hub(message_hub_cfg)
self.assertIsInstance(message_hub, MessageHub) self.assertIsInstance(message_hub, MessageHub)
self.assertEqual(message_hub.instance_name, 'message_hub_name2') self.assertEqual(message_hub.instance_name, 'test_build_message_hub3')
# input is not a valid type # input is not a valid type
with self.assertRaisesRegex(TypeError, 'message_hub should be'): with self.assertRaisesRegex(TypeError, 'message_hub should be'):
runner.build_message_hub('invalid-type') runner.build_message_hub('invalid-type')
def test_build_writer(self): def test_build_writer(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg) self.epoch_based_cfg.experiment_name = 'test_build_writer1'
runner = Runner.from_cfg(self.epoch_based_cfg)
self.assertIsInstance(runner.writer, ComposedWriter) self.assertIsInstance(runner.writer, ComposedWriter)
self.assertEqual(runner.experiment_name, runner.writer.instance_name) self.assertEqual(runner.experiment_name, runner.writer.instance_name)
...@@ -471,17 +458,17 @@ class TestRunner(TestCase): ...@@ -471,17 +458,17 @@ class TestRunner(TestCase):
id(runner.build_writer(runner.writer)), id(runner.writer)) id(runner.build_writer(runner.writer)), id(runner.writer))
# input is a dict # input is a dict
writer_cfg = dict(name='writer_name1') writer_cfg = dict(name='test_build_writer2')
writer = runner.build_writer(writer_cfg) writer = runner.build_writer(writer_cfg)
self.assertIsInstance(writer, ComposedWriter) self.assertIsInstance(writer, ComposedWriter)
self.assertEqual(writer.instance_name, 'writer_name1') self.assertEqual(writer.instance_name, 'test_build_writer2')
# input is a dict but does not contain name key # input is a dict but does not contain name key
runner._experiment_name = 'writer_name2' runner._experiment_name = 'test_build_writer3'
writer_cfg = dict() writer_cfg = dict()
writer = runner.build_writer(writer_cfg) writer = runner.build_writer(writer_cfg)
self.assertIsInstance(writer, ComposedWriter) self.assertIsInstance(writer, ComposedWriter)
self.assertEqual(writer.instance_name, 'writer_name2') self.assertEqual(writer.instance_name, 'test_build_writer3')
# input is not a valid type # input is not a valid type
with self.assertRaisesRegex(TypeError, 'writer should be'): with self.assertRaisesRegex(TypeError, 'writer should be'):
...@@ -501,12 +488,16 @@ class TestRunner(TestCase): ...@@ -501,12 +488,16 @@ class TestRunner(TestCase):
type='ToyScheduler', milestones=[1, 2]) type='ToyScheduler', milestones=[1, 2])
self.epoch_based_cfg.default_scope = 'toy' self.epoch_based_cfg.default_scope = 'toy'
runner = Runner.build_from_cfg(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_default_scope'
runner = Runner.from_cfg(cfg)
runner.train() runner.train()
self.assertIsInstance(runner.param_schedulers[0], ToyScheduler) self.assertIsInstance(runner.param_schedulers[0], ToyScheduler)
def test_build_model(self): def test_build_model(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_model'
runner = Runner.from_cfg(cfg)
self.assertIsInstance(runner.model, ToyModel) self.assertIsInstance(runner.model, ToyModel)
# input should be a nn.Module object or dict # input should be a nn.Module object or dict
...@@ -526,12 +517,15 @@ class TestRunner(TestCase): ...@@ -526,12 +517,15 @@ class TestRunner(TestCase):
# TODO: test on distributed environment # TODO: test on distributed environment
# custom model wrapper # custom model wrapper
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_wrap_model'
cfg.model_wrapper_cfg = dict(type='CustomModelWrapper') cfg.model_wrapper_cfg = dict(type='CustomModelWrapper')
runner = Runner.build_from_cfg(cfg) runner = Runner.from_cfg(cfg)
self.assertIsInstance(runner.model, CustomModelWrapper) self.assertIsInstance(runner.model, CustomModelWrapper)
def test_build_optimizer(self): def test_build_optimizer(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_optimizer'
runner = Runner.from_cfg(cfg)
# input should be an Optimizer object or dict # input should be an Optimizer object or dict
with self.assertRaisesRegex(TypeError, 'optimizer should be'): with self.assertRaisesRegex(TypeError, 'optimizer should be'):
...@@ -547,7 +541,9 @@ class TestRunner(TestCase): ...@@ -547,7 +541,9 @@ class TestRunner(TestCase):
self.assertIsInstance(optimizer, SGD) self.assertIsInstance(optimizer, SGD)
def test_build_param_scheduler(self): def test_build_param_scheduler(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_param_scheduler'
runner = Runner.from_cfg(cfg)
# `build_optimizer` should be called before `build_param_scheduler` # `build_optimizer` should be called before `build_param_scheduler`
cfg = dict(type='MultiStepLR', milestones=[1, 2]) cfg = dict(type='MultiStepLR', milestones=[1, 2])
...@@ -584,7 +580,9 @@ class TestRunner(TestCase): ...@@ -584,7 +580,9 @@ class TestRunner(TestCase):
self.assertIsInstance(param_schedulers[1], StepLR) self.assertIsInstance(param_schedulers[1], StepLR)
def test_build_evaluator(self): def test_build_evaluator(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_evaluator'
runner = Runner.from_cfg(cfg)
# input is a BaseEvaluator or ComposedEvaluator object # input is a BaseEvaluator or ComposedEvaluator object
evaluator = ToyEvaluator1() evaluator = ToyEvaluator1()
...@@ -603,7 +601,9 @@ class TestRunner(TestCase): ...@@ -603,7 +601,9 @@ class TestRunner(TestCase):
runner.build_evaluator(evaluator), ComposedEvaluator) runner.build_evaluator(evaluator), ComposedEvaluator)
def test_build_dataloader(self): def test_build_dataloader(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_dataloader'
runner = Runner.from_cfg(cfg)
cfg = dict( cfg = dict(
dataset=dict(type='ToyDataset'), dataset=dict(type='ToyDataset'),
...@@ -616,8 +616,11 @@ class TestRunner(TestCase): ...@@ -616,8 +616,11 @@ class TestRunner(TestCase):
self.assertIsInstance(dataloader.sampler, DefaultSampler) self.assertIsInstance(dataloader.sampler, DefaultSampler)
def test_build_train_loop(self): def test_build_train_loop(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_train_loop'
runner = Runner.from_cfg(cfg)
# input should be a Loop object or dict # input should be a Loop object or dict
runner = Runner.build_from_cfg(self.epoch_based_cfg)
with self.assertRaisesRegex(TypeError, 'should be'): with self.assertRaisesRegex(TypeError, 'should be'):
runner.build_train_loop('invalid-type') runner.build_train_loop('invalid-type')
...@@ -653,7 +656,9 @@ class TestRunner(TestCase): ...@@ -653,7 +656,9 @@ class TestRunner(TestCase):
self.assertIsInstance(loop, CustomTrainLoop) self.assertIsInstance(loop, CustomTrainLoop)
def test_build_val_loop(self): def test_build_val_loop(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_val_loop'
runner = Runner.from_cfg(cfg)
# input should be a Loop object or dict # input should be a Loop object or dict
with self.assertRaisesRegex(TypeError, 'should be'): with self.assertRaisesRegex(TypeError, 'should be'):
...@@ -678,7 +683,9 @@ class TestRunner(TestCase): ...@@ -678,7 +683,9 @@ class TestRunner(TestCase):
self.assertIsInstance(loop, CustomValLoop) self.assertIsInstance(loop, CustomValLoop)
def test_build_test_loop(self): def test_build_test_loop(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_test_loop'
runner = Runner.from_cfg(cfg)
# input should be a Loop object or dict # input should be a Loop object or dict
with self.assertRaisesRegex(TypeError, 'should be'): with self.assertRaisesRegex(TypeError, 'should be'):
...@@ -705,16 +712,15 @@ class TestRunner(TestCase): ...@@ -705,16 +712,15 @@ class TestRunner(TestCase):
def test_train(self): def test_train(self):
# 1. test `self.train_loop` is None # 1. test `self.train_loop` is None
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_train1'
cfg.pop('train_dataloader') cfg.pop('train_dataloader')
cfg.pop('train_cfg') cfg.pop('train_cfg')
cfg.pop('optimizer') cfg.pop('optimizer')
cfg.pop('param_scheduler') cfg.pop('param_scheduler')
runner = Runner.build_from_cfg(cfg) runner = Runner.from_cfg(cfg)
with self.assertRaisesRegex(RuntimeError, 'should not be None'): with self.assertRaisesRegex(RuntimeError, 'should not be None'):
runner.train() runner.train()
time.sleep(1)
# 2. test iter and epoch counter of EpochBasedTrainLoop # 2. test iter and epoch counter of EpochBasedTrainLoop
epoch_results = [] epoch_results = []
epoch_targets = [i for i in range(3)] epoch_targets = [i for i in range(3)]
...@@ -733,10 +739,10 @@ class TestRunner(TestCase): ...@@ -733,10 +739,10 @@ class TestRunner(TestCase):
iter_results.append(runner.iter) iter_results.append(runner.iter)
batch_idx_results.append(batch_idx) batch_idx_results.append(batch_idx)
self.epoch_based_cfg.custom_hooks = [ cfg = copy.deepcopy(self.epoch_based_cfg)
dict(type='TestEpochHook', priority=50) cfg.experiment_name = 'test_train2'
] cfg.custom_hooks = [dict(type='TestEpochHook', priority=50)]
runner = Runner.build_from_cfg(self.epoch_based_cfg) runner = Runner.from_cfg(cfg)
runner.train() runner.train()
...@@ -749,8 +755,6 @@ class TestRunner(TestCase): ...@@ -749,8 +755,6 @@ class TestRunner(TestCase):
for result, target, in zip(batch_idx_results, batch_idx_targets): for result, target, in zip(batch_idx_results, batch_idx_targets):
self.assertEqual(result, target) self.assertEqual(result, target)
time.sleep(1)
# 3. test iter and epoch counter of IterBasedTrainLoop # 3. test iter and epoch counter of IterBasedTrainLoop
epoch_results = [] epoch_results = []
iter_results = [] iter_results = []
...@@ -768,11 +772,11 @@ class TestRunner(TestCase): ...@@ -768,11 +772,11 @@ class TestRunner(TestCase):
iter_results.append(runner.iter) iter_results.append(runner.iter)
batch_idx_results.append(batch_idx) batch_idx_results.append(batch_idx)
self.iter_based_cfg.custom_hooks = [ cfg = copy.deepcopy(self.iter_based_cfg)
dict(type='TestIterHook', priority=50) cfg.experiment_name = 'test_train3'
] cfg.custom_hooks = [dict(type='TestIterHook', priority=50)]
self.iter_based_cfg.val_cfg = dict(interval=4) cfg.val_cfg = dict(interval=4)
runner = Runner.build_from_cfg(self.iter_based_cfg) runner = Runner.from_cfg(cfg)
runner.train() runner.train()
assert isinstance(runner.train_loop, IterBasedTrainLoop) assert isinstance(runner.train_loop, IterBasedTrainLoop)
...@@ -786,32 +790,38 @@ class TestRunner(TestCase): ...@@ -786,32 +790,38 @@ class TestRunner(TestCase):
def test_val(self): def test_val(self):
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_val1'
cfg.pop('val_dataloader') cfg.pop('val_dataloader')
cfg.pop('val_cfg') cfg.pop('val_cfg')
cfg.pop('val_evaluator') cfg.pop('val_evaluator')
runner = Runner.build_from_cfg(cfg) runner = Runner.from_cfg(cfg)
with self.assertRaisesRegex(RuntimeError, 'should not be None'): with self.assertRaisesRegex(RuntimeError, 'should not be None'):
runner.val() runner.val()
time.sleep(1) cfg = copy.deepcopy(self.epoch_based_cfg)
runner = Runner.build_from_cfg(self.epoch_based_cfg) cfg.experiment_name = 'test_val2'
runner = Runner.from_cfg(cfg)
runner.val() runner.val()
def test_test(self): def test_test(self):
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_test1'
cfg.pop('test_dataloader') cfg.pop('test_dataloader')
cfg.pop('test_cfg') cfg.pop('test_cfg')
cfg.pop('test_evaluator') cfg.pop('test_evaluator')
runner = Runner.build_from_cfg(cfg) runner = Runner.from_cfg(cfg)
with self.assertRaisesRegex(RuntimeError, 'should not be None'): with self.assertRaisesRegex(RuntimeError, 'should not be None'):
runner.test() runner.test()
time.sleep(1) cfg = copy.deepcopy(self.epoch_based_cfg)
runner = Runner.build_from_cfg(self.epoch_based_cfg) cfg.experiment_name = 'test_test2'
runner = Runner.from_cfg(cfg)
runner.test() runner.test()
def test_register_hook(self): def test_register_hook(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_register_hook'
runner = Runner.from_cfg(cfg)
runner._hooks = [] runner._hooks = []
# 1. test `hook` parameter # 1. test `hook` parameter
...@@ -870,7 +880,9 @@ class TestRunner(TestCase): ...@@ -870,7 +880,9 @@ class TestRunner(TestCase):
get_priority(runner._hooks[3].priority), get_priority('VERY_LOW')) get_priority(runner._hooks[3].priority), get_priority('VERY_LOW'))
def test_default_hooks(self): def test_default_hooks(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_default_hooks'
runner = Runner.from_cfg(cfg)
runner._hooks = [] runner._hooks = []
# register five hooks by default # register five hooks by default
...@@ -893,7 +905,10 @@ class TestRunner(TestCase): ...@@ -893,7 +905,10 @@ class TestRunner(TestCase):
self.assertTrue(isinstance(runner._hooks[5], ToyHook)) self.assertTrue(isinstance(runner._hooks[5], ToyHook))
def test_custom_hooks(self): def test_custom_hooks(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_custom_hooks'
runner = Runner.from_cfg(cfg)
self.assertEqual(len(runner._hooks), 5) self.assertEqual(len(runner._hooks), 5)
custom_hooks = [dict(type='ToyHook')] custom_hooks = [dict(type='ToyHook')]
runner.register_custom_hooks(custom_hooks) runner.register_custom_hooks(custom_hooks)
...@@ -901,7 +916,10 @@ class TestRunner(TestCase): ...@@ -901,7 +916,10 @@ class TestRunner(TestCase):
self.assertTrue(isinstance(runner._hooks[5], ToyHook)) self.assertTrue(isinstance(runner._hooks[5], ToyHook))
def test_register_hooks(self): def test_register_hooks(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_register_hooks'
runner = Runner.from_cfg(cfg)
runner._hooks = [] runner._hooks = []
custom_hooks = [dict(type='ToyHook')] custom_hooks = [dict(type='ToyHook')]
runner.register_hooks(custom_hooks=custom_hooks) runner.register_hooks(custom_hooks=custom_hooks)
...@@ -975,7 +993,8 @@ class TestRunner(TestCase): ...@@ -975,7 +993,8 @@ class TestRunner(TestCase):
self.iter_based_cfg.custom_hooks = [ self.iter_based_cfg.custom_hooks = [
dict(type='TestWarmupHook', priority=50) dict(type='TestWarmupHook', priority=50)
] ]
runner = Runner.build_from_cfg(self.iter_based_cfg) self.iter_based_cfg.experiment_name = 'test_custom_loop'
runner = Runner.from_cfg(self.iter_based_cfg)
runner.train() runner.train()
self.assertIsInstance(runner.train_loop, CustomTrainLoop2) self.assertIsInstance(runner.train_loop, CustomTrainLoop2)
...@@ -990,7 +1009,9 @@ class TestRunner(TestCase): ...@@ -990,7 +1009,9 @@ class TestRunner(TestCase):
def test_checkpoint(self): def test_checkpoint(self):
# 1. test epoch based # 1. test epoch based
runner = Runner.build_from_cfg(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_checkpoint1'
runner = Runner.from_cfg(cfg)
runner.train() runner.train()
# 1.1 test `save_checkpoint` which called by `CheckpointHook` # 1.1 test `save_checkpoint` which called by `CheckpointHook`
...@@ -1006,17 +1027,19 @@ class TestRunner(TestCase): ...@@ -1006,17 +1027,19 @@ class TestRunner(TestCase):
assert isinstance(ckpt['optimizer'], dict) assert isinstance(ckpt['optimizer'], dict)
assert isinstance(ckpt['param_schedulers'], list) assert isinstance(ckpt['param_schedulers'], list)
time.sleep(1)
# 1.2 test `load_checkpoint` # 1.2 test `load_checkpoint`
runner = Runner.build_from_cfg(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_checkpoint2'
runner = Runner.from_cfg(cfg)
runner.load_checkpoint(path) runner.load_checkpoint(path)
self.assertEqual(runner.epoch, 0) self.assertEqual(runner.epoch, 0)
self.assertEqual(runner.iter, 0) self.assertEqual(runner.iter, 0)
self.assertTrue(runner._has_loaded) self.assertTrue(runner._has_loaded)
time.sleep(1)
# 1.3 test `resume` # 1.3 test `resume`
runner = Runner.build_from_cfg(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_checkpoint3'
runner = Runner.from_cfg(cfg)
runner.resume(path) runner.resume(path)
self.assertEqual(runner.epoch, 3) self.assertEqual(runner.epoch, 3)
self.assertEqual(runner.iter, 12) self.assertEqual(runner.iter, 12)
...@@ -1025,8 +1048,9 @@ class TestRunner(TestCase): ...@@ -1025,8 +1048,9 @@ class TestRunner(TestCase):
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR) self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
# 2. test iter based # 2. test iter based
time.sleep(1) cfg = copy.deepcopy(self.iter_based_cfg)
runner = Runner.build_from_cfg(self.iter_based_cfg) cfg.experiment_name = 'test_checkpoint4'
runner = Runner.from_cfg(cfg)
runner.train() runner.train()
# 2.1 test `save_checkpoint` which called by `CheckpointHook` # 2.1 test `save_checkpoint` which called by `CheckpointHook`
...@@ -1043,16 +1067,18 @@ class TestRunner(TestCase): ...@@ -1043,16 +1067,18 @@ class TestRunner(TestCase):
assert isinstance(ckpt['param_schedulers'], list) assert isinstance(ckpt['param_schedulers'], list)
# 2.2 test `load_checkpoint` # 2.2 test `load_checkpoint`
time.sleep(1) cfg = copy.deepcopy(self.iter_based_cfg)
runner = Runner.build_from_cfg(self.iter_based_cfg) cfg.experiment_name = 'test_checkpoint5'
runner = Runner.from_cfg(cfg)
runner.load_checkpoint(path) runner.load_checkpoint(path)
self.assertEqual(runner.epoch, 0) self.assertEqual(runner.epoch, 0)
self.assertEqual(runner.iter, 0) self.assertEqual(runner.iter, 0)
self.assertTrue(runner._has_loaded) self.assertTrue(runner._has_loaded)
time.sleep(1)
# 2.3 test `resume` # 2.3 test `resume`
runner = Runner.build_from_cfg(self.iter_based_cfg) cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_checkpoint6'
runner = Runner.from_cfg(cfg)
runner.resume(path) runner.resume(path)
self.assertEqual(runner.epoch, 0) self.assertEqual(runner.epoch, 0)
self.assertEqual(runner.iter, 12) self.assertEqual(runner.iter, 12)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment