Skip to content
Snippets Groups Projects
Unverified Commit ad590e45 authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Enhance] Disable warning of subprocess launched by dataloader (#870)

* Disable warning of subprocess launched by dataloader

* Add type hint
parent 0b59a90a
No related branches found
No related tags found
No related merge requests found
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import random import random
import warnings
from typing import Any, Mapping, Sequence from typing import Any, Mapping, Sequence
import numpy as np import numpy as np
...@@ -13,8 +14,11 @@ from mmengine.structures import BaseDataElement ...@@ -13,8 +14,11 @@ from mmengine.structures import BaseDataElement
COLLATE_FUNCTIONS = Registry('Collate Functions') COLLATE_FUNCTIONS = Registry('Collate Functions')
def worker_init_fn(worker_id: int, num_workers: int, rank: int, def worker_init_fn(worker_id: int,
seed: int) -> None: num_workers: int,
rank: int,
seed: int,
disable_subprocess_warning: bool = False) -> None:
"""This function will be called on each worker subprocess after seeding and """This function will be called on each worker subprocess after seeding and
before data loading. before data loading.
...@@ -31,6 +35,8 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int, ...@@ -31,6 +35,8 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int,
np.random.seed(worker_seed) np.random.seed(worker_seed)
random.seed(worker_seed) random.seed(worker_seed)
torch.manual_seed(worker_seed) torch.manual_seed(worker_seed)
if disable_subprocess_warning and worker_id != 0:
warnings.simplefilter('ignore')
@COLLATE_FUNCTIONS.register_module() @COLLATE_FUNCTIONS.register_module()
......
...@@ -1367,12 +1367,20 @@ class Runner: ...@@ -1367,12 +1367,20 @@ class Runner:
# build dataloader # build dataloader
init_fn: Optional[partial] init_fn: Optional[partial]
if seed is not None: if seed is not None:
disable_subprocess_warning = dataloader_cfg.pop(
'disable_subprocess_warning', False)
assert isinstance(
disable_subprocess_warning,
bool), ('disable_subprocess_warning should be a bool, but got '
f'{type(disable_subprocess_warning)}')
init_fn = partial( init_fn = partial(
worker_init_fn, worker_init_fn,
num_workers=dataloader_cfg.get('num_workers'), num_workers=dataloader_cfg.get('num_workers'),
rank=get_rank(), rank=get_rank(),
seed=seed) seed=seed,
disable_subprocess_warning=disable_subprocess_warning)
else: else:
init_fn = None init_fn = None
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment