diff --git a/mmengine/runner/base_loop.py b/mmengine/runner/base_loop.py index ec3f880dd42b539ecf484f55152525154292bb5c..e3916a24bf78c54a0702d3acef4a0516f49de576 100644 --- a/mmengine/runner/base_loop.py +++ b/mmengine/runner/base_loop.py @@ -19,9 +19,9 @@ class BaseLoop(metaclass=ABCMeta): def __init__(self, runner, dataloader: Union[DataLoader, Dict]) -> None: self._runner = runner - if isinstance(dataloader, dict): - self.dataloader = runner.build_dataloader(dataloader) + self.dataloader = runner.build_dataloader( + dataloader, seed=runner.seed) else: self.dataloader = dataloader diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index c9e1893be8f47b92073b4393e663491226ba9957..ba4b0089feb564cdef77e92d144d7dd9cfe4d065 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -21,8 +21,8 @@ from torch.utils.data import DataLoader import mmengine from mmengine.config import Config, ConfigDict from mmengine.data import pseudo_collate, worker_init_fn -from mmengine.dist import (broadcast, get_dist_info, init_dist, master_only, - sync_random_seed) +from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist, + master_only, sync_random_seed) from mmengine.evaluator import Evaluator from mmengine.hooks import Hook from mmengine.logging import LogProcessor, MessageHub, MMLogger @@ -893,9 +893,8 @@ class Runner: ] Args: - evaluator (Evaluator or dict or list): - An Evaluator object or a config dict or list of config dict - used to build an Evaluator. + evaluator (Evaluator or dict or list): An Evaluator object or a + config dict or list of config dict used to build an Evaluator. Returns: Evaluator: Evaluator build from ``evaluator``. @@ -909,8 +908,9 @@ class Runner: 'evaluator should be one of dict, list of dict, and Evaluator' f', but got {evaluator}') - def build_dataloader(self, dataloader: Union[DataLoader, - Dict]) -> DataLoader: + @staticmethod + def build_dataloader(dataloader: Union[DataLoader, Dict], + seed: Optional[int] = None) -> DataLoader: """Build dataloader. The method builds three components: @@ -932,6 +932,7 @@ class Runner: dataloader (DataLoader or dict): A Dataloader object or a dict to build Dataloader object. If ``dataloader`` is a Dataloader object, just returns itself. + seed (int, optional): Random seed. Defaults to None. Returns: Dataloader: DataLoader build from ``dataloader_cfg``. @@ -979,12 +980,12 @@ class Runner: # build dataloader init_fn: Optional[partial] - if self.seed is not None: + if seed is not None: init_fn = partial( worker_init_fn, num_workers=dataloader_cfg.get('num_workers'), - rank=self.rank, - seed=self.seed) + rank=get_rank(), + seed=seed) else: init_fn = None