From 38ae566632693125849a11ff2b4357f309573817 Mon Sep 17 00:00:00 2001
From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>
Date: Fri, 26 Aug 2022 11:33:14 +0800
Subject: [PATCH] [Fix] Add `set_random_seed` function in MMEngine (#464)

* add set random seed fun

* fix conflict

* allign the previous version
---
 mmengine/runner/__init__.py |  3 ++-
 mmengine/runner/runner.py   | 28 +++-----------------
 mmengine/runner/utils.py    | 53 ++++++++++++++++++++++++++++++++++++-
 3 files changed, 57 insertions(+), 27 deletions(-)

diff --git a/mmengine/runner/__init__.py b/mmengine/runner/__init__.py
index 439ce801..e4f5dfbc 100644
--- a/mmengine/runner/__init__.py
+++ b/mmengine/runner/__init__.py
@@ -10,6 +10,7 @@ from .log_processor import LogProcessor
 from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop
 from .priority import Priority, get_priority
 from .runner import Runner
+from .utils import set_random_seed
 
 __all__ = [
     'BaseLoop', 'load_state_dict', 'get_torchvision_models',
@@ -17,5 +18,5 @@ __all__ = [
     'CheckpointLoader', 'load_checkpoint', 'weights_to_cpu', 'get_state_dict',
     'save_checkpoint', 'EpochBasedTrainLoop', 'IterBasedTrainLoop', 'ValLoop',
     'TestLoop', 'Runner', 'get_priority', 'Priority', 'find_latest_checkpoint',
-    'autocast', 'LogProcessor'
+    'autocast', 'LogProcessor', 'set_random_seed'
 ]
diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py
index 4eae559d..bdfe156c 100644
--- a/mmengine/runner/runner.py
+++ b/mmengine/runner/runner.py
@@ -4,14 +4,12 @@ import logging
 import os
 import os.path as osp
 import platform
-import random
 import time
 import warnings
 from collections import OrderedDict
 from functools import partial
 from typing import Callable, Dict, List, Optional, Sequence, Union
 
-import numpy as np
 import torch
 import torch.nn as nn
 from torch.nn.parallel.distributed import DistributedDataParallel
@@ -23,7 +21,7 @@ from mmengine.config import Config, ConfigDict
 from mmengine.dataset import COLLATE_FUNCTIONS, worker_init_fn
 from mmengine.device import get_device
 from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist,
-                           is_distributed, master_only, sync_random_seed)
+                           is_distributed, master_only)
 from mmengine.evaluator import Evaluator
 from mmengine.fileio import FileClient
 from mmengine.hooks import Hook
@@ -48,6 +46,7 @@ from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
 from .log_processor import LogProcessor
 from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop
 from .priority import Priority, get_priority
+from .utils import set_random_seed
 
 ConfigType = Union[Dict, Config, ConfigDict]
 ParamSchedulerType = Union[List[_ParamScheduler], Dict[str,
@@ -683,28 +682,7 @@ class Runner:
                 more details.
         """
         self._deterministic = deterministic
-        self._seed = seed
-        if self._seed is None:
-            self._seed = sync_random_seed()
-
-        if diff_rank_seed:
-            # set different seeds for different ranks
-            self._seed = self._seed + get_rank()
-        random.seed(self._seed)
-        np.random.seed(self._seed)
-        torch.manual_seed(self._seed)
-        torch.cuda.manual_seed_all(self._seed)
-        if deterministic:
-            if torch.backends.cudnn.benchmark:
-                warnings.warn(
-                    'torch.backends.cudnn.benchmark is going to be set as '
-                    '`False` to cause cuDNN to deterministically select an '
-                    'algorithm')
-
-            torch.backends.cudnn.benchmark = False
-            torch.backends.cudnn.deterministic = True
-            if digit_version(TORCH_VERSION) >= digit_version('1.10.0'):
-                torch.use_deterministic_algorithms(True)
+        self._seed = set_random_seed(seed, diff_rank_seed, deterministic)
 
     def build_logger(self,
                      log_level: Union[int, str] = 'INFO',
diff --git a/mmengine/runner/utils.py b/mmengine/runner/utils.py
index a8563a1d..db034df7 100644
--- a/mmengine/runner/utils.py
+++ b/mmengine/runner/utils.py
@@ -1,7 +1,15 @@
 # Copyright (c) OpenMMLab. All rights reserved.
+import logging
+import random
 from typing import List, Optional, Tuple
 
-from mmengine.utils import is_list_of
+import numpy as np
+import torch
+
+from mmengine.dist import get_rank, sync_random_seed
+from mmengine.logging import print_log
+from mmengine.utils import digit_version, is_list_of
+from mmengine.utils.dl_utils import TORCH_VERSION
 
 
 def calc_dynamic_intervals(
@@ -33,3 +41,46 @@ def calc_dynamic_intervals(
     dynamic_intervals.extend(
         [dynamic_interval[1] for dynamic_interval in dynamic_interval_list])
     return dynamic_milestones, dynamic_intervals
+
+
+def set_random_seed(seed: Optional[int] = None,
+                    deterministic: bool = False,
+                    diff_rank_seed: bool = False) -> int:
+    """Set random seed.
+
+    Args:
+        seed (int, optional): Seed to be used.
+        deterministic (bool): Whether to set the deterministic option for
+            CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
+            to True and `torch.backends.cudnn.benchmark` to False.
+            Default: False.
+        diff_rank_seed (bool): Whether to add rank number to the random seed to
+            have different random seed in different threads. Default: False.
+    """
+    if seed is None:
+        seed = sync_random_seed()
+
+    if diff_rank_seed:
+        rank = get_rank()
+        seed += rank
+
+    random.seed(seed)
+    np.random.seed(seed)
+    torch.manual_seed(seed)
+    # torch.cuda.manual_seed(seed)
+    torch.cuda.manual_seed_all(seed)
+    # os.environ['PYTHONHASHSEED'] = str(seed)
+    if deterministic:
+        if torch.backends.cudnn.benchmark:
+            print_log(
+                'torch.backends.cudnn.benchmark is going to be set as '
+                '`False` to cause cuDNN to deterministically select an '
+                'algorithm',
+                logger='current',
+                level=logging.WARNING)
+        torch.backends.cudnn.deterministic = True
+        torch.backends.cudnn.benchmark = False
+
+        if digit_version(TORCH_VERSION) >= digit_version('1.10.0'):
+            torch.use_deterministic_algorithms(True)
+    return seed
-- 
GitLab