Skip to content
Snippets Groups Projects
Unverified Commit f1de071c authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

[Enhancement] Refactor Runner (#139)

* [Enhancement] Rename build_from_cfg to from_cfg

* refactor build_logger and build_message_hub

* remove time.sleep from unit tests

* minor fix

* move set_randomness from setup_env

* improve docstring

* refine comments

* print a warning information

* refine comments

* simplify the interface of build_logger
parent 9a61b389
No related branches found
No related tags found
No related merge requests found
......@@ -32,7 +32,8 @@ from mmengine.optim import _ParamScheduler, build_optimizer
from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS,
DefaultScope)
from mmengine.utils import find_latest_checkpoint, is_list_of, symlink
from mmengine.utils import (TORCH_VERSION, digit_version,
find_latest_checkpoint, is_list_of, symlink)
from mmengine.visualization import ComposedWriter
from .base_loop import BaseLoop
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
......@@ -47,6 +48,20 @@ ConfigType = Union[Dict, Config, ConfigDict]
class Runner:
"""A training helper for PyTorch.
Runner object can be built from config by ``runner = Runner.from_cfg(cfg)``
where the ``cfg`` usually contains training, validation, and test-related
configurations to build corresponding components. We usually use the
same config to launch training, testing, and validation tasks. However,
only some of these components are necessary at the same time, e.g.,
testing a model does not need training or validation-related components.
To avoid repeatedly modifying config, the construction of ``Runner`` adopts
lazy initialization to only initialize components when they are going to be
used. Therefore, the model is always initialized at the beginning, and
training, validation, and, testing related components are only initialized
when calling ``runner.train()``, ``runner.val()``, and ``runner.test()``,
respectively.
Args:
model (:obj:`torch.nn.Module` or dict): The model to be run. It can be
a dict used for build a model.
......@@ -114,23 +129,21 @@ class Runner:
non-distributed environment will be launched.
env_cfg (dict): A dict used for setting environment. Defaults to
dict(dist_cfg=dict(backend='nccl')).
logger (MMLogger or dict, optional): A MMLogger object or a dict to
build MMLogger object. Defaults to None. If not specified, default
config will be used.
message_hub (MessageHub or dict, optional): A Messagehub object or a
dict to build MessageHub object. Defaults to None. If not
specified, default config will be used.
log_level (int or str): The log level of MMLogger handlers.
Defaults to 'INFO'.
writer (ComposedWriter or dict, optional): A ComposedWriter object or a
dict build ComposedWriter object. Defaults to None. If not
specified, default config will be used.
default_scope (str, optional): Used to reset registries location.
Defaults to None.
seed (int, optional): A number to set random modules. If not specified,
a random number will be set as seed. Defaults to None.
deterministic (bool): Whether cudnn to select deterministic algorithms.
Defaults to False.
See https://pytorch.org/docs/stable/notes/randomness.html for
more details.
randomness (dict): Some settings to make the experiment as reproducible
as possible like seed and deterministic.
Defaults to ``dict(seed=None)``. If seed is None, a random number
will be generated and it will be broadcasted to all other processes
if in distributed environment. If ``cudnn_benchmarch`` is
``True`` in ``env_cfg`` but ``deterministic`` is ``True`` in
``randomness``, the value of ``torch.backends.cudnn.benchmark``
will be ``False`` finally.
experiment_name (str, optional): Name of current experiment. If not
specified, timestamp will be used as ``experiment_name``.
Defaults to None.
......@@ -173,13 +186,11 @@ class Runner:
param_scheduler=dict(type='ParamSchedulerHook')),
launcher='none',
env_cfg=dict(dist_cfg=dict(backend='nccl')),
logger=dict(log_level='INFO'),
message_hub=None,
writer=dict(
name='composed_writer',
writers=[dict(type='LocalWriter', save_dir='temp_dir')])
)
>>> runner = Runner.build_from_cfg(cfg)
>>> runner = Runner.from_cfg(cfg)
>>> runner.train()
>>> runner.test()
"""
......@@ -208,19 +219,17 @@ class Runner:
resume: bool = False,
launcher: str = 'none',
env_cfg: Dict = dict(dist_cfg=dict(backend='nccl')),
logger: Optional[Union[MMLogger, Dict]] = None,
message_hub: Optional[Union[MessageHub, Dict]] = None,
log_level: str = 'INFO',
writer: Optional[Union[ComposedWriter, Dict]] = None,
default_scope: Optional[str] = None,
seed: Optional[int] = None,
deterministic: bool = False,
randomness: Dict = dict(seed=None),
experiment_name: Optional[str] = None,
cfg: Optional[ConfigType] = None,
):
self._work_dir = osp.abspath(work_dir)
mmengine.mkdir_or_exist(self._work_dir)
# recursively copy the ``cfg`` because `self.cfg` will be modified
# recursively copy the `cfg` because `self.cfg` will be modified
# everywhere.
if cfg is not None:
self.cfg = copy.deepcopy(cfg)
......@@ -231,21 +240,25 @@ class Runner:
self._iter = 0
# lazy initialization
training_related = [
train_dataloader, train_cfg, optimizer, param_scheduler
]
training_related = [train_dataloader, train_cfg, optimizer]
if not (all(item is None for item in training_related)
or all(item is not None for item in training_related)):
raise ValueError(
'train_dataloader, train_cfg, optimizer and param_scheduler '
'should be either all None or not None, but got '
'train_dataloader, train_cfg, and optimizer should be either '
'all None or not None, but got '
f'train_dataloader={train_dataloader}, '
f'train_cfg={train_cfg}, '
f'optimizer={optimizer}, '
f'param_scheduler={param_scheduler}.')
f'optimizer={optimizer}.')
self.train_dataloader = train_dataloader
self.train_loop = train_cfg
self.optimizer = optimizer
# If there is no need to adjust learning rate, momentum or other
# parameters of optimizer, param_scheduler can be None
if param_scheduler is not None and self.optimizer is None:
raise ValueError(
'param_scheduler should be None when optimizer is None, '
f'but got {param_scheduler}')
if not isinstance(param_scheduler, Sequence):
self.param_schedulers = [param_scheduler]
else:
......@@ -256,7 +269,7 @@ class Runner:
for item in val_related) or all(item is not None
for item in val_related)):
raise ValueError(
'val_dataloader, val_cfg and val_evaluator should be either '
'val_dataloader, val_cfg, and val_evaluator should be either '
'all None or not None, but got '
f'val_dataloader={val_dataloader}, val_cfg={val_cfg}, '
f'val_evaluator={val_evaluator}')
......@@ -268,8 +281,8 @@ class Runner:
if not (all(item is None for item in test_related)
or all(item is not None for item in test_related)):
raise ValueError(
'test_dataloader, test_cfg and test_evaluator should be either'
' all None or not None, but got '
'test_dataloader, test_cfg, and test_evaluator should be '
'either all None or not None, but got '
f'test_dataloader={test_dataloader}, test_cfg={test_cfg}, '
f'test_evaluator={test_evaluator}')
self.test_dataloader = test_dataloader
......@@ -282,10 +295,13 @@ class Runner:
else:
self._distributed = True
# self._deterministic, self._seed and self._timestamp will be set in
# the `setup_env`` method. Besides, it also will initialize
# multi-process and (or) distributed environment.
self.setup_env(env_cfg, seed, deterministic)
# self._timestamp will be set in the `setup_env` method. Besides,
# it also will initialize multi-process and (or) distributed
# environment.
self.setup_env(env_cfg)
# self._deterministic and self._seed will be set in the
# `set_randomness`` method
self.set_randomness(**randomness)
if experiment_name is not None:
self._experiment_name = f'{experiment_name}_{self._timestamp}'
......@@ -296,9 +312,9 @@ class Runner:
else:
self._experiment_name = self.timestamp
self.logger = self.build_logger(logger)
self.logger = self.build_logger(log_level=log_level)
# message hub used for component interaction
self.message_hub = self.build_message_hub(message_hub)
self.message_hub = self.build_message_hub()
# writer used for writing log or visualizing all kinds of data
self.writer = self.build_writer(writer)
# Used to reset registries location. See :meth:`Registry.build` for
......@@ -333,7 +349,7 @@ class Runner:
self.dump_config()
@classmethod
def build_from_cfg(cls, cfg: ConfigType) -> 'Runner':
def from_cfg(cls, cfg: ConfigType) -> 'Runner':
"""Build a runner from config.
Args:
......@@ -363,12 +379,11 @@ class Runner:
resume=cfg.get('resume', False),
launcher=cfg.get('launcher', 'none'),
env_cfg=cfg.get('env_cfg'), # type: ignore
logger=cfg.get('log_cfg'),
message_hub=cfg.get('message_hub'),
log_level=cfg.get('log_level', 'INFO'),
writer=cfg.get('writer'),
default_scope=cfg.get('default_scope'),
seed=cfg.get('seed'),
deterministic=cfg.get('deterministic', False),
randomness=cfg.get('randomness', dict(seed=None)),
experiment_name=cfg.get('experiment_name'),
cfg=cfg,
)
......@@ -439,10 +454,7 @@ class Runner:
"""list[:obj:`Hook`]: A list of registered hooks."""
return self._hooks
def setup_env(self,
env_cfg: Dict,
seed: Optional[int],
deterministic: bool = False) -> None:
def setup_env(self, env_cfg: Dict) -> None:
"""Setup environment.
An example of ``env_cfg``::
......@@ -458,17 +470,7 @@ class Runner:
Args:
env_cfg (dict): Config for setting environment.
seed (int, optional): A number to set random modules. If not
specified, a random number will be set as seed.
Defaults to None.
deterministic (bool): Whether cudnn to select deterministic
algorithms. Defaults to False.
See https://pytorch.org/docs/stable/notes/randomness.html for
more details.
"""
self._deterministic = deterministic
self._seed = seed
if env_cfg.get('cudnn_benchmark'):
torch.backends.cudnn.benchmark = True
......@@ -490,9 +492,6 @@ class Runner:
self._timestamp = time.strftime('%Y%m%d_%H%M%S',
time.localtime(timestamp.item()))
# set random seeds
self._set_random_seed()
def _set_multi_processing(self,
mp_start_method: str = 'fork',
opencv_num_threads: int = 0) -> None:
......@@ -546,16 +545,20 @@ class Runner:
'optimal performance in your application as needed.')
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
def _set_random_seed(self) -> None:
def set_randomness(self, seed, deterministic: bool = False) -> None:
"""Set random seed to guarantee reproducible results.
Warning:
Results can not be guaranteed to resproducible if ``self.seed`` is
None because :meth:`_set_random_seed` will generate a random seed
when launching a new experiment.
See https://pytorch.org/docs/stable/notes/randomness.html for details.
Args:
seed (int): A number to set random modules.
deterministic (bool): Whether to set the deterministic option for
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False.
Defaults to False.
See https://pytorch.org/docs/stable/notes/randomness.html for
more details.
"""
self._deterministic = deterministic
self._seed = seed
if self._seed is None:
self._seed = sync_random_seed()
......@@ -563,69 +566,62 @@ class Runner:
np.random.seed(self._seed)
torch.manual_seed(self._seed)
torch.cuda.manual_seed_all(self._seed)
if self._deterministic:
torch.backends.cudnn.deterministic = True
if deterministic:
if torch.backends.cudnn.benchmark:
warnings.warn(
'torch.backends.cudnn.benchmark is going to be set as '
'`False` to cause cuDNN to deterministically select an '
'algorithm')
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if digit_version(TORCH_VERSION) >= digit_version('1.10.0'):
torch.use_deterministic_algorithms(True)
def build_logger(self,
logger: Optional[Union[MMLogger,
Dict]] = None) -> MMLogger:
log_level: Union[int, str] = 'INFO',
log_file: str = None,
**kwargs) -> MMLogger:
"""Build a global asscessable MMLogger.
Args:
logger (MMLogger or dict, optional): A MMLogger object or a dict to
build MMLogger object. If ``logger`` is a MMLogger object, just
returns itself. If not specified, default config will be used
to build MMLogger object. Defaults to None.
log_level (int or str): The log level of MMLogger handlers.
Defaults to 'INFO'.
log_file (str, optional): Path of filename to save log.
Defaults to None.
**kwargs: Remaining parameters passed to ``MMLogger``.
Returns:
MMLogger: A MMLogger object build from ``logger``.
"""
if isinstance(logger, MMLogger):
return logger
elif logger is None:
logger = dict(
name=self._experiment_name,
log_level='INFO',
log_file=osp.join(self.work_dir,
f'{self._experiment_name}.log'))
elif isinstance(logger, dict):
# ensure logger containing name key
logger.setdefault('name', self._experiment_name)
else:
raise TypeError(
'logger should be MMLogger object, a dict or None, '
f'but got {logger}')
if log_file is None:
log_file = osp.join(self.work_dir, f'{self._experiment_name}.log')
return MMLogger.get_instance(**logger)
log_cfg = dict(log_level=log_level, log_file=log_file, **kwargs)
log_cfg.setdefault('name', self._experiment_name)
def build_message_hub(
self,
message_hub: Optional[Union[MessageHub,
Dict]] = None) -> MessageHub:
return MMLogger.get_instance(**log_cfg) # type: ignore
def build_message_hub(self,
message_hub: Optional[Dict] = None) -> MessageHub:
"""Build a global asscessable MessageHub.
Args:
message_hub (MessageHub or dict, optional): A MessageHub object or
a dict to build MessageHub object. If ``message_hub`` is a
MessageHub object, just returns itself. If not specified,
default config will be used to build MessageHub object.
Defaults to None.
message_hub (dict, optional): A dict to build MessageHub object.
If not specified, default config will be used to build
MessageHub object. Defaults to None.
Returns:
MessageHub: A MessageHub object build from ``message_hub``.
"""
if isinstance(message_hub, MessageHub):
return message_hub
elif message_hub is None:
if message_hub is None:
message_hub = dict(name=self._experiment_name)
elif isinstance(message_hub, dict):
# ensure message_hub containing name key
message_hub.setdefault('name', self._experiment_name)
else:
raise TypeError(
'message_hub should be MessageHub object, a dict or None, '
f'but got {message_hub}')
f'message_hub should be dict or None, but got {message_hub}')
return MessageHub.get_instance(**message_hub)
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment