Newer
Older
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import platform
import time
import warnings
from functools import partial
from typing import Callable, Dict, List, Optional, Sequence, Union
import torch
import torch.nn as nn
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.optim import Optimizer
from torch.utils.data import DataLoader
import mmengine
from mmengine.config import Config, ConfigDict
from mmengine.dataset import worker_init_fn as default_worker_init_fn
from mmengine.device import get_device
from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist,
is_distributed, master_only)
from mmengine.evaluator import Evaluator
from mmengine.fileio import FileClient, join_path
from mmengine.logging import MessageHub, MMLogger, print_log
from mmengine.model import (MMDistributedDataParallel, convert_sync_batchnorm,
is_model_wrapper, revert_sync_batchnorm)
from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler,
build_optim_wrapper)
from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, FUNCTIONS,
HOOKS, LOG_PROCESSORS, LOOPS, MODEL_WRAPPERS,
MODELS, OPTIM_WRAPPERS, PARAM_SCHEDULERS,
RUNNERS, VISUALIZERS, DefaultScope)
from mmengine.utils import digit_version, get_git_hash, is_seq_of
from mmengine.utils.dl_utils import (TORCH_VERSION, collect_env,
set_multi_processing)
from .base_loop import BaseLoop
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
find_latest_checkpoint, get_state_dict,
save_checkpoint, weights_to_cpu)
from .log_processor import LogProcessor
from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop
from .priority import Priority, get_priority
from .utils import set_random_seed
ConfigType = Union[Dict, Config, ConfigDict]
ParamSchedulerType = Union[List[_ParamScheduler], Dict[str,
List[_ParamScheduler]]]
OptimWrapperType = Union[OptimWrapper, OptimWrapperDict]
class Runner:
"""A training helper for PyTorch.
Runner object can be built from config by ``runner = Runner.from_cfg(cfg)``
where the ``cfg`` usually contains training, validation, and test-related
configurations to build corresponding components. We usually use the
same config to launch training, testing, and validation tasks. However,
only some of these components are necessary at the same time, e.g.,
testing a model does not need training or validation-related components.
To avoid repeatedly modifying config, the construction of ``Runner`` adopts
lazy initialization to only initialize components when they are going to be
used. Therefore, the model is always initialized at the beginning, and
training, validation, and, testing related components are only initialized
when calling ``runner.train()``, ``runner.val()``, and ``runner.test()``,
respectively.
Args:
model (:obj:`torch.nn.Module` or dict): The model to be run. It can be
a dict used for build a model.
work_dir (str): The working directory to save checkpoints. The logs
will be saved in the subdirectory of `work_dir` named
:attr:`timestamp`.
train_dataloader (Dataloader or dict, optional): A dataloader object or
a dict to build a dataloader. If ``None`` is given, it means
skipping training steps. Defaults to None.
See :meth:`build_dataloader` for more details.
val_dataloader (Dataloader or dict, optional): A dataloader object or
a dict to build a dataloader. If ``None`` is given, it means
skipping validation steps. Defaults to None.
See :meth:`build_dataloader` for more details.
test_dataloader (Dataloader or dict, optional): A dataloader object or
a dict to build a dataloader. If ``None`` is given, it means
skipping test steps. Defaults to None.
See :meth:`build_dataloader` for more details.
train_cfg (dict, optional): A dict to build a training loop. If it does
not provide "type" key, it should contain "by_epoch" to decide
which type of training loop :class:`EpochBasedTrainLoop` or
:class:`IterBasedTrainLoop` should be used. If ``train_cfg``
specified, :attr:`train_dataloader` should also be specified.
Defaults to None. See :meth:`build_train_loop` for more details.
val_cfg (dict, optional): A dict to build a validation loop. If it does
not provide "type" key, :class:`ValLoop` will be used by default.
If ``val_cfg`` specified, :attr:`val_dataloader` should also be
specified. If ``ValLoop`` is built with `fp16=True``,
``runner.val()`` will be performed under fp16 precision.
Defaults to None. See :meth:`build_val_loop` for more details.
test_cfg (dict, optional): A dict to build a test loop. If it does
not provide "type" key, :class:`TestLoop` will be used by default.
If ``test_cfg`` specified, :attr:`test_dataloader` should also be
specified. If ``ValLoop`` is built with `fp16=True``,
``runner.val()`` will be performed under fp16 precision.
Defaults to None. See :meth:`build_test_loop` for more details.
auto_scale_lr (dict, Optional): Config to scale the learning rate
automatically. It includes ``base_batch_size`` and ``enable``.
``base_batch_size`` is the batch size that the optimizer lr is
based on. ``enable`` is the switch to turn on and off the feature.
optim_wrapper (OptimWrapper or dict, optional):
Computing gradient of model parameters. If specified,
:attr:`train_dataloader` should also be specified. If automatic
mixed precision or gradient accmulation
training is required. The type of ``optim_wrapper`` should be
AmpOptimizerWrapper. See :meth:`build_optim_wrapper` for
examples. Defaults to None.
param_scheduler (_ParamScheduler or dict or list, optional):
Parameter scheduler for updating optimizer parameters. If
specified, :attr:`optimizer` should also be specified.
Defaults to None.
See :meth:`build_param_scheduler` for examples.
val_evaluator (Evaluator or dict or list, optional): A evaluator object
used for computing metrics for validation. It can be a dict or a
list of dict to build a evaluator. If specified,
:attr:`val_dataloader` should also be specified. Defaults to None.
test_evaluator (Evaluator or dict or list, optional): A evaluator
object used for computing metrics for test steps. It can be a dict
or a list of dict to build a evaluator. If specified,
:attr:`test_dataloader` should also be specified. Defaults to None.
default_hooks (dict[str, dict] or dict[str, Hook], optional): Hooks to
execute default actions like updating model parameters and saving
checkpoints. Default hooks are ``OptimizerHook``,
``IterTimerHook``, ``LoggerHook``, ``ParamSchedulerHook`` and
``CheckpointHook``. Defaults to None.
See :meth:`register_default_hooks` for more details.
custom_hooks (list[dict] or list[Hook], optional): Hooks to execute
custom actions like visualizing images processed by pipeline.
Defaults to None.
data_preprocessor (dict, optional): The pre-process config of
:class:`BaseDataPreprocessor`. If the ``model`` argument is a dict
and doesn't contain the key ``data_preprocessor``, set the argument
as the ``data_preprocessor`` of the ``model`` dict.
Defaults to None.
load_from (str, optional): The checkpoint file to load from.
Defaults to None.
resume (bool): Whether to resume training. Defaults to False. If
``resume`` is True and ``load_from`` is None, automatically to
find latest checkpoint from ``work_dir``. If not found, resuming
does nothing.
launcher (str): Way to launcher multi-process. Supported launchers
are 'pytorch', 'mpi', 'slurm' and 'none'. If 'none' is provided,
non-distributed environment will be launched.
env_cfg (dict): A dict used for setting environment. Defaults to
dict(dist_cfg=dict(backend='nccl')).
log_processor (dict, optional): A processor to format logs. Defaults to
None.
log_level (int or str): The log level of MMLogger handlers.
Defaults to 'INFO'.
visualizer (Visualizer or dict, optional): A Visualizer object or a
dict build Visualizer object. Defaults to None. If not
specified, default config will be used.
default_scope (str): Used to reset registries location.
Defaults to "mmengine".
randomness (dict): Some settings to make the experiment as reproducible
as possible like seed and deterministic.
Defaults to ``dict(seed=None)``. If seed is None, a random number
will be generated and it will be broadcasted to all other processes
Mashiro
committed
if in distributed environment. If ``cudnn_benchmark`` is
``True`` in ``env_cfg`` but ``deterministic`` is ``True`` in
``randomness``, the value of ``torch.backends.cudnn.benchmark``
will be ``False`` finally.
experiment_name (str, optional): Name of current experiment. If not
specified, timestamp will be used as ``experiment_name``.
Defaults to None.
cfg (dict or Configdict or :obj:`Config`, optional): Full config.
Defaults to None.
Note:
Since PyTorch 2.0.0, you can enable ``torch.compile`` by passing in
`cfg.compile = True`. If you want to control compile options, you
can pass a dict, e.g. ``cfg.compile = dict(backend='eager')``.
Refer to `PyTorch API Documentation <https://pytorch.org/docs/
master/generated/torch.compile.html#torch.compile>`_ for more valid
options.
>>> from mmengine.runner import Runner
>>> model=dict(type='ToyModel'),
>>> work_dir='path/of/work_dir',
>>> train_dataloader=dict(
>>> dataset=dict(type='ToyDataset'),
>>> sampler=dict(type='DefaultSampler', shuffle=True),
>>> batch_size=1,
>>> num_workers=0),
>>> val_dataloader=dict(
>>> dataset=dict(type='ToyDataset'),
>>> sampler=dict(type='DefaultSampler', shuffle=False),
>>> batch_size=1,
>>> num_workers=0),
>>> test_dataloader=dict(
>>> dataset=dict(type='ToyDataset'),
>>> sampler=dict(type='DefaultSampler', shuffle=False),
>>> batch_size=1,
>>> num_workers=0),
>>> auto_scale_lr=dict(base_batch_size=16, enable=False),
>>> optim_wrapper=dict(type='OptimizerWrapper', optimizer=dict(
>>> type='SGD', lr=0.01)),
>>> param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]),
>>> val_evaluator=dict(type='ToyEvaluator'),
>>> test_evaluator=dict(type='ToyEvaluator'),
RangiLyu
committed
>>> train_cfg=dict(by_epoch=True, max_epochs=3, val_interval=1),
>>> val_cfg=dict(),
>>> test_cfg=dict(),
>>> custom_hooks=[],
>>> default_hooks=dict(
>>> timer=dict(type='IterTimerHook'),
>>> checkpoint=dict(type='CheckpointHook', interval=1),
>>> logger=dict(type='LoggerHook'),
>>> optimizer=dict(type='OptimizerHook', grad_clip=False),
>>> param_scheduler=dict(type='ParamSchedulerHook')),
>>> launcher='none',
>>> env_cfg=dict(dist_cfg=dict(backend='nccl')),
>>> log_processor=dict(window_size=20),
>>> visualizer=dict(type='Visualizer',
>>> vis_backends=[dict(type='LocalVisBackend',
>>> save_dir='temp_dir')])
>>> )
>>> runner.train()
>>> runner.test()
"""
_train_loop: Optional[Union[BaseLoop, Dict]]
_val_loop: Optional[Union[BaseLoop, Dict]]
_test_loop: Optional[Union[BaseLoop, Dict]]
def __init__(
self,
model: Union[nn.Module, Dict],
work_dir: str,
train_dataloader: Optional[Union[DataLoader, Dict]] = None,
val_dataloader: Optional[Union[DataLoader, Dict]] = None,
test_dataloader: Optional[Union[DataLoader, Dict]] = None,
train_cfg: Optional[Dict] = None,
val_cfg: Optional[Dict] = None,
test_cfg: Optional[Dict] = None,
auto_scale_lr: Optional[Dict] = None,
optim_wrapper: Optional[Union[OptimWrapper, Dict]] = None,
param_scheduler: Optional[Union[_ParamScheduler, Dict, List]] = None,
val_evaluator: Optional[Union[Evaluator, Dict, List]] = None,
test_evaluator: Optional[Union[Evaluator, Dict, List]] = None,
default_hooks: Optional[Dict[str, Union[Hook, Dict]]] = None,
custom_hooks: Optional[List[Union[Hook, Dict]]] = None,
data_preprocessor: Union[nn.Module, Dict, None] = None,
load_from: Optional[str] = None,
resume: bool = False,
launcher: str = 'none',
env_cfg: Dict = dict(dist_cfg=dict(backend='nccl')),
log_processor: Optional[Dict] = None,
visualizer: Optional[Union[Visualizer, Dict]] = None,
default_scope: str = 'mmengine',
experiment_name: Optional[str] = None,
cfg: Optional[ConfigType] = None,
):
self._work_dir = osp.abspath(work_dir)
mmengine.mkdir_or_exist(self._work_dir)
# recursively copy the `cfg` because `self.cfg` will be modified
if isinstance(cfg, Config):
self.cfg = copy.deepcopy(cfg)
elif isinstance(cfg, dict):
self.cfg = Config(cfg)
self.cfg = Config(dict())
training_related = [train_dataloader, train_cfg, optim_wrapper]
if not (all(item is None for item in training_related)
or all(item is not None for item in training_related)):
raise ValueError(
'train_dataloader, train_cfg, and optim_wrapper should be '
'either all None or not None, but got '
f'train_dataloader={train_dataloader}, '
f'train_cfg={train_cfg}, '
self._train_dataloader = train_dataloader
self._train_loop = train_cfg
self.optim_wrapper: Optional[Union[OptimWrapper, dict]]
self.optim_wrapper = optim_wrapper
self.auto_scale_lr = auto_scale_lr
# If there is no need to adjust learning rate, momentum or other
# parameters of optimizer, param_scheduler can be None
if param_scheduler is not None and self.optim_wrapper is None:
'param_scheduler should be None when optim_wrapper is None, '
# Parse `param_scheduler` to a list or a dict. If `optim_wrapper` is a
# `dict` with single optimizer, parsed param_scheduler will be a
# list of parameter schedulers. If `optim_wrapper` is
# a `dict` with multiple optimizers, parsed `param_scheduler` will be
# dict with multiple list of parameter schedulers.
self._check_scheduler_cfg(param_scheduler)
self.param_schedulers = param_scheduler
val_related = [val_dataloader, val_cfg, val_evaluator]
if not (all(item is None
for item in val_related) or all(item is not None
for item in val_related)):
raise ValueError(
'val_dataloader, val_cfg, and val_evaluator should be either '
'all None or not None, but got '
f'val_dataloader={val_dataloader}, val_cfg={val_cfg}, '
f'val_evaluator={val_evaluator}')
self._val_dataloader = val_dataloader
self._val_loop = val_cfg
self._val_evaluator = val_evaluator
test_related = [test_dataloader, test_cfg, test_evaluator]
if not (all(item is None for item in test_related)
or all(item is not None for item in test_related)):
raise ValueError(
'test_dataloader, test_cfg, and test_evaluator should be '
'either all None or not None, but got '
f'test_dataloader={test_dataloader}, test_cfg={test_cfg}, '
f'test_evaluator={test_evaluator}')
self._test_dataloader = test_dataloader
self._test_loop = test_cfg
self._test_evaluator = test_evaluator
self._launcher = launcher
if self._launcher == 'none':
self._distributed = False
else:
self._distributed = True
# self._timestamp will be set in the `setup_env` method. Besides,
# it also will initialize multi-process and (or) distributed
# environment.
self.setup_env(env_cfg)
# self._deterministic and self._seed will be set in the
# `set_randomness`` method
self._randomness_cfg = randomness
if experiment_name is not None:
self._experiment_name = f'{experiment_name}_{self._timestamp}'
elif self.cfg.filename is not None:
filename_no_ext = osp.splitext(osp.basename(self.cfg.filename))[0]
self._experiment_name = f'{filename_no_ext}_{self._timestamp}'
else:
self._experiment_name = self.timestamp
self._log_dir = osp.join(self.work_dir, self.timestamp)
mmengine.mkdir_or_exist(self._log_dir)
# Used to reset registries location. See :meth:`Registry.build` for
# more details.
self.default_scope = DefaultScope.get_instance(
self._experiment_name, scope_name=default_scope)
# Build log processor to format message.
log_processor = dict() if log_processor is None else log_processor
self.log_processor = self.build_log_processor(log_processor)
# Since `get_instance` could return any subclass of ManagerMixin. The
# corresponding attribute needs a type hint.
self.logger = self.build_logger(log_level=log_level)
# Collect and log environment information.
self._log_env(env_cfg)
# Build `message_hub` for communication among components.
# `message_hub` can store log scalars (loss, learning rate) and
# runtime information (iter and epoch). Those components that do not
# have access to the runner can get iteration or epoch information
# from `message_hub`. For example, models can get the latest created
# `message_hub` by
# `self.message_hub=MessageHub.get_current_instance()` and then get
# current epoch by `cur_epoch = self.message_hub.get_info('epoch')`.
# See `MessageHub` and `ManagerMixin` for more details.
self.message_hub = self.build_message_hub()
# visualizer used for writing log or visualizing all kinds of data
self.visualizer = self.build_visualizer(visualizer)
if self.cfg:
self.visualizer.add_config(self.cfg)
self._load_from = load_from
self._resume = resume
# flag to mark whether checkpoint has been loaded or resumed
self._has_loaded = False
# build a model
if isinstance(model, dict) and data_preprocessor is not None:
# Merge the data_preprocessor to model config.
model.setdefault('data_preprocessor', data_preprocessor)
self.model = self.build_model(model)
# wrap model
self.model = self.wrap_model(
self.cfg.get('model_wrapper_cfg'), self.model)
# get model name from the model class
if hasattr(self.model, 'module'):
self._model_name = self.model.module.__class__.__name__
else:
self._model_name = self.model.__class__.__name__
self._hooks: List[Hook] = []
# register hooks to `self._hooks`
self.register_hooks(default_hooks, custom_hooks)
# log hooks information
self.logger.info(f'Hooks will be executed in the following '
f'order:\n{self.get_hooks_info()}')
# dump `cfg` to `work_dir`
self.dump_config()
@classmethod
def from_cfg(cls, cfg: ConfigType) -> 'Runner':
"""Build a runner from config.
Args:
cfg (ConfigType): A config used for building runner. Keys of
``cfg`` can see :meth:`__init__`.
Returns:
Runner: A runner build from ``cfg``.
"""
cfg = copy.deepcopy(cfg)
runner = cls(
model=cfg['model'],
work_dir=cfg['work_dir'],
train_dataloader=cfg.get('train_dataloader'),
val_dataloader=cfg.get('val_dataloader'),
test_dataloader=cfg.get('test_dataloader'),
train_cfg=cfg.get('train_cfg'),
val_cfg=cfg.get('val_cfg'),
test_cfg=cfg.get('test_cfg'),
auto_scale_lr=cfg.get('auto_scale_lr'),
optim_wrapper=cfg.get('optim_wrapper'),
param_scheduler=cfg.get('param_scheduler'),
val_evaluator=cfg.get('val_evaluator'),
test_evaluator=cfg.get('test_evaluator'),
default_hooks=cfg.get('default_hooks'),
custom_hooks=cfg.get('custom_hooks'),
data_preprocessor=cfg.get('data_preprocessor'),
load_from=cfg.get('load_from'),
resume=cfg.get('resume', False),
launcher=cfg.get('launcher', 'none'),
env_cfg=cfg.get('env_cfg'), # type: ignore
log_processor=cfg.get('log_processor'),
log_level=cfg.get('log_level', 'INFO'),
default_scope=cfg.get('default_scope', 'mmengine'),
randomness=cfg.get('randomness', dict(seed=None)),
experiment_name=cfg.get('experiment_name'),
cfg=cfg,
)
return runner
@property
def experiment_name(self):
"""str: Name of experiment."""
return self._experiment_name
@property
def model_name(self):
"""str: Name of the model, usually the module class name."""
return self._model_name
@property
def work_dir(self):
"""str: The working directory to save checkpoints and logs."""
return self._work_dir
@property
def log_dir(self):
return self._log_dir
@property
def max_epochs(self):
"""int: Total epochs to train model."""
if isinstance(self.train_loop, BaseLoop):
return self.train_loop.max_epochs
else:
return 0
@property
def max_iters(self):
"""int: Total iterations to train model."""
if isinstance(self.train_loop, BaseLoop):
return self.train_loop.max_iters
else:
return 0
@property
def epoch(self):
"""int: Current epoch."""
if isinstance(self.train_loop, BaseLoop):
return self.train_loop.epoch
else:
return 0
"""int: Current iteration."""
if isinstance(self.train_loop, BaseLoop):
return self.train_loop.iter
else:
return 0
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
@property
def launcher(self):
"""str: Way to launcher multi processes."""
return self._launcher
@property
def distributed(self):
"""bool: Whether current environment is distributed."""
return self._distributed
@property
def rank(self):
"""int: Rank of current process."""
return self._rank
@property
def world_size(self):
"""int: Number of processes participating in the job."""
return self._world_size
@property
def deterministic(self):
"""int: Whether cudnn to select deterministic algorithms."""
return self._deterministic
@property
def seed(self):
"""int: A number to set random modules."""
return self._seed
@property
def timestamp(self):
"""str: Timestamp when creating experiment."""
return self._timestamp
@property
def hooks(self):
"""list[:obj:`Hook`]: A list of registered hooks."""
return self._hooks
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
@property
def train_loop(self):
""":obj:`BaseLoop`: A loop to run training."""
if isinstance(self._train_loop, BaseLoop) or self._train_loop is None:
return self._train_loop
else:
self._train_loop = self.build_train_loop(self._train_loop)
return self._train_loop
@property
def val_loop(self):
""":obj:`BaseLoop`: A loop to run validation."""
if isinstance(self._val_loop, BaseLoop) or self._val_loop is None:
return self._val_loop
else:
self._val_loop = self.build_val_loop(self._val_loop)
return self._val_loop
@property
def test_loop(self):
""":obj:`BaseLoop`: A loop to run testing."""
if isinstance(self._test_loop, BaseLoop) or self._test_loop is None:
return self._test_loop
else:
self._test_loop = self.build_test_loop(self._test_loop)
return self._test_loop
@property
def train_dataloader(self):
"""The data loader for training."""
return self.train_loop.dataloader
@property
def val_dataloader(self):
"""The data loader for validation."""
return self.val_loop.dataloader
@property
def test_dataloader(self):
"""The data loader for testing."""
return self.test_loop.dataloader
@property
def val_evaluator(self):
""":obj:`Evaluator`: An evaluator for validation."""
return self.val_loop.evaluator
@property
def test_evaluator(self):
""":obj:`Evaluator`: An evaluator for testing."""
return self.test_loop.evaluator
@property
def val_interval(self):
"""int: Interval to run validation during training."""
RangiLyu
committed
return self.train_loop.val_interval
@property
def val_begin(self):
"""int: The epoch/iteration to start running validation during
training."""
RangiLyu
committed
return self.train_loop.val_begin
def setup_env(self, env_cfg: Dict) -> None:
"""Setup environment.
An example of ``env_cfg``::
env_cfg = dict(
cudnn_benchmark=True,
mp_cfg=dict(
mp_start_method='fork',
opencv_num_threads=0
),
dist_cfg=dict(backend='nccl', timeout=1800),
)
Args:
env_cfg (dict): Config for setting environment.
"""
if env_cfg.get('cudnn_benchmark'):
torch.backends.cudnn.benchmark = True
Haian Huang(深度眸)
committed
mp_cfg: dict = env_cfg.get('mp_cfg', {})
set_multi_processing(**mp_cfg, distributed=self.distributed)
# init distributed env first, since logger depends on the dist info.
if self.distributed and not is_distributed():
Haian Huang(深度眸)
committed
dist_cfg: dict = env_cfg.get('dist_cfg', {})
init_dist(self.launcher, **dist_cfg)
self._rank, self._world_size = get_dist_info()
timestamp = torch.tensor(time.time(), dtype=torch.float64)
# broadcast timestamp from 0 process to other processes
broadcast(timestamp)
self._timestamp = time.strftime('%Y%m%d_%H%M%S',
time.localtime(timestamp.item()))
# https://github.com/pytorch/pytorch/issues/973
# set resource limit
if platform.system() != 'Windows':
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
base_soft_limit = rlimit[0]
hard_limit = rlimit[1]
soft_limit = min(
max(env_cfg.get('resource_limit', 4096), base_soft_limit),
hard_limit)
resource.setrlimit(resource.RLIMIT_NOFILE,
(soft_limit, hard_limit))
def set_randomness(self,
seed,
diff_rank_seed: bool = False,
deterministic: bool = False) -> None:
"""Set random seed to guarantee reproducible results.
Args:
seed (int): A number to set random modules.
diff_rank_seed (bool): Whether or not set different seeds according
to global rank. Defaults to False.
deterministic (bool): Whether to set the deterministic option for
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False.
Defaults to False.
See https://pytorch.org/docs/stable/notes/randomness.html for
more details.
self._seed = set_random_seed(
seed=seed,
deterministic=deterministic,
diff_rank_seed=diff_rank_seed)
log_level: Union[int, str] = 'INFO',
log_file: str = None,
**kwargs) -> MMLogger:
"""Build a global asscessable MMLogger.
Args:
log_level (int or str): The log level of MMLogger handlers.
Defaults to 'INFO'.
log_file (str, optional): Path of filename to save log.
Defaults to None.
**kwargs: Remaining parameters passed to ``MMLogger``.
Returns:
MMLogger: A MMLogger object build from ``logger``.
"""
log_file = osp.join(self._log_dir, f'{self.timestamp}.log')
log_cfg = dict(log_level=log_level, log_file=log_file, **kwargs)
log_cfg.setdefault('name', self._experiment_name)
# `torch.compile` in PyTorch 2.0 could close all user defined handlers
# unexpectedly. Using file mode 'a' can help prevent abnormal
# termination of the FileHandler and ensure that the log file could
# be continuously updated during the lifespan of the runner.
log_cfg.setdefault('file_mode', 'a')
return MMLogger.get_instance(**log_cfg) # type: ignore
def build_message_hub(self,
message_hub: Optional[Dict] = None) -> MessageHub:
"""Build a global asscessable MessageHub.
Args:
message_hub (dict, optional): A dict to build MessageHub object.
If not specified, default config will be used to build
MessageHub object. Defaults to None.
Returns:
MessageHub: A MessageHub object build from ``message_hub``.
"""
message_hub = dict(name=self._experiment_name)
elif isinstance(message_hub, dict):
# ensure message_hub containing name key
message_hub.setdefault('name', self._experiment_name)
else:
raise TypeError(
f'message_hub should be dict or None, but got {message_hub}')
return MessageHub.get_instance(**message_hub)
def build_visualizer(
self,
visualizer: Optional[Union[Visualizer,
Dict]] = None) -> Visualizer:
"""Build a global asscessable Visualizer.
visualizer (Visualizer or dict, optional): A Visualizer object
or a dict to build Visualizer object. If ``visualizer`` is a
Visualizer object, just returns itself. If not specified,
default config will be used to build Visualizer object.
Visualizer: A Visualizer object build from ``visualizer``.
vis_backends=[dict(type='LocalVisBackend')],
save_dir=self._log_dir)
return Visualizer.get_instance(**visualizer)
if isinstance(visualizer, Visualizer):
return visualizer
if isinstance(visualizer, dict):
# ensure visualizer containing name key
visualizer.setdefault('name', self._experiment_name)
visualizer.setdefault('save_dir', self._log_dir)
'visualizer should be Visualizer object, a dict or None, '
f'but got {visualizer}')
def build_model(self, model: Union[nn.Module, Dict]) -> nn.Module:
If ``model`` is a dict, it will be used to build a nn.Module object.
Else, if ``model`` is a nn.Module object it will be returned directly.
An example of ``model``::
model = dict(type='ResNet')
Args:
model (nn.Module or dict): A ``nn.Module`` object or a dict to
build nn.Module object. If ``model`` is a nn.Module object,
just returns itself.
Note:
The returned model must implement ``train_step``, ``test_step``
if ``runner.train`` or ``runner.test`` will be called. If
``runner.val`` will be called or ``val_cfg`` is configured,
model must implement `val_step`.
Returns:
nn.Module: Model build from ``model``.
"""
if isinstance(model, nn.Module):
return model
elif isinstance(model, dict):
return model # type: ignore
else:
raise TypeError('model should be a nn.Module object or dict, '
f'but got {model}')
def wrap_model(
self, model_wrapper_cfg: Optional[Dict],
model: nn.Module) -> Union[DistributedDataParallel, nn.Module]:
"""Wrap the model to :obj:``MMDistributedDataParallel`` or other custom
distributed data-parallel module wrappers.
An example of ``model_wrapper_cfg``::
model_wrapper_cfg = dict(
broadcast_buffers=False,
find_unused_parameters=False
)
Args:
model_wrapper_cfg (dict, optional): Config to wrap model. If not
specified, ``DistributedDataParallel`` will be used in
distributed environment. Defaults to None.
model (nn.Module): Model to be wrapped.
nn.Module or DistributedDataParallel: nn.Module or subclass of
``DistributedDataParallel``.
"""
if is_model_wrapper(model):
if model_wrapper_cfg is not None:
raise TypeError(
'model has been wrapped and "model_wrapper_cfg" should be '
f'None, but got {model_wrapper_cfg}')
return model
# Set `export CUDA_VISIBLE_DEVICES=-1` to enable CPU training.
if not self.distributed:
self.logger.info(
'Distributed training is not used, all SyncBatchNorm (SyncBN) '
'layers in the model will be automatically reverted to '
'BatchNormXd layers if they are used.')
model = revert_sync_batchnorm(model)
return model # type: ignore
else:
sync_bn = self.cfg.get('sync_bn', None)
if sync_bn is not None:
try:
model = convert_sync_batchnorm(model, sync_bn)
except ValueError as e:
self.logger.error('cfg.sync_bn should be "torch" or '
f'"mmcv", but got {sync_bn}')
raise e
find_unused_parameters = self.cfg.get('find_unused_parameters',
False)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
# TODO: may use a more elegant way to get local device ID.
model = MMDistributedDataParallel(
module=model,
device_ids=[int(os.environ['LOCAL_RANK'])],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
model_wrapper_type = MODEL_WRAPPERS.get(
model_wrapper_cfg.get('type')) # type: ignore
default_args: dict = dict()
if issubclass(
model_wrapper_type, # type: ignore
DistributedDataParallel):
default_args['device_ids'] = [int(os.environ['LOCAL_RANK'])]
default_args['module'] = model
model_wrapper_cfg, default_args=default_args)
def _init_model_weights(self) -> None:
"""Initialize the model weights if the model has
:meth:`init_weights`"""
model = self.model.module if is_model_wrapper(
self.model) else self.model
if hasattr(model, 'init_weights'):
model.init_weights()
# sync params and buffers
for name, params in model.state_dict().items():
broadcast(params)
def scale_lr(self,
optim_wrapper: OptimWrapper,
auto_scale_lr: Optional[Dict] = None) -> None:
"""Automatically scaling learning rate in training according to the
ratio of ``base_batch_size`` in ``autoscalelr_cfg`` and real batch
size.
It scales the learning rate linearly according to the
`paper <https://arxiv.org/abs/1706.02677>`_.
Note:
``scale_lr`` must be called after building optimizer wrappers
and before building parameter schedulers.
Args:
optim_wrapper (OptimWrapper): An OptimWrapper object whose
parameter groups' learning rate need to be scaled.
auto_scale_lr (Dict, Optional): Config to scale the learning
rate automatically. It includes ``base_batch_size`` and
``enable``. ``base_batch_size`` is the batch size that the
optimizer lr is based on. ``enable`` is the switch to turn on
and off the feature.
"""
if (auto_scale_lr is None or not auto_scale_lr.get('enable', False)):
assert 'base_batch_size' in auto_scale_lr, \
'Lack of `base_batch_size` in `auto_scale_lr`.'
dataloader: Union[DataLoader, Dict] = self._train_dataloader
bs = dataloader.batch_size if isinstance(
dataloader, DataLoader) else dataloader['batch_size']
real_bs = self.world_size * bs
base_bs = auto_scale_lr['base_batch_size']
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
ratio = float(real_bs) / float(base_bs)
self.logger.info(f'LR is set based on batch size of {base_bs} '
f'and the current batch size is {real_bs}. '
f'Scaling the original LR by {ratio}.')
def _is_built(schedulers):
if isinstance(schedulers, dict):
return False if 'type' in schedulers else any(
_is_built(s) for s in schedulers.values())
if isinstance(schedulers, list):
return any(_is_built(s) for s in schedulers)
return isinstance(schedulers, _ParamScheduler)
if _is_built(self.param_schedulers):
raise RuntimeError('`scale_lr` should be called before building '
'ParamScheduler because ParamScheduler will '
'store initial lr from optimizer wrappers')
assert isinstance(optim_wrapper, OptimWrapper), \
'`scale_lr should be called after building OptimWrapper'
wrappers = list(optim_wrapper.values()) if isinstance(
optim_wrapper, OptimWrapperDict) else [optim_wrapper]
for wrapper in wrappers:
for group in wrapper.optimizer.param_groups:
group['lr'] = group['lr'] * ratio
def build_optim_wrapper(
self, optim_wrapper: Union[Optimizer, OptimWrapper, Dict]
) -> Union[OptimWrapper, OptimWrapperDict]:
"""Build optimizer wrapper.
If ``optim_wrapper`` is a config dict for only one optimizer,
the keys must contain ``optimizer``, and ``type`` is optional.
It will build a :obj:`OptimWrapper` by default.
If ``optim_wrapper`` is a config dict for multiple optimizers, i.e.,
it has multiple keys and each key is for an optimizer wrapper. The
constructor must be specified since
:obj:`DefaultOptimizerConstructor` cannot handle the building of
training with multiple optimizers.
If ``optim_wrapper`` is a dict of pre-built optimizer wrappers, i.e.,
each value of ``optim_wrapper`` represents an ``OptimWrapper``
instance. ``build_optim_wrapper`` will directly build the
:obj:`OptimWrapperDict` instance from ``optim_wrapper``.
optim_wrapper (OptimWrapper or dict): An OptimWrapper object or a
dict to build OptimWrapper objects. If ``optim_wrapper`` is an
OptimWrapper, just return an ``OptimizeWrapper`` instance.
Note:
For single optimizer training, if `optim_wrapper` is a config
dict, `type` is optional(defaults to :obj:`OptimWrapper`) and it
must contain `optimizer` to build the corresponding optimizer.
Examples:
>>> # build an optimizer
>>> optim_wrapper_cfg = dict(type='OptimWrapper', optimizer=dict(
... type='SGD', lr=0.01))
>>> # optim_wrapper_cfg = dict(optimizer=dict(type='SGD', lr=0.01))
>>> # is also valid.
>>> optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg)
>>> optim_wrapper