diff --git a/mmengine/runner/__init__.py b/mmengine/runner/__init__.py index 439ce8017eaab62bcc15551d111fcfe4b7a3f2c4..e4f5dfbc8cb3a0e0375d7f203734227ac2102880 100644 --- a/mmengine/runner/__init__.py +++ b/mmengine/runner/__init__.py @@ -10,6 +10,7 @@ from .log_processor import LogProcessor from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop from .priority import Priority, get_priority from .runner import Runner +from .utils import set_random_seed __all__ = [ 'BaseLoop', 'load_state_dict', 'get_torchvision_models', @@ -17,5 +18,5 @@ __all__ = [ 'CheckpointLoader', 'load_checkpoint', 'weights_to_cpu', 'get_state_dict', 'save_checkpoint', 'EpochBasedTrainLoop', 'IterBasedTrainLoop', 'ValLoop', 'TestLoop', 'Runner', 'get_priority', 'Priority', 'find_latest_checkpoint', - 'autocast', 'LogProcessor' + 'autocast', 'LogProcessor', 'set_random_seed' ] diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 4eae559df577926eab99920a3083e09dddd402a1..bdfe156c93c47dc1c727671f2a9f7e401a49c584 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -4,14 +4,12 @@ import logging import os import os.path as osp import platform -import random import time import warnings from collections import OrderedDict from functools import partial from typing import Callable, Dict, List, Optional, Sequence, Union -import numpy as np import torch import torch.nn as nn from torch.nn.parallel.distributed import DistributedDataParallel @@ -23,7 +21,7 @@ from mmengine.config import Config, ConfigDict from mmengine.dataset import COLLATE_FUNCTIONS, worker_init_fn from mmengine.device import get_device from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist, - is_distributed, master_only, sync_random_seed) + is_distributed, master_only) from mmengine.evaluator import Evaluator from mmengine.fileio import FileClient from mmengine.hooks import Hook @@ -48,6 +46,7 @@ from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model, from .log_processor import LogProcessor from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop from .priority import Priority, get_priority +from .utils import set_random_seed ConfigType = Union[Dict, Config, ConfigDict] ParamSchedulerType = Union[List[_ParamScheduler], Dict[str, @@ -683,28 +682,7 @@ class Runner: more details. """ self._deterministic = deterministic - self._seed = seed - if self._seed is None: - self._seed = sync_random_seed() - - if diff_rank_seed: - # set different seeds for different ranks - self._seed = self._seed + get_rank() - random.seed(self._seed) - np.random.seed(self._seed) - torch.manual_seed(self._seed) - torch.cuda.manual_seed_all(self._seed) - if deterministic: - if torch.backends.cudnn.benchmark: - warnings.warn( - 'torch.backends.cudnn.benchmark is going to be set as ' - '`False` to cause cuDNN to deterministically select an ' - 'algorithm') - - torch.backends.cudnn.benchmark = False - torch.backends.cudnn.deterministic = True - if digit_version(TORCH_VERSION) >= digit_version('1.10.0'): - torch.use_deterministic_algorithms(True) + self._seed = set_random_seed(seed, diff_rank_seed, deterministic) def build_logger(self, log_level: Union[int, str] = 'INFO', diff --git a/mmengine/runner/utils.py b/mmengine/runner/utils.py index a8563a1d624d27fcbafc68fdce21d7f37be6c528..db034df7ada8f4d0eedb51d250ceafa32f101906 100644 --- a/mmengine/runner/utils.py +++ b/mmengine/runner/utils.py @@ -1,7 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. +import logging +import random from typing import List, Optional, Tuple -from mmengine.utils import is_list_of +import numpy as np +import torch + +from mmengine.dist import get_rank, sync_random_seed +from mmengine.logging import print_log +from mmengine.utils import digit_version, is_list_of +from mmengine.utils.dl_utils import TORCH_VERSION def calc_dynamic_intervals( @@ -33,3 +41,46 @@ def calc_dynamic_intervals( dynamic_intervals.extend( [dynamic_interval[1] for dynamic_interval in dynamic_interval_list]) return dynamic_milestones, dynamic_intervals + + +def set_random_seed(seed: Optional[int] = None, + deterministic: bool = False, + diff_rank_seed: bool = False) -> int: + """Set random seed. + + Args: + seed (int, optional): Seed to be used. + deterministic (bool): Whether to set the deterministic option for + CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` + to True and `torch.backends.cudnn.benchmark` to False. + Default: False. + diff_rank_seed (bool): Whether to add rank number to the random seed to + have different random seed in different threads. Default: False. + """ + if seed is None: + seed = sync_random_seed() + + if diff_rank_seed: + rank = get_rank() + seed += rank + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + # torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # os.environ['PYTHONHASHSEED'] = str(seed) + if deterministic: + if torch.backends.cudnn.benchmark: + print_log( + 'torch.backends.cudnn.benchmark is going to be set as ' + '`False` to cause cuDNN to deterministically select an ' + 'algorithm', + logger='current', + level=logging.WARNING) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + if digit_version(TORCH_VERSION) >= digit_version('1.10.0'): + torch.use_deterministic_algorithms(True) + return seed