From 38ae566632693125849a11ff2b4357f309573817 Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Fri, 26 Aug 2022 11:33:14 +0800 Subject: [PATCH] [Fix] Add `set_random_seed` function in MMEngine (#464) * add set random seed fun * fix conflict * allign the previous version --- mmengine/runner/__init__.py | 3 ++- mmengine/runner/runner.py | 28 +++----------------- mmengine/runner/utils.py | 53 ++++++++++++++++++++++++++++++++++++- 3 files changed, 57 insertions(+), 27 deletions(-) diff --git a/mmengine/runner/__init__.py b/mmengine/runner/__init__.py index 439ce801..e4f5dfbc 100644 --- a/mmengine/runner/__init__.py +++ b/mmengine/runner/__init__.py @@ -10,6 +10,7 @@ from .log_processor import LogProcessor from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop from .priority import Priority, get_priority from .runner import Runner +from .utils import set_random_seed __all__ = [ 'BaseLoop', 'load_state_dict', 'get_torchvision_models', @@ -17,5 +18,5 @@ __all__ = [ 'CheckpointLoader', 'load_checkpoint', 'weights_to_cpu', 'get_state_dict', 'save_checkpoint', 'EpochBasedTrainLoop', 'IterBasedTrainLoop', 'ValLoop', 'TestLoop', 'Runner', 'get_priority', 'Priority', 'find_latest_checkpoint', - 'autocast', 'LogProcessor' + 'autocast', 'LogProcessor', 'set_random_seed' ] diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 4eae559d..bdfe156c 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -4,14 +4,12 @@ import logging import os import os.path as osp import platform -import random import time import warnings from collections import OrderedDict from functools import partial from typing import Callable, Dict, List, Optional, Sequence, Union -import numpy as np import torch import torch.nn as nn from torch.nn.parallel.distributed import DistributedDataParallel @@ -23,7 +21,7 @@ from mmengine.config import Config, ConfigDict from mmengine.dataset import COLLATE_FUNCTIONS, worker_init_fn from mmengine.device import get_device from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist, - is_distributed, master_only, sync_random_seed) + is_distributed, master_only) from mmengine.evaluator import Evaluator from mmengine.fileio import FileClient from mmengine.hooks import Hook @@ -48,6 +46,7 @@ from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model, from .log_processor import LogProcessor from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop from .priority import Priority, get_priority +from .utils import set_random_seed ConfigType = Union[Dict, Config, ConfigDict] ParamSchedulerType = Union[List[_ParamScheduler], Dict[str, @@ -683,28 +682,7 @@ class Runner: more details. """ self._deterministic = deterministic - self._seed = seed - if self._seed is None: - self._seed = sync_random_seed() - - if diff_rank_seed: - # set different seeds for different ranks - self._seed = self._seed + get_rank() - random.seed(self._seed) - np.random.seed(self._seed) - torch.manual_seed(self._seed) - torch.cuda.manual_seed_all(self._seed) - if deterministic: - if torch.backends.cudnn.benchmark: - warnings.warn( - 'torch.backends.cudnn.benchmark is going to be set as ' - '`False` to cause cuDNN to deterministically select an ' - 'algorithm') - - torch.backends.cudnn.benchmark = False - torch.backends.cudnn.deterministic = True - if digit_version(TORCH_VERSION) >= digit_version('1.10.0'): - torch.use_deterministic_algorithms(True) + self._seed = set_random_seed(seed, diff_rank_seed, deterministic) def build_logger(self, log_level: Union[int, str] = 'INFO', diff --git a/mmengine/runner/utils.py b/mmengine/runner/utils.py index a8563a1d..db034df7 100644 --- a/mmengine/runner/utils.py +++ b/mmengine/runner/utils.py @@ -1,7 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. +import logging +import random from typing import List, Optional, Tuple -from mmengine.utils import is_list_of +import numpy as np +import torch + +from mmengine.dist import get_rank, sync_random_seed +from mmengine.logging import print_log +from mmengine.utils import digit_version, is_list_of +from mmengine.utils.dl_utils import TORCH_VERSION def calc_dynamic_intervals( @@ -33,3 +41,46 @@ def calc_dynamic_intervals( dynamic_intervals.extend( [dynamic_interval[1] for dynamic_interval in dynamic_interval_list]) return dynamic_milestones, dynamic_intervals + + +def set_random_seed(seed: Optional[int] = None, + deterministic: bool = False, + diff_rank_seed: bool = False) -> int: + """Set random seed. + + Args: + seed (int, optional): Seed to be used. + deterministic (bool): Whether to set the deterministic option for + CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` + to True and `torch.backends.cudnn.benchmark` to False. + Default: False. + diff_rank_seed (bool): Whether to add rank number to the random seed to + have different random seed in different threads. Default: False. + """ + if seed is None: + seed = sync_random_seed() + + if diff_rank_seed: + rank = get_rank() + seed += rank + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + # torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # os.environ['PYTHONHASHSEED'] = str(seed) + if deterministic: + if torch.backends.cudnn.benchmark: + print_log( + 'torch.backends.cudnn.benchmark is going to be set as ' + '`False` to cause cuDNN to deterministically select an ' + 'algorithm', + logger='current', + level=logging.WARNING) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + if digit_version(TORCH_VERSION) >= digit_version('1.10.0'): + torch.use_deterministic_algorithms(True) + return seed -- GitLab