Skip to content
Snippets Groups Projects
Unverified Commit c6a8d72c authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

[Feature] Add distributed module (#59)

* [Feature] Add distributed module

* fix IS_DIST error

* all_reduce_dict does operations in-place

* support 'mean' operation

* provide local group process

* add tmpdir argument for collect_results

* add unit tests

* refactor unit tests

* simplify steps to create multiple processes

* minor fix

* describe the different of *gather* in mmengine and pytorch

* minor fix

* add unit tests for nccl

* test nccl backend in multiple gpu

* add get_default_group function to handle different torch versions

* minor fix

* [Feature] Add distributed module

* fix IS_DIST error

* all_reduce_dict does operations in-place

* support 'mean' operation

* provide local group process

* add tmpdir argument for collect_results

* add unit tests

* refactor unit tests

* simplify steps to create multiple processes

* minor fix

* describe the different of *gather* in mmengine and pytorch

* minor fix

* add unit tests for nccl

* test nccl backend in multiple gpu

* add get_default_group function to handle different torch versions

* minor fix

* minor fix

* handle torch1.5

* handle torch1.5

* minor fix

* fix typo

* refactor unit tests

* nccl does not support gather and gather_object

* fix gather

* fix collect_results_cpu

* fix collect_results and refactor unit tests

* fix collect_results unit tests

* handle torch.cat in torch1.5

* refine docstring

* refine docstring

* fix comments

* fix comments
parent 817eb89a
No related branches found
No related tags found
No related merge requests found
......@@ -10,7 +10,6 @@ __pycache__/
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
......
......@@ -7,3 +7,8 @@ Data
--------
.. automodule:: mmengine.data
:members:
Distributed
-----------
.. automodule:: mmengine.dist
:members:
......@@ -7,3 +7,8 @@ Data
--------
.. automodule:: mmengine.data
:members:
Distributed
-----------
.. automodule:: mmengine.dist
:members:
# Copyright (c) OpenMMLab. All rights reserved.
from .dist import (all_gather_object, all_reduce, all_gather, all_reduce_dict,
collect_results, gather, broadcast, gather_object,
sync_random_seed, broadcast_object_list,
collect_results_cpu, collect_results_gpu)
from .utils import (get_dist_info, init_dist, init_local_group, get_backend,
get_world_size, get_rank, get_local_size, get_local_rank,
is_main_process, master_only, barrier, get_local_group,
is_distributed, get_default_group)
__all__ = [
'all_gather_object', 'all_reduce', 'all_gather', 'all_reduce_dict',
'collect_results', 'collect_results_cpu', 'collect_results_gpu', 'gather',
'broadcast', 'gather_object', 'sync_random_seed', 'broadcast_object_list',
'get_dist_info', 'init_dist', 'init_local_group', 'get_backend',
'get_world_size', 'get_rank', 'get_local_size', 'get_local_group',
'get_local_rank', 'is_main_process', 'master_only', 'barrier',
'is_distributed', 'get_default_group'
]
This diff is collapsed.
# Copyright (c) OpenMMLab. All rights reserved.
import functools
import os
import subprocess
from typing import Callable, Optional, Tuple
import torch
import torch.multiprocessing as mp
from torch import distributed as dist
_LOCAL_PROCESS_GROUP = None
def is_distributed() -> bool:
"""Return True if distributed environment has been initialized."""
return dist.is_available() and dist.is_initialized()
def get_local_group() -> Optional[dist.ProcessGroup]:
"""Return local process group."""
if not is_distributed():
return None
if _LOCAL_PROCESS_GROUP is None:
raise RuntimeError('Local process group is not created, please use '
'`init_local_group` to setup local process group.')
return _LOCAL_PROCESS_GROUP
def get_default_group() -> Optional[dist.ProcessGroup]:
"""Return default process group."""
return dist.distributed_c10d._get_default_group()
def init_dist(launcher, backend='nccl', **kwargs) -> None:
"""Initialize distributed environment.
Args:
launcher (str): Way to launcher multi processes. Supported launchers
are 'pytorch', 'mpi' and 'slurm'.
backend (str): Communication Backends. Supported backends are 'nccl',
'gloo' and 'mpi'. Defaults to 'nccl'.
**kwargs: keyword arguments are passed to ``init_process_group``.
"""
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
if launcher == 'pytorch':
_init_dist_pytorch(backend, **kwargs)
elif launcher == 'mpi':
_init_dist_mpi(backend, **kwargs)
elif launcher == 'slurm':
_init_dist_slurm(backend, **kwargs)
else:
raise ValueError(f'Invalid launcher type: {launcher}')
def _init_dist_pytorch(backend, **kwargs) -> None:
"""Initialize distributed environment with PyTorch launcher.
Args:
backend (str): Backend of torch.distributed. Supported backends are
'nccl', 'gloo' and 'mpi'. Defaults to 'nccl'.
**kwargs: keyword arguments are passed to ``init_process_group``.
"""
# TODO: use local_rank instead of rank % num_gpus
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
def _init_dist_mpi(backend, **kwargs) -> None:
"""Initialize distributed environment with MPI launcher.
Args:
backend (str): Backend of torch.distributed. Supported backends are
'nccl', 'gloo' and 'mpi'. Defaults to 'nccl'.
**kwargs: keyword arguments are passed to ``init_process_group``.
"""
# TODO: use local_rank instead of rank % num_gpus
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
def _init_dist_slurm(backend, port=None) -> None:
"""Initialize slurm distributed training environment.
If argument ``port`` is not specified, then the master port will be system
environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
environment variable, then a default port ``29500`` will be used.
Args:
backend (str): Backend of torch.distributed.
port (int, optional): Master port. Defaults to None.
TODO: https://github.com/open-mmlab/mmcv/pull/1682
"""
proc_id = int(os.environ['SLURM_PROCID'])
ntasks = int(os.environ['SLURM_NTASKS'])
node_list = os.environ['SLURM_NODELIST']
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(proc_id % num_gpus)
addr = subprocess.getoutput(
f'scontrol show hostname {node_list} | head -n1')
# specify master port
if port is not None:
os.environ['MASTER_PORT'] = str(port)
elif 'MASTER_PORT' in os.environ:
pass # use MASTER_PORT in the environment variable
else:
# 29500 is torch.distributed default port
os.environ['MASTER_PORT'] = '29500'
# use MASTER_ADDR in the environment variable if it already exists
if 'MASTER_ADDR' not in os.environ:
os.environ['MASTER_ADDR'] = addr
os.environ['WORLD_SIZE'] = str(ntasks)
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
os.environ['RANK'] = str(proc_id)
dist.init_process_group(backend=backend)
def init_local_group(node_rank: int, num_gpus_per_node: int):
"""Setup the local process group.
Setup a process group which only includes processes that on the same
machine as the current process.
The code is modified from
https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py
Args:
node_rank (int): Rank of machines used for training.
num_gpus_per_node (int): Number of gpus used for training in a single
machine.
""" # noqa: W501
global _LOCAL_PROCESS_GROUP
assert _LOCAL_PROCESS_GROUP is None
ranks = list(
range(node_rank * num_gpus_per_node,
(node_rank + 1) * num_gpus_per_node))
_LOCAL_PROCESS_GROUP = dist.new_group(ranks)
def get_backend(group: Optional[dist.ProcessGroup] = None) -> Optional[str]:
"""Return the backend of the given process group.
Note:
Calling ``get_backend`` in non-distributed environment will return
None.
Args:
group (ProcessGroup, optional): The process group to work on. The
default is the general main process group. If another specific
group is specified, the calling process must be part of
:attr:`group`. Defaults to None.
Returns:
str or None: Return the backend of the given process group as a lower
case string if in distributed environment, otherwise None.
"""
if is_distributed():
# handle low versions of torch like 1.5.0 which does not support
# passing in None for group argument
if group is None:
group = get_default_group()
return dist.get_backend(group)
else:
return None
def get_world_size(group: Optional[dist.ProcessGroup] = None) -> int:
"""Return the number of the given process group.
Note:
Calling ``get_world_size`` in non-distributed environment will return
1.
Args:
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used. Defaults to None.
Returns:
int: Return the number of processes of the given process group if in
distributed environment, otherwise 1.
"""
if is_distributed():
# handle low versions of torch like 1.5.0 which does not support
# passing in None for group argument
if group is None:
group = get_default_group()
return dist.get_world_size(group)
else:
return 1
def get_rank(group: Optional[dist.ProcessGroup] = None) -> int:
"""Return the rank of the given process group.
Rank is a unique identifier assigned to each process within a distributed
process group. They are always consecutive integers ranging from 0 to
``world_size``.
Note:
Calling ``get_rank`` in non-distributed environment will return 0.
Args:
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used. Defaults to None.
Returns:
int: Return the rank of the process group if in distributed
environment, otherwise 0.
"""
if is_distributed():
# handle low versions of torch like 1.5.0 which does not support
# passing in None for group argument
if group is None:
group = get_default_group()
return dist.get_rank(group)
else:
return 0
def get_local_size() -> int:
"""Return the number of the current node.
Returns:
int: Return the number of processes in the current node if in
distributed environment, otherwise 1.
"""
if not is_distributed():
return 1
if _LOCAL_PROCESS_GROUP is None:
raise RuntimeError('Local process group is not created, please use '
'`init_local_group` to setup local process group.')
return dist.get_world_size(_LOCAL_PROCESS_GROUP)
def get_local_rank() -> int:
"""Return the rank of current process in the current node.
Returns:
int: Return the rank of current process in the current node if in
distributed environment, otherwise 0
"""
if not is_distributed():
return 0
if _LOCAL_PROCESS_GROUP is None:
raise RuntimeError('Local process group is not created, please use '
'`init_local_group` to setup local process group.')
return dist.get_rank(_LOCAL_PROCESS_GROUP)
def get_dist_info(
group: Optional[dist.ProcessGroup] = None) -> Tuple[int, int]:
"""Get distributed information of the given process group.
Note:
Calling ``get_dist_info`` in non-distributed environment will return
(0, 1).
Args:
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used. Defaults to None.
Returns:
tuple[int, int]: Return a tuple containing the ``world_size`` and
``rank``.
"""
world_size = get_world_size(group)
rank = get_rank(group)
return rank, world_size
def is_main_process(group: Optional[dist.ProcessGroup] = None) -> bool:
"""Whether the current rank of the given process group is equal to 0.
Args:
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used. Defaults to None.
Returns:
bool: Return True if the current rank of the given process group is
equal to 0, otherwise False.
"""
return get_rank(group) == 0
def master_only(func: Callable) -> Callable:
"""Decorate those methods which should be executed in master process.
Args:
func (callable): Function to be decorated.
Returns:
callable: Return decorated function.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
if is_main_process():
return func(*args, **kwargs)
return wrapper
def barrier(group: Optional[dist.ProcessGroup] = None) -> None:
"""Synchronize all processes from the given process group.
This collective blocks processes until the whole group enters this
function.
Note:
Calling ``barrier`` in non-distributed environment will do nothing.
Args:
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used. Defaults to None.
"""
if is_distributed():
# handle low versions of torch like 1.5.0 which does not support
# passing in None for group argument
if group is None:
group = get_default_group()
dist.barrier(group)
# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
import tempfile
from unittest.mock import patch
import pytest
import torch
import torch.multiprocessing as mp
import mmengine.dist as dist
from mmengine.dist.dist import sync_random_seed
from mmengine.utils import TORCH_VERSION, digit_version
def _test_all_reduce_non_dist():
data = torch.arange(2, dtype=torch.int64)
expected = torch.arange(2, dtype=torch.int64)
dist.all_reduce(data)
assert torch.allclose(data, expected)
def _test_all_gather_non_dist():
data = torch.arange(2, dtype=torch.int64)
expected = torch.arange(2, dtype=torch.int64)
output = dist.all_gather(data)
assert torch.allclose(output[0], expected)
def _test_gather_non_dist():
data = torch.arange(2, dtype=torch.int64)
expected = torch.arange(2, dtype=torch.int64)
output = dist.gather(data)
assert torch.allclose(output[0], expected)
def _test_broadcast_non_dist():
data = torch.arange(2, dtype=torch.int64)
expected = torch.arange(2, dtype=torch.int64)
dist.broadcast(data)
assert torch.allclose(data, expected)
@patch('numpy.random.randint', return_value=10)
def _test_sync_random_seed_no_dist(mock):
assert sync_random_seed() == 10
def _test_broadcast_object_list_no_dist():
with pytest.raises(AssertionError):
# input should be list of object
dist.broadcast_object_list('foo')
data = ['foo', 12, {1: 2}]
expected = ['foo', 12, {1: 2}]
dist.broadcast_object_list(data)
assert data == expected
def _test_all_reduce_dict_no_dist():
with pytest.raises(AssertionError):
# input should be dict
dist.all_reduce_dict('foo')
data = {
'key1': torch.arange(2, dtype=torch.int64),
'key2': torch.arange(3, dtype=torch.int64)
}
expected = {
'key1': torch.arange(2, dtype=torch.int64),
'key2': torch.arange(3, dtype=torch.int64)
}
dist.all_reduce_dict(data)
for key in data:
assert torch.allclose(data[key], expected[key])
def _test_all_gather_object_no_dist():
data = 'foo'
expected = 'foo'
gather_objects = dist.all_gather_object(data)
assert gather_objects[0] == expected
def _test_gather_object_no_dist():
data = 'foo'
expected = 'foo'
gather_objects = dist.gather_object(data)
assert gather_objects[0] == expected
def _test_collect_results_non_dist():
data = ['foo', {1: 2}]
size = 2
expected = ['foo', {1: 2}]
# test `device=cpu`
output = dist.collect_results(data, size, device='cpu')
assert output == expected
# test `device=gpu`
output = dist.collect_results(data, size, device='cpu')
assert output == expected
def init_process(rank, world_size, functions, backend='gloo'):
"""Initialize the distributed environment."""
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29505'
os.environ['RANK'] = str(rank)
dist.init_dist('pytorch', backend, rank=rank, world_size=world_size)
device = 'cpu' if backend == 'gloo' else 'cuda'
for func in functions:
func(device)
def main(functions, world_size=2, backend='gloo'):
try:
mp.spawn(
init_process,
args=(world_size, functions, backend),
nprocs=world_size)
except Exception:
pytest.fail(f'{backend} failed')
def _test_all_reduce_dist(device):
for tensor_type, reduce_op in zip([torch.int64, torch.float32],
['sum', 'mean']):
if dist.get_rank() == 0:
data = torch.tensor([1, 2], dtype=tensor_type).to(device)
else:
data = torch.tensor([3, 4], dtype=tensor_type).to(device)
if reduce_op == 'sum':
expected = torch.tensor([4, 6], dtype=tensor_type).to(device)
else:
expected = torch.tensor([2, 3], dtype=tensor_type).to(device)
dist.all_reduce(data, reduce_op)
assert torch.allclose(data, expected)
def _test_all_gather_dist(device):
if dist.get_rank() == 0:
data = torch.tensor([0, 1]).to(device)
else:
data = torch.tensor([1, 2]).to(device)
expected = [
torch.tensor([0, 1]).to(device),
torch.tensor([1, 2]).to(device)
]
output = dist.all_gather(data)
assert torch.allclose(output[dist.get_rank()], expected[dist.get_rank()])
def _test_gather_dist(device):
if dist.get_rank() == 0:
data = torch.tensor([0, 1]).to(device)
else:
data = torch.tensor([1, 2]).to(device)
output = dist.gather(data)
if dist.get_rank() == 0:
expected = [
torch.tensor([0, 1]).to(device),
torch.tensor([1, 2]).to(device)
]
for i in range(2):
assert torch.allclose(output[i], expected[i])
else:
assert output == []
def _test_broadcast_dist(device):
if dist.get_rank() == 0:
data = torch.tensor([0, 1]).to(device)
else:
data = torch.tensor([1, 2]).to(device)
expected = torch.tensor([0, 1]).to(device)
dist.broadcast(data, 0)
assert torch.allclose(data, expected)
def _test_sync_random_seed_dist(device):
with patch.object(
torch, 'tensor',
return_value=torch.tensor(1024).to(device)) as mock_tensor:
output = dist.sync_random_seed()
assert output == 1024
mock_tensor.assert_called()
def _test_broadcast_object_list_dist(device):
if dist.get_rank() == 0:
data = ['foo', 12, {1: 2}]
else:
data = [None, None, None]
expected = ['foo', 12, {1: 2}]
dist.broadcast_object_list(data)
assert data == expected
def _test_all_reduce_dict_dist(device):
for tensor_type, reduce_op in zip([torch.int64, torch.float32],
['sum', 'mean']):
if dist.get_rank() == 0:
data = {
'key1': torch.tensor([0, 1], dtype=tensor_type).to(device),
'key2': torch.tensor([1, 2], dtype=tensor_type).to(device)
}
else:
data = {
'key1': torch.tensor([2, 3], dtype=tensor_type).to(device),
'key2': torch.tensor([3, 4], dtype=tensor_type).to(device)
}
if reduce_op == 'sum':
expected = {
'key1': torch.tensor([2, 4], dtype=tensor_type).to(device),
'key2': torch.tensor([4, 6], dtype=tensor_type).to(device)
}
else:
expected = {
'key1': torch.tensor([1, 2], dtype=tensor_type).to(device),
'key2': torch.tensor([2, 3], dtype=tensor_type).to(device)
}
dist.all_reduce_dict(data, reduce_op)
for key in data:
assert torch.allclose(data[key], expected[key])
# `torch.cat` in torch1.5 can not concatenate different types so we
# fallback to convert them all to float type.
if digit_version(TORCH_VERSION) == digit_version('1.5.0'):
if dist.get_rank() == 0:
data = {
'key1': torch.tensor([0, 1], dtype=torch.float32).to(device),
'key2': torch.tensor([1, 2], dtype=torch.int32).to(device)
}
else:
data = {
'key1': torch.tensor([2, 3], dtype=torch.float32).to(device),
'key2': torch.tensor([3, 4], dtype=torch.int32).to(device)
}
expected = {
'key1': torch.tensor([2, 4], dtype=torch.float32).to(device),
'key2': torch.tensor([4, 6], dtype=torch.float32).to(device)
}
dist.all_reduce_dict(data, 'sum')
for key in data:
assert torch.allclose(data[key], expected[key])
def _test_all_gather_object_dist(device):
if dist.get_rank() == 0:
data = 'foo'
else:
data = {1: 2}
expected = ['foo', {1: 2}]
output = dist.all_gather_object(data)
assert output == expected
def _test_gather_object_dist(device):
if dist.get_rank() == 0:
data = 'foo'
else:
data = {1: 2}
output = dist.gather_object(data, dst=0)
if dist.get_rank() == 0:
assert output == ['foo', {1: 2}]
else:
assert output is None
def _test_collect_results_dist(device):
if dist.get_rank() == 0:
data = ['foo', {1: 2}]
else:
data = [24, {'a': 'b'}]
size = 4
expected = ['foo', 24, {1: 2}, {'a': 'b'}]
# test `device=cpu`
output = dist.collect_results(data, size, device='cpu')
if dist.get_rank() == 0:
assert output == expected
else:
assert output is None
# test `device=cpu` and `tmpdir is not None`
tmpdir = tempfile.mkdtemp()
# broadcast tmpdir to all ranks to make it consistent
object_list = [tmpdir]
dist.broadcast_object_list(object_list)
output = dist.collect_results(
data, size, device='cpu', tmpdir=object_list[0])
if dist.get_rank() == 0:
assert output == expected
else:
assert output is None
if dist.get_rank() == 0:
# object_list[0] will be removed by `dist.collect_results`
assert not osp.exists(object_list[0])
# test `device=gpu`
output = dist.collect_results(data, size, device='gpu')
if dist.get_rank() == 0:
assert output == expected
else:
assert output is None
def test_non_distributed_env():
_test_all_reduce_non_dist()
_test_all_gather_non_dist()
_test_gather_non_dist()
_test_broadcast_non_dist()
_test_sync_random_seed_no_dist()
_test_broadcast_object_list_no_dist()
_test_all_reduce_dict_no_dist()
_test_all_gather_object_no_dist()
_test_gather_object_no_dist()
_test_collect_results_non_dist()
def test_gloo_backend():
functions_to_test = [
_test_all_reduce_dist,
_test_all_gather_dist,
_test_gather_dist,
_test_broadcast_dist,
_test_sync_random_seed_dist,
_test_broadcast_object_list_dist,
_test_all_reduce_dict_dist,
_test_all_gather_object_dist,
_test_gather_object_dist,
]
main(functions_to_test, backend='gloo')
@pytest.mark.skipif(
torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl')
def test_nccl_backend():
functions_to_test = [
_test_all_reduce_dist,
_test_all_gather_dist,
_test_broadcast_dist,
_test_sync_random_seed_dist,
_test_broadcast_object_list_dist,
_test_all_reduce_dict_dist,
_test_all_gather_object_dist,
_test_collect_results_dist,
]
main(functions_to_test, backend='nccl')
# Copyright (c) OpenMMLab. All rights reserved.
import os
import pytest
import torch
import torch.distributed as torch_dist
import torch.multiprocessing as mp
import mmengine.dist as dist
def _test_get_backend_non_dist():
assert dist.get_backend() is None
def _test_get_world_size_non_dist():
assert dist.get_world_size() == 1
def _test_get_rank_non_dist():
assert dist.get_rank() == 0
def _test_local_size_non_dist():
assert dist.get_local_size() == 1
def _test_local_rank_non_dist():
assert dist.get_local_rank() == 0
def _test_get_dist_info_non_dist():
assert dist.get_dist_info() == (0, 1)
def _test_is_main_process_non_dist():
assert dist.is_main_process()
def _test_master_only_non_dist():
@dist.master_only
def fun():
assert dist.get_rank() == 0
fun()
def _test_barrier_non_dist():
dist.barrier() # nothing is done
def init_process(rank, world_size, functions, backend='gloo'):
"""Initialize the distributed environment."""
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29501'
os.environ['RANK'] = str(rank)
dist.init_dist('pytorch', backend, rank=rank, world_size=world_size)
dist.init_local_group(0, world_size)
for func in functions:
func()
def main(functions, world_size=2, backend='gloo'):
try:
mp.spawn(
init_process,
args=(world_size, functions, backend),
nprocs=world_size)
except Exception:
pytest.fail('error')
def _test_get_backend_dist():
assert dist.get_backend() == torch_dist.get_backend()
def _test_get_world_size_dist():
assert dist.get_world_size() == 2
def _test_get_rank_dist():
if torch_dist.get_rank() == 0:
assert dist.get_rank() == 0
else:
assert dist.get_rank() == 1
def _test_local_size_dist():
assert dist.get_local_size() == 2
def _test_local_rank_dist():
torch_dist.get_rank(dist.get_local_group()) == dist.get_local_rank()
def _test_get_dist_info_dist():
if dist.get_rank() == 0:
assert dist.get_dist_info() == (0, 2)
else:
assert dist.get_dist_info() == (1, 2)
def _test_is_main_process_dist():
if dist.get_rank() == 0:
assert dist.is_main_process()
else:
assert not dist.is_main_process()
def _test_master_only_dist():
@dist.master_only
def fun():
assert dist.get_rank() == 0
fun()
def test_non_distributed_env():
_test_get_backend_non_dist()
_test_get_world_size_non_dist()
_test_get_rank_non_dist()
_test_local_size_non_dist()
_test_local_rank_non_dist()
_test_get_dist_info_non_dist()
_test_is_main_process_non_dist()
_test_master_only_non_dist()
_test_barrier_non_dist()
functions_to_test = [
_test_get_backend_dist,
_test_get_world_size_dist,
_test_get_rank_dist,
_test_local_size_dist,
_test_local_rank_dist,
_test_get_dist_info_dist,
_test_is_main_process_dist,
_test_master_only_dist,
]
def test_gloo_backend():
main(functions_to_test)
@pytest.mark.skipif(
torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl')
def test_nccl_backend():
main(functions_to_test, backend='nccl')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment