Skip to content
Snippets Groups Projects
Unverified Commit f1de071c authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

[Enhancement] Refactor Runner (#139)

* [Enhancement] Rename build_from_cfg to from_cfg

* refactor build_logger and build_message_hub

* remove time.sleep from unit tests

* minor fix

* move set_randomness from setup_env

* improve docstring

* refine comments

* print a warning information

* refine comments

* simplify the interface of build_logger
parent 9a61b389
No related branches found
No related tags found
No related merge requests found
...@@ -32,7 +32,8 @@ from mmengine.optim import _ParamScheduler, build_optimizer ...@@ -32,7 +32,8 @@ from mmengine.optim import _ParamScheduler, build_optimizer
from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS, MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS,
DefaultScope) DefaultScope)
from mmengine.utils import find_latest_checkpoint, is_list_of, symlink from mmengine.utils import (TORCH_VERSION, digit_version,
find_latest_checkpoint, is_list_of, symlink)
from mmengine.visualization import ComposedWriter from mmengine.visualization import ComposedWriter
from .base_loop import BaseLoop from .base_loop import BaseLoop
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model, from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
...@@ -47,6 +48,20 @@ ConfigType = Union[Dict, Config, ConfigDict] ...@@ -47,6 +48,20 @@ ConfigType = Union[Dict, Config, ConfigDict]
class Runner: class Runner:
"""A training helper for PyTorch. """A training helper for PyTorch.
Runner object can be built from config by ``runner = Runner.from_cfg(cfg)``
where the ``cfg`` usually contains training, validation, and test-related
configurations to build corresponding components. We usually use the
same config to launch training, testing, and validation tasks. However,
only some of these components are necessary at the same time, e.g.,
testing a model does not need training or validation-related components.
To avoid repeatedly modifying config, the construction of ``Runner`` adopts
lazy initialization to only initialize components when they are going to be
used. Therefore, the model is always initialized at the beginning, and
training, validation, and, testing related components are only initialized
when calling ``runner.train()``, ``runner.val()``, and ``runner.test()``,
respectively.
Args: Args:
model (:obj:`torch.nn.Module` or dict): The model to be run. It can be model (:obj:`torch.nn.Module` or dict): The model to be run. It can be
a dict used for build a model. a dict used for build a model.
...@@ -114,23 +129,21 @@ class Runner: ...@@ -114,23 +129,21 @@ class Runner:
non-distributed environment will be launched. non-distributed environment will be launched.
env_cfg (dict): A dict used for setting environment. Defaults to env_cfg (dict): A dict used for setting environment. Defaults to
dict(dist_cfg=dict(backend='nccl')). dict(dist_cfg=dict(backend='nccl')).
logger (MMLogger or dict, optional): A MMLogger object or a dict to log_level (int or str): The log level of MMLogger handlers.
build MMLogger object. Defaults to None. If not specified, default Defaults to 'INFO'.
config will be used.
message_hub (MessageHub or dict, optional): A Messagehub object or a
dict to build MessageHub object. Defaults to None. If not
specified, default config will be used.
writer (ComposedWriter or dict, optional): A ComposedWriter object or a writer (ComposedWriter or dict, optional): A ComposedWriter object or a
dict build ComposedWriter object. Defaults to None. If not dict build ComposedWriter object. Defaults to None. If not
specified, default config will be used. specified, default config will be used.
default_scope (str, optional): Used to reset registries location. default_scope (str, optional): Used to reset registries location.
Defaults to None. Defaults to None.
seed (int, optional): A number to set random modules. If not specified, randomness (dict): Some settings to make the experiment as reproducible
a random number will be set as seed. Defaults to None. as possible like seed and deterministic.
deterministic (bool): Whether cudnn to select deterministic algorithms. Defaults to ``dict(seed=None)``. If seed is None, a random number
Defaults to False. will be generated and it will be broadcasted to all other processes
See https://pytorch.org/docs/stable/notes/randomness.html for if in distributed environment. If ``cudnn_benchmarch`` is
more details. ``True`` in ``env_cfg`` but ``deterministic`` is ``True`` in
``randomness``, the value of ``torch.backends.cudnn.benchmark``
will be ``False`` finally.
experiment_name (str, optional): Name of current experiment. If not experiment_name (str, optional): Name of current experiment. If not
specified, timestamp will be used as ``experiment_name``. specified, timestamp will be used as ``experiment_name``.
Defaults to None. Defaults to None.
...@@ -173,13 +186,11 @@ class Runner: ...@@ -173,13 +186,11 @@ class Runner:
param_scheduler=dict(type='ParamSchedulerHook')), param_scheduler=dict(type='ParamSchedulerHook')),
launcher='none', launcher='none',
env_cfg=dict(dist_cfg=dict(backend='nccl')), env_cfg=dict(dist_cfg=dict(backend='nccl')),
logger=dict(log_level='INFO'),
message_hub=None,
writer=dict( writer=dict(
name='composed_writer', name='composed_writer',
writers=[dict(type='LocalWriter', save_dir='temp_dir')]) writers=[dict(type='LocalWriter', save_dir='temp_dir')])
) )
>>> runner = Runner.build_from_cfg(cfg) >>> runner = Runner.from_cfg(cfg)
>>> runner.train() >>> runner.train()
>>> runner.test() >>> runner.test()
""" """
...@@ -208,19 +219,17 @@ class Runner: ...@@ -208,19 +219,17 @@ class Runner:
resume: bool = False, resume: bool = False,
launcher: str = 'none', launcher: str = 'none',
env_cfg: Dict = dict(dist_cfg=dict(backend='nccl')), env_cfg: Dict = dict(dist_cfg=dict(backend='nccl')),
logger: Optional[Union[MMLogger, Dict]] = None, log_level: str = 'INFO',
message_hub: Optional[Union[MessageHub, Dict]] = None,
writer: Optional[Union[ComposedWriter, Dict]] = None, writer: Optional[Union[ComposedWriter, Dict]] = None,
default_scope: Optional[str] = None, default_scope: Optional[str] = None,
seed: Optional[int] = None, randomness: Dict = dict(seed=None),
deterministic: bool = False,
experiment_name: Optional[str] = None, experiment_name: Optional[str] = None,
cfg: Optional[ConfigType] = None, cfg: Optional[ConfigType] = None,
): ):
self._work_dir = osp.abspath(work_dir) self._work_dir = osp.abspath(work_dir)
mmengine.mkdir_or_exist(self._work_dir) mmengine.mkdir_or_exist(self._work_dir)
# recursively copy the ``cfg`` because `self.cfg` will be modified # recursively copy the `cfg` because `self.cfg` will be modified
# everywhere. # everywhere.
if cfg is not None: if cfg is not None:
self.cfg = copy.deepcopy(cfg) self.cfg = copy.deepcopy(cfg)
...@@ -231,21 +240,25 @@ class Runner: ...@@ -231,21 +240,25 @@ class Runner:
self._iter = 0 self._iter = 0
# lazy initialization # lazy initialization
training_related = [ training_related = [train_dataloader, train_cfg, optimizer]
train_dataloader, train_cfg, optimizer, param_scheduler
]
if not (all(item is None for item in training_related) if not (all(item is None for item in training_related)
or all(item is not None for item in training_related)): or all(item is not None for item in training_related)):
raise ValueError( raise ValueError(
'train_dataloader, train_cfg, optimizer and param_scheduler ' 'train_dataloader, train_cfg, and optimizer should be either '
'should be either all None or not None, but got ' 'all None or not None, but got '
f'train_dataloader={train_dataloader}, ' f'train_dataloader={train_dataloader}, '
f'train_cfg={train_cfg}, ' f'train_cfg={train_cfg}, '
f'optimizer={optimizer}, ' f'optimizer={optimizer}.')
f'param_scheduler={param_scheduler}.')
self.train_dataloader = train_dataloader self.train_dataloader = train_dataloader
self.train_loop = train_cfg self.train_loop = train_cfg
self.optimizer = optimizer self.optimizer = optimizer
# If there is no need to adjust learning rate, momentum or other
# parameters of optimizer, param_scheduler can be None
if param_scheduler is not None and self.optimizer is None:
raise ValueError(
'param_scheduler should be None when optimizer is None, '
f'but got {param_scheduler}')
if not isinstance(param_scheduler, Sequence): if not isinstance(param_scheduler, Sequence):
self.param_schedulers = [param_scheduler] self.param_schedulers = [param_scheduler]
else: else:
...@@ -256,7 +269,7 @@ class Runner: ...@@ -256,7 +269,7 @@ class Runner:
for item in val_related) or all(item is not None for item in val_related) or all(item is not None
for item in val_related)): for item in val_related)):
raise ValueError( raise ValueError(
'val_dataloader, val_cfg and val_evaluator should be either ' 'val_dataloader, val_cfg, and val_evaluator should be either '
'all None or not None, but got ' 'all None or not None, but got '
f'val_dataloader={val_dataloader}, val_cfg={val_cfg}, ' f'val_dataloader={val_dataloader}, val_cfg={val_cfg}, '
f'val_evaluator={val_evaluator}') f'val_evaluator={val_evaluator}')
...@@ -268,8 +281,8 @@ class Runner: ...@@ -268,8 +281,8 @@ class Runner:
if not (all(item is None for item in test_related) if not (all(item is None for item in test_related)
or all(item is not None for item in test_related)): or all(item is not None for item in test_related)):
raise ValueError( raise ValueError(
'test_dataloader, test_cfg and test_evaluator should be either' 'test_dataloader, test_cfg, and test_evaluator should be '
' all None or not None, but got ' 'either all None or not None, but got '
f'test_dataloader={test_dataloader}, test_cfg={test_cfg}, ' f'test_dataloader={test_dataloader}, test_cfg={test_cfg}, '
f'test_evaluator={test_evaluator}') f'test_evaluator={test_evaluator}')
self.test_dataloader = test_dataloader self.test_dataloader = test_dataloader
...@@ -282,10 +295,13 @@ class Runner: ...@@ -282,10 +295,13 @@ class Runner:
else: else:
self._distributed = True self._distributed = True
# self._deterministic, self._seed and self._timestamp will be set in # self._timestamp will be set in the `setup_env` method. Besides,
# the `setup_env`` method. Besides, it also will initialize # it also will initialize multi-process and (or) distributed
# multi-process and (or) distributed environment. # environment.
self.setup_env(env_cfg, seed, deterministic) self.setup_env(env_cfg)
# self._deterministic and self._seed will be set in the
# `set_randomness`` method
self.set_randomness(**randomness)
if experiment_name is not None: if experiment_name is not None:
self._experiment_name = f'{experiment_name}_{self._timestamp}' self._experiment_name = f'{experiment_name}_{self._timestamp}'
...@@ -296,9 +312,9 @@ class Runner: ...@@ -296,9 +312,9 @@ class Runner:
else: else:
self._experiment_name = self.timestamp self._experiment_name = self.timestamp
self.logger = self.build_logger(logger) self.logger = self.build_logger(log_level=log_level)
# message hub used for component interaction # message hub used for component interaction
self.message_hub = self.build_message_hub(message_hub) self.message_hub = self.build_message_hub()
# writer used for writing log or visualizing all kinds of data # writer used for writing log or visualizing all kinds of data
self.writer = self.build_writer(writer) self.writer = self.build_writer(writer)
# Used to reset registries location. See :meth:`Registry.build` for # Used to reset registries location. See :meth:`Registry.build` for
...@@ -333,7 +349,7 @@ class Runner: ...@@ -333,7 +349,7 @@ class Runner:
self.dump_config() self.dump_config()
@classmethod @classmethod
def build_from_cfg(cls, cfg: ConfigType) -> 'Runner': def from_cfg(cls, cfg: ConfigType) -> 'Runner':
"""Build a runner from config. """Build a runner from config.
Args: Args:
...@@ -363,12 +379,11 @@ class Runner: ...@@ -363,12 +379,11 @@ class Runner:
resume=cfg.get('resume', False), resume=cfg.get('resume', False),
launcher=cfg.get('launcher', 'none'), launcher=cfg.get('launcher', 'none'),
env_cfg=cfg.get('env_cfg'), # type: ignore env_cfg=cfg.get('env_cfg'), # type: ignore
logger=cfg.get('log_cfg'), log_level=cfg.get('log_level', 'INFO'),
message_hub=cfg.get('message_hub'),
writer=cfg.get('writer'), writer=cfg.get('writer'),
default_scope=cfg.get('default_scope'), default_scope=cfg.get('default_scope'),
seed=cfg.get('seed'), randomness=cfg.get('randomness', dict(seed=None)),
deterministic=cfg.get('deterministic', False), experiment_name=cfg.get('experiment_name'),
cfg=cfg, cfg=cfg,
) )
...@@ -439,10 +454,7 @@ class Runner: ...@@ -439,10 +454,7 @@ class Runner:
"""list[:obj:`Hook`]: A list of registered hooks.""" """list[:obj:`Hook`]: A list of registered hooks."""
return self._hooks return self._hooks
def setup_env(self, def setup_env(self, env_cfg: Dict) -> None:
env_cfg: Dict,
seed: Optional[int],
deterministic: bool = False) -> None:
"""Setup environment. """Setup environment.
An example of ``env_cfg``:: An example of ``env_cfg``::
...@@ -458,17 +470,7 @@ class Runner: ...@@ -458,17 +470,7 @@ class Runner:
Args: Args:
env_cfg (dict): Config for setting environment. env_cfg (dict): Config for setting environment.
seed (int, optional): A number to set random modules. If not
specified, a random number will be set as seed.
Defaults to None.
deterministic (bool): Whether cudnn to select deterministic
algorithms. Defaults to False.
See https://pytorch.org/docs/stable/notes/randomness.html for
more details.
""" """
self._deterministic = deterministic
self._seed = seed
if env_cfg.get('cudnn_benchmark'): if env_cfg.get('cudnn_benchmark'):
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
...@@ -490,9 +492,6 @@ class Runner: ...@@ -490,9 +492,6 @@ class Runner:
self._timestamp = time.strftime('%Y%m%d_%H%M%S', self._timestamp = time.strftime('%Y%m%d_%H%M%S',
time.localtime(timestamp.item())) time.localtime(timestamp.item()))
# set random seeds
self._set_random_seed()
def _set_multi_processing(self, def _set_multi_processing(self,
mp_start_method: str = 'fork', mp_start_method: str = 'fork',
opencv_num_threads: int = 0) -> None: opencv_num_threads: int = 0) -> None:
...@@ -546,16 +545,20 @@ class Runner: ...@@ -546,16 +545,20 @@ class Runner:
'optimal performance in your application as needed.') 'optimal performance in your application as needed.')
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads) os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
def _set_random_seed(self) -> None: def set_randomness(self, seed, deterministic: bool = False) -> None:
"""Set random seed to guarantee reproducible results. """Set random seed to guarantee reproducible results.
Warning: Args:
Results can not be guaranteed to resproducible if ``self.seed`` is seed (int): A number to set random modules.
None because :meth:`_set_random_seed` will generate a random seed deterministic (bool): Whether to set the deterministic option for
when launching a new experiment. CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False.
See https://pytorch.org/docs/stable/notes/randomness.html for details. Defaults to False.
See https://pytorch.org/docs/stable/notes/randomness.html for
more details.
""" """
self._deterministic = deterministic
self._seed = seed
if self._seed is None: if self._seed is None:
self._seed = sync_random_seed() self._seed = sync_random_seed()
...@@ -563,69 +566,62 @@ class Runner: ...@@ -563,69 +566,62 @@ class Runner:
np.random.seed(self._seed) np.random.seed(self._seed)
torch.manual_seed(self._seed) torch.manual_seed(self._seed)
torch.cuda.manual_seed_all(self._seed) torch.cuda.manual_seed_all(self._seed)
if self._deterministic: if deterministic:
torch.backends.cudnn.deterministic = True if torch.backends.cudnn.benchmark:
warnings.warn(
'torch.backends.cudnn.benchmark is going to be set as '
'`False` to cause cuDNN to deterministically select an '
'algorithm')
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if digit_version(TORCH_VERSION) >= digit_version('1.10.0'):
torch.use_deterministic_algorithms(True)
def build_logger(self, def build_logger(self,
logger: Optional[Union[MMLogger, log_level: Union[int, str] = 'INFO',
Dict]] = None) -> MMLogger: log_file: str = None,
**kwargs) -> MMLogger:
"""Build a global asscessable MMLogger. """Build a global asscessable MMLogger.
Args: Args:
logger (MMLogger or dict, optional): A MMLogger object or a dict to log_level (int or str): The log level of MMLogger handlers.
build MMLogger object. If ``logger`` is a MMLogger object, just Defaults to 'INFO'.
returns itself. If not specified, default config will be used log_file (str, optional): Path of filename to save log.
to build MMLogger object. Defaults to None. Defaults to None.
**kwargs: Remaining parameters passed to ``MMLogger``.
Returns: Returns:
MMLogger: A MMLogger object build from ``logger``. MMLogger: A MMLogger object build from ``logger``.
""" """
if isinstance(logger, MMLogger): if log_file is None:
return logger log_file = osp.join(self.work_dir, f'{self._experiment_name}.log')
elif logger is None:
logger = dict(
name=self._experiment_name,
log_level='INFO',
log_file=osp.join(self.work_dir,
f'{self._experiment_name}.log'))
elif isinstance(logger, dict):
# ensure logger containing name key
logger.setdefault('name', self._experiment_name)
else:
raise TypeError(
'logger should be MMLogger object, a dict or None, '
f'but got {logger}')
return MMLogger.get_instance(**logger) log_cfg = dict(log_level=log_level, log_file=log_file, **kwargs)
log_cfg.setdefault('name', self._experiment_name)
def build_message_hub( return MMLogger.get_instance(**log_cfg) # type: ignore
self,
message_hub: Optional[Union[MessageHub, def build_message_hub(self,
Dict]] = None) -> MessageHub: message_hub: Optional[Dict] = None) -> MessageHub:
"""Build a global asscessable MessageHub. """Build a global asscessable MessageHub.
Args: Args:
message_hub (MessageHub or dict, optional): A MessageHub object or message_hub (dict, optional): A dict to build MessageHub object.
a dict to build MessageHub object. If ``message_hub`` is a If not specified, default config will be used to build
MessageHub object, just returns itself. If not specified, MessageHub object. Defaults to None.
default config will be used to build MessageHub object.
Defaults to None.
Returns: Returns:
MessageHub: A MessageHub object build from ``message_hub``. MessageHub: A MessageHub object build from ``message_hub``.
""" """
if isinstance(message_hub, MessageHub): if message_hub is None:
return message_hub
elif message_hub is None:
message_hub = dict(name=self._experiment_name) message_hub = dict(name=self._experiment_name)
elif isinstance(message_hub, dict): elif isinstance(message_hub, dict):
# ensure message_hub containing name key # ensure message_hub containing name key
message_hub.setdefault('name', self._experiment_name) message_hub.setdefault('name', self._experiment_name)
else: else:
raise TypeError( raise TypeError(
'message_hub should be MessageHub object, a dict or None, ' f'message_hub should be dict or None, but got {message_hub}')
f'but got {message_hub}')
return MessageHub.get_instance(**message_hub) return MessageHub.get_instance(**message_hub)
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment