From 63ea47460c8778e38050821a5dbb9c96e8940061 Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Thu, 12 May 2022 18:36:44 +0800 Subject: [PATCH] [Refactor] Change Runner.build_dataloader to a static method (#219) * [Refactor] Change some methods to static methods * only change build_dataloader to static method --- mmengine/runner/base_loop.py | 4 ++-- mmengine/runner/runner.py | 21 +++++++++++---------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/mmengine/runner/base_loop.py b/mmengine/runner/base_loop.py index ec3f880d..e3916a24 100644 --- a/mmengine/runner/base_loop.py +++ b/mmengine/runner/base_loop.py @@ -19,9 +19,9 @@ class BaseLoop(metaclass=ABCMeta): def __init__(self, runner, dataloader: Union[DataLoader, Dict]) -> None: self._runner = runner - if isinstance(dataloader, dict): - self.dataloader = runner.build_dataloader(dataloader) + self.dataloader = runner.build_dataloader( + dataloader, seed=runner.seed) else: self.dataloader = dataloader diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index c9e1893b..ba4b0089 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -21,8 +21,8 @@ from torch.utils.data import DataLoader import mmengine from mmengine.config import Config, ConfigDict from mmengine.data import pseudo_collate, worker_init_fn -from mmengine.dist import (broadcast, get_dist_info, init_dist, master_only, - sync_random_seed) +from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist, + master_only, sync_random_seed) from mmengine.evaluator import Evaluator from mmengine.hooks import Hook from mmengine.logging import LogProcessor, MessageHub, MMLogger @@ -893,9 +893,8 @@ class Runner: ] Args: - evaluator (Evaluator or dict or list): - An Evaluator object or a config dict or list of config dict - used to build an Evaluator. + evaluator (Evaluator or dict or list): An Evaluator object or a + config dict or list of config dict used to build an Evaluator. Returns: Evaluator: Evaluator build from ``evaluator``. @@ -909,8 +908,9 @@ class Runner: 'evaluator should be one of dict, list of dict, and Evaluator' f', but got {evaluator}') - def build_dataloader(self, dataloader: Union[DataLoader, - Dict]) -> DataLoader: + @staticmethod + def build_dataloader(dataloader: Union[DataLoader, Dict], + seed: Optional[int] = None) -> DataLoader: """Build dataloader. The method builds three components: @@ -932,6 +932,7 @@ class Runner: dataloader (DataLoader or dict): A Dataloader object or a dict to build Dataloader object. If ``dataloader`` is a Dataloader object, just returns itself. + seed (int, optional): Random seed. Defaults to None. Returns: Dataloader: DataLoader build from ``dataloader_cfg``. @@ -979,12 +980,12 @@ class Runner: # build dataloader init_fn: Optional[partial] - if self.seed is not None: + if seed is not None: init_fn = partial( worker_init_fn, num_workers=dataloader_cfg.get('num_workers'), - rank=self.rank, - seed=self.seed) + rank=get_rank(), + seed=seed) else: init_fn = None -- GitLab