From 5e1ed7aaf0166f3913c84bef830456cd25c4a497 Mon Sep 17 00:00:00 2001 From: shufan wu <shufanwu@outlook.com> Date: Mon, 10 Apr 2023 17:32:36 +0800 Subject: [PATCH] [Enhance] Allow users to customize worker_init_fn of Dataloader (#1038) * customize worker init fn function * add assert * narrow worker_init_fn type --- mmengine/runner/runner.py | 37 +++++++++++++++++++------------- tests/test_runner/test_runner.py | 26 ++++++++++++++++++++++ 2 files changed, 48 insertions(+), 15 deletions(-) diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index f9c6c9af..89004fc6 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -19,7 +19,7 @@ from torch.utils.data import DataLoader import mmengine from mmengine.config import Config, ConfigDict -from mmengine.dataset import worker_init_fn +from mmengine.dataset import worker_init_fn as default_worker_init_fn from mmengine.device import get_device from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist, is_distributed, master_only) @@ -1381,21 +1381,28 @@ class Runner: # build dataloader init_fn: Optional[partial] - if seed is not None: - disable_subprocess_warning = dataloader_cfg.pop( - 'disable_subprocess_warning', False) - assert isinstance( - disable_subprocess_warning, - bool), ('disable_subprocess_warning should be a bool, but got ' - f'{type(disable_subprocess_warning)}') - init_fn = partial( - worker_init_fn, - num_workers=dataloader_cfg.get('num_workers'), - rank=get_rank(), - seed=seed, - disable_subprocess_warning=disable_subprocess_warning) + if 'worker_init_fn' in dataloader_cfg: + worker_init_fn_cfg = dataloader_cfg.pop('worker_init_fn') + worker_init_fn_type = worker_init_fn_cfg.pop('type') + worker_init_fn = FUNCTIONS.get(worker_init_fn_type) + assert callable(worker_init_fn) + init_fn = partial(worker_init_fn, + **worker_init_fn_cfg) # type: ignore else: - init_fn = None + if seed is not None: + disable_subprocess_warning = dataloader_cfg.pop( + 'disable_subprocess_warning', False) + assert isinstance(disable_subprocess_warning, bool), ( + 'disable_subprocess_warning should be a bool, but got ' + f'{type(disable_subprocess_warning)}') + init_fn = partial( + default_worker_init_fn, + num_workers=dataloader_cfg.get('num_workers'), + rank=get_rank(), + seed=seed, + disable_subprocess_warning=disable_subprocess_warning) + else: + init_fn = None # `persistent_workers` requires pytorch version >= 1.7 if ('persistent_workers' in dataloader_cfg diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index 725b511e..fc00efab 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -3,6 +3,7 @@ import copy import logging import os import os.path as osp +import random import shutil import tempfile from unittest import TestCase, skipIf @@ -352,6 +353,11 @@ def custom_collate(data_batch, pad_value): return pseudo_collate(data_batch) +def custom_worker_init(worker_id): + np.random.seed(worker_id) + random.seed(worker_id) + + class TestRunner(TestCase): def setUp(self): @@ -376,6 +382,7 @@ class TestRunner(TestCase): RUNNERS.register_module(module=CustomRunner, force=True) EVALUATOR.register_module(module=ToyEvaluator, force=True) FUNCTIONS.register_module(module=custom_collate, force=True) + FUNCTIONS.register_module(module=custom_worker_init, force=True) self.temp_dir = tempfile.mkdtemp() epoch_based_cfg = dict( @@ -459,6 +466,7 @@ class TestRunner(TestCase): RUNNERS.module_dict.pop('CustomRunner') EVALUATOR.module_dict.pop('ToyEvaluator') FUNCTIONS.module_dict.pop('custom_collate') + FUNCTIONS.module_dict.pop('custom_worker_init') logging.shutdown() MMLogger._instance_dict.clear() @@ -1245,6 +1253,16 @@ class TestRunner(TestCase): cfg, seed=seed, diff_rank_seed=True) self.assertNotEqual(dataloader.sampler.seed, seed) + # custom worker_init_fn + cfg = dict( + dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', shuffle=True), + worker_init_fn=dict(type='custom_worker_init'), + batch_size=1, + num_workers=2) + dataloader = runner.build_dataloader(cfg) + self.assertIs(dataloader.worker_init_fn.func, custom_worker_init) + def test_build_train_loop(self): cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_build_train_loop' @@ -1689,6 +1707,14 @@ class TestRunner(TestCase): runner = Runner.from_cfg(cfg) runner.train() + # 10.3 Test build dataloader with custom worker_init function + cfg = copy.deepcopy(self.iter_based_cfg) + cfg.experiment_name = 'test_train10.3' + cfg.train_dataloader.update( + worker_init_fn=dict(type='custom_worker_init')) + runner = Runner.from_cfg(cfg) + runner.train() + # 11 test build dataloader without default arguments of collate # function. with self.assertRaises(TypeError): -- GitLab