Skip to content
Snippets Groups Projects
Unverified Commit 1aa14b45 authored by Alexander Pacha's avatar Alexander Pacha Committed by GitHub
Browse files

[Enhancement] Enable timeout in dist training (#877)


* Adding missing pre-commit requirement to tests.txt

* Added support for setting a timeout for distributed learning

* Adding documentation about how to change the runtime timeout into the distributed manual.

* Fixed type in documentation to correctly specify an integer

* Removing type-cast after checking the correct type already before

* Update mmengine/dist/utils.py

Adding an explicit `is not None` to the check

Co-authored-by: default avatarMashiro <57566630+HAOCHENYE@users.noreply.github.com>

* Removing explicit type check and replacing it with more pythonic way of assuming it is the right type and handling the exception if the type doesn't match.

* Removing pre-commit from test requirements again

* Simplified the code according to suggestions from PR

* Update distributed.md

---------

Co-authored-by: default avatarMashiro <57566630+HAOCHENYE@users.noreply.github.com>
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent d1d4609f
No related branches found
No related tags found
No related merge requests found
......@@ -23,6 +23,17 @@ We will detail on these APIs in the following chapters.
- [init_dist](mmengine.dist.init_dist): Launch function of distributed training. Currently it supports 3 launchers including pytorch, slurm and MPI. It also setup the given communication backends, defaults to NCCL.
If you need to change the runtime timeout (default=30 minutes) for distributed operations that take very long, you can specify a different timeout in your `env_cfg` configuration passing in [Runner](mmengine.runner.Runner) like this:
```python
env_cfg = dict(
cudnn_benchmark=True,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl', timeout=10800), # Sets the timeout to 3h (10800 seconds)
)
runner = Runner(xxx, env_cfg=env_cfg)
```
## Query and control
The query and control functions are all argument free.
......
# Copyright (c) OpenMMLab. All rights reserved.
import datetime
import functools
import os
import subprocess
......@@ -50,6 +51,19 @@ def init_dist(launcher, backend='nccl', **kwargs) -> None:
'gloo' and 'mpi'. Defaults to 'nccl'.
**kwargs: keyword arguments are passed to ``init_process_group``.
"""
timeout = kwargs.get('timeout', None)
if timeout is not None:
# If a timeout (in seconds) is specified, it must be converted
# to a timedelta object before forwarding the call to
# the respective backend, because they expect a timedelta object.
try:
kwargs['timeout'] = datetime.timedelta(seconds=timeout)
except TypeError as exception:
raise TypeError(
f'Timeout for distributed training must be provided as '
f"timeout in seconds, but we've received the type "
f'{type(timeout)}. Please specify the timeout like this: '
f"dist_cfg=dict(backend='nccl', timeout=1800)") from exception
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
if launcher == 'pytorch':
......
......@@ -625,7 +625,7 @@ class Runner:
mp_start_method='fork',
opencv_num_threads=0
),
dist_cfg=dict(backend='nccl'),
dist_cfg=dict(backend='nccl', timeout=1800),
resource_limit=4096
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment