From 5e1ed7aaf0166f3913c84bef830456cd25c4a497 Mon Sep 17 00:00:00 2001
From: shufan wu <shufanwu@outlook.com>
Date: Mon, 10 Apr 2023 17:32:36 +0800
Subject: [PATCH] [Enhance] Allow users to customize worker_init_fn of
 Dataloader (#1038)

* customize worker init fn function

* add assert

* narrow worker_init_fn type
---
 mmengine/runner/runner.py        | 37 +++++++++++++++++++-------------
 tests/test_runner/test_runner.py | 26 ++++++++++++++++++++++
 2 files changed, 48 insertions(+), 15 deletions(-)

diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py
index f9c6c9af..89004fc6 100644
--- a/mmengine/runner/runner.py
+++ b/mmengine/runner/runner.py
@@ -19,7 +19,7 @@ from torch.utils.data import DataLoader
 
 import mmengine
 from mmengine.config import Config, ConfigDict
-from mmengine.dataset import worker_init_fn
+from mmengine.dataset import worker_init_fn as default_worker_init_fn
 from mmengine.device import get_device
 from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist,
                            is_distributed, master_only)
@@ -1381,21 +1381,28 @@ class Runner:
         # build dataloader
         init_fn: Optional[partial]
 
-        if seed is not None:
-            disable_subprocess_warning = dataloader_cfg.pop(
-                'disable_subprocess_warning', False)
-            assert isinstance(
-                disable_subprocess_warning,
-                bool), ('disable_subprocess_warning should be a bool, but got '
-                        f'{type(disable_subprocess_warning)}')
-            init_fn = partial(
-                worker_init_fn,
-                num_workers=dataloader_cfg.get('num_workers'),
-                rank=get_rank(),
-                seed=seed,
-                disable_subprocess_warning=disable_subprocess_warning)
+        if 'worker_init_fn' in dataloader_cfg:
+            worker_init_fn_cfg = dataloader_cfg.pop('worker_init_fn')
+            worker_init_fn_type = worker_init_fn_cfg.pop('type')
+            worker_init_fn = FUNCTIONS.get(worker_init_fn_type)
+            assert callable(worker_init_fn)
+            init_fn = partial(worker_init_fn,
+                              **worker_init_fn_cfg)  # type: ignore
         else:
-            init_fn = None
+            if seed is not None:
+                disable_subprocess_warning = dataloader_cfg.pop(
+                    'disable_subprocess_warning', False)
+                assert isinstance(disable_subprocess_warning, bool), (
+                    'disable_subprocess_warning should be a bool, but got '
+                    f'{type(disable_subprocess_warning)}')
+                init_fn = partial(
+                    default_worker_init_fn,
+                    num_workers=dataloader_cfg.get('num_workers'),
+                    rank=get_rank(),
+                    seed=seed,
+                    disable_subprocess_warning=disable_subprocess_warning)
+            else:
+                init_fn = None
 
         # `persistent_workers` requires pytorch version >= 1.7
         if ('persistent_workers' in dataloader_cfg
diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py
index 725b511e..fc00efab 100644
--- a/tests/test_runner/test_runner.py
+++ b/tests/test_runner/test_runner.py
@@ -3,6 +3,7 @@ import copy
 import logging
 import os
 import os.path as osp
+import random
 import shutil
 import tempfile
 from unittest import TestCase, skipIf
@@ -352,6 +353,11 @@ def custom_collate(data_batch, pad_value):
     return pseudo_collate(data_batch)
 
 
+def custom_worker_init(worker_id):
+    np.random.seed(worker_id)
+    random.seed(worker_id)
+
+
 class TestRunner(TestCase):
 
     def setUp(self):
@@ -376,6 +382,7 @@ class TestRunner(TestCase):
         RUNNERS.register_module(module=CustomRunner, force=True)
         EVALUATOR.register_module(module=ToyEvaluator, force=True)
         FUNCTIONS.register_module(module=custom_collate, force=True)
+        FUNCTIONS.register_module(module=custom_worker_init, force=True)
 
         self.temp_dir = tempfile.mkdtemp()
         epoch_based_cfg = dict(
@@ -459,6 +466,7 @@ class TestRunner(TestCase):
         RUNNERS.module_dict.pop('CustomRunner')
         EVALUATOR.module_dict.pop('ToyEvaluator')
         FUNCTIONS.module_dict.pop('custom_collate')
+        FUNCTIONS.module_dict.pop('custom_worker_init')
 
         logging.shutdown()
         MMLogger._instance_dict.clear()
@@ -1245,6 +1253,16 @@ class TestRunner(TestCase):
             cfg, seed=seed, diff_rank_seed=True)
         self.assertNotEqual(dataloader.sampler.seed, seed)
 
+        # custom worker_init_fn
+        cfg = dict(
+            dataset=dict(type='ToyDataset'),
+            sampler=dict(type='DefaultSampler', shuffle=True),
+            worker_init_fn=dict(type='custom_worker_init'),
+            batch_size=1,
+            num_workers=2)
+        dataloader = runner.build_dataloader(cfg)
+        self.assertIs(dataloader.worker_init_fn.func, custom_worker_init)
+
     def test_build_train_loop(self):
         cfg = copy.deepcopy(self.epoch_based_cfg)
         cfg.experiment_name = 'test_build_train_loop'
@@ -1689,6 +1707,14 @@ class TestRunner(TestCase):
         runner = Runner.from_cfg(cfg)
         runner.train()
 
+        # 10.3 Test build dataloader with custom worker_init function
+        cfg = copy.deepcopy(self.iter_based_cfg)
+        cfg.experiment_name = 'test_train10.3'
+        cfg.train_dataloader.update(
+            worker_init_fn=dict(type='custom_worker_init'))
+        runner = Runner.from_cfg(cfg)
+        runner.train()
+
         # 11 test build dataloader without default arguments of collate
         # function.
         with self.assertRaises(TypeError):
-- 
GitLab