Skip to content
Snippets Groups Projects
Unverified Commit 03d5c17b authored by Yuan Liu's avatar Yuan Liu Committed by GitHub
Browse files

[Feature]: Set different seed to different rank (#298)

* [Feature]: Set different seed for diff rank

* [Feature]: Add log

* [Fix]: Fix lint

* [Fix]: Fix docstring

* [Fix]: Fix sampler seed

* [Fix]: Fix log bug

* [Fix]: Change diff_seed to diff_rank_seed

* [Fix]: Fix lint
parent 12f7d3a0
No related branches found
No related tags found
No related merge requests found
......@@ -20,8 +20,11 @@ class BaseLoop(metaclass=ABCMeta):
def __init__(self, runner, dataloader: Union[DataLoader, Dict]) -> None:
self._runner = runner
if isinstance(dataloader, dict):
# Determine whether or not different ranks use different seed.
diff_rank_seed = runner._randomness_cfg.get(
'diff_rank_seed', False)
self.dataloader = runner.build_dataloader(
dataloader, seed=runner.seed)
dataloader, seed=runner.seed, diff_rank_seed=diff_rank_seed)
else:
self.dataloader = dataloader
......
......@@ -649,11 +649,16 @@ class Runner:
resource.setrlimit(resource.RLIMIT_NOFILE,
(soft_limit, hard_limit))
def set_randomness(self, seed, deterministic: bool = False) -> None:
def set_randomness(self,
seed,
diff_rank_seed: bool = False,
deterministic: bool = False) -> None:
"""Set random seed to guarantee reproducible results.
Args:
seed (int): A number to set random modules.
diff_rank_seed (bool): Whether or not set different seeds according
to global rank. Defaults to False.
deterministic (bool): Whether to set the deterministic option for
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False.
......@@ -666,6 +671,9 @@ class Runner:
if self._seed is None:
self._seed = sync_random_seed()
if diff_rank_seed:
# set different seeds for different ranks
self._seed = self._seed + get_rank()
random.seed(self._seed)
np.random.seed(self._seed)
torch.manual_seed(self._seed)
......@@ -1254,7 +1262,8 @@ class Runner:
@staticmethod
def build_dataloader(dataloader: Union[DataLoader, Dict],
seed: Optional[int] = None) -> DataLoader:
seed: Optional[int] = None,
diff_rank_seed: bool = False) -> DataLoader:
"""Build dataloader.
The method builds three components:
......@@ -1277,6 +1286,11 @@ class Runner:
build Dataloader object. If ``dataloader`` is a Dataloader
object, just returns itself.
seed (int, optional): Random seed. Defaults to None.
diff_rank_seed (bool): Whether or not set different seeds to
different ranks. If True, the seed passed to sampler is set
to None, in order to synchronize the seeds used in samplers
across different ranks.
Returns:
Dataloader: DataLoader build from ``dataloader_cfg``.
......@@ -1300,8 +1314,10 @@ class Runner:
# build sampler
sampler_cfg = dataloader_cfg.pop('sampler')
if isinstance(sampler_cfg, dict):
sampler_seed = None if diff_rank_seed else seed
sampler = DATA_SAMPLERS.build(
sampler_cfg, default_args=dict(dataset=dataset, seed=seed))
sampler_cfg,
default_args=dict(dataset=dataset, seed=sampler_seed))
else:
# fallback to raise error in dataloader
# if `sampler_cfg` is not a valid type
......
......@@ -998,6 +998,11 @@ class TestRunner(TestCase):
self.assertIsInstance(dataloader.sampler, DefaultSampler)
self.assertEqual(dataloader.sampler.seed, seed)
# diff_rank_seed is True
dataloader = runner.build_dataloader(
cfg, seed=seed, diff_rank_seed=True)
self.assertNotEqual(dataloader.sampler.seed, seed)
def test_build_train_loop(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_train_loop'
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment