From c6a8d72c5e5d46bef61f8b07b196d0834218ab89 Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Sat, 5 Mar 2022 22:03:32 +0800 Subject: [PATCH] [Feature] Add distributed module (#59) * [Feature] Add distributed module * fix IS_DIST error * all_reduce_dict does operations in-place * support 'mean' operation * provide local group process * add tmpdir argument for collect_results * add unit tests * refactor unit tests * simplify steps to create multiple processes * minor fix * describe the different of *gather* in mmengine and pytorch * minor fix * add unit tests for nccl * test nccl backend in multiple gpu * add get_default_group function to handle different torch versions * minor fix * [Feature] Add distributed module * fix IS_DIST error * all_reduce_dict does operations in-place * support 'mean' operation * provide local group process * add tmpdir argument for collect_results * add unit tests * refactor unit tests * simplify steps to create multiple processes * minor fix * describe the different of *gather* in mmengine and pytorch * minor fix * add unit tests for nccl * test nccl backend in multiple gpu * add get_default_group function to handle different torch versions * minor fix * minor fix * handle torch1.5 * handle torch1.5 * minor fix * fix typo * refactor unit tests * nccl does not support gather and gather_object * fix gather * fix collect_results_cpu * fix collect_results and refactor unit tests * fix collect_results unit tests * handle torch.cat in torch1.5 * refine docstring * refine docstring * fix comments * fix comments --- .gitignore | 1 - docs/en/api.rst | 5 + docs/zh_cn/api.rst | 5 + mmengine/dist/__init__.py | 19 + mmengine/dist/dist.py | 1023 +++++++++++++++++++++++++++++++++ mmengine/dist/utils.py | 335 +++++++++++ tests/test_dist/test_dist.py | 376 ++++++++++++ tests/test_dist/test_utils.py | 152 +++++ 8 files changed, 1915 insertions(+), 1 deletion(-) create mode 100644 mmengine/dist/__init__.py create mode 100644 mmengine/dist/dist.py create mode 100644 mmengine/dist/utils.py create mode 100644 tests/test_dist/test_dist.py create mode 100644 tests/test_dist/test_utils.py diff --git a/.gitignore b/.gitignore index 00379477..5e8df67f 100644 --- a/.gitignore +++ b/.gitignore @@ -10,7 +10,6 @@ __pycache__/ .Python build/ develop-eggs/ -dist/ downloads/ eggs/ .eggs/ diff --git a/docs/en/api.rst b/docs/en/api.rst index fee9eea1..a6f3c0ac 100644 --- a/docs/en/api.rst +++ b/docs/en/api.rst @@ -7,3 +7,8 @@ Data -------- .. automodule:: mmengine.data :members: + +Distributed +----------- +.. automodule:: mmengine.dist + :members: diff --git a/docs/zh_cn/api.rst b/docs/zh_cn/api.rst index fee9eea1..a6f3c0ac 100644 --- a/docs/zh_cn/api.rst +++ b/docs/zh_cn/api.rst @@ -7,3 +7,8 @@ Data -------- .. automodule:: mmengine.data :members: + +Distributed +----------- +.. automodule:: mmengine.dist + :members: diff --git a/mmengine/dist/__init__.py b/mmengine/dist/__init__.py new file mode 100644 index 00000000..d4dfe710 --- /dev/null +++ b/mmengine/dist/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .dist import (all_gather_object, all_reduce, all_gather, all_reduce_dict, + collect_results, gather, broadcast, gather_object, + sync_random_seed, broadcast_object_list, + collect_results_cpu, collect_results_gpu) +from .utils import (get_dist_info, init_dist, init_local_group, get_backend, + get_world_size, get_rank, get_local_size, get_local_rank, + is_main_process, master_only, barrier, get_local_group, + is_distributed, get_default_group) + +__all__ = [ + 'all_gather_object', 'all_reduce', 'all_gather', 'all_reduce_dict', + 'collect_results', 'collect_results_cpu', 'collect_results_gpu', 'gather', + 'broadcast', 'gather_object', 'sync_random_seed', 'broadcast_object_list', + 'get_dist_info', 'init_dist', 'init_local_group', 'get_backend', + 'get_world_size', 'get_rank', 'get_local_size', 'get_local_group', + 'get_local_rank', 'is_main_process', 'master_only', 'barrier', + 'is_distributed', 'get_default_group' +] diff --git a/mmengine/dist/dist.py b/mmengine/dist/dist.py new file mode 100644 index 00000000..6569d901 --- /dev/null +++ b/mmengine/dist/dist.py @@ -0,0 +1,1023 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, List, Optional, Tuple, Dict +import shutil +import pickle +import numpy as np +import tempfile +import torch +import os.path as osp +from torch import Tensor +from torch import distributed as dist + +import mmengine +from .utils import (get_world_size, get_rank, get_backend, get_dist_info, + get_default_group) +from mmengine.utils import digit_version, TORCH_VERSION + + +def _get_reduce_op(name: str) -> dist.ReduceOp: + op_mappings = { + 'sum': dist.ReduceOp.SUM, + 'product': dist.ReduceOp.PRODUCT, + 'min': dist.ReduceOp.MIN, + 'max': dist.ReduceOp.MAX, + 'band': dist.ReduceOp.BAND, + 'bor': dist.ReduceOp.BOR, + 'bxor': dist.ReduceOp.BXOR, + } + + if name.lower() not in op_mappings: + raise ValueError( + f'reduce op should be one of {op_mappings.keys()}, bug got {name}') + + return op_mappings[name.lower()] + + +def all_reduce(data: Tensor, + op: str = 'sum', + group: Optional[dist.ProcessGroup] = None) -> None: + """Reduces the tensor data across all machines in such a way that all get + the final result. + + After the call ``data`` is going to be bitwise identical in all + processes. + + Note: + Calling ``all_reduce`` in non-distributed environment does nothing. + + Args: + data (Tensor): Input and output of the collective. The function + operates in-place. + op (str): Operation to reduce data. Defaults to 'sum'. Optional values + are 'sum', 'mean' and 'produce', 'min', 'max', 'band', 'bor' and + 'bxor'. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Defaults to None. + + Examples: + >>> import torch + >>> import mmengine.dist as dist + + >>> # non-distributed environment + >>> data = torch.arange(2, dtype=torch.int64) + >>> dist.all_reduce(data) + >>> data + tensor([0, 1]) + + >>> # distributed environment + >>> # We have 2 process groups, 2 ranks. + >>> data = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank + >>> data + tensor([1, 2]) # Rank 0 + tensor([3, 4]) # Rank 1 + >>> dist.all_reduce(data, op=torch.dist.ReduceOp.SUM) + >>> data + tensor([4, 6]) # Rank 0 + tensor([4, 6]) # Rank 1 + """ + world_size = get_world_size(group) + if world_size > 1: + if group is None: + group = get_default_group() + + # pytorch does not support 'mean' operation so we fall back to support + # it with 'sum' operation. + if op.lower() == 'mean': + dist.all_reduce(data, _get_reduce_op('sum'), group) + data.div_(world_size) + else: + dist.all_reduce(data, _get_reduce_op(op), group) + + +def all_gather(data: Tensor, + group: Optional[dist.ProcessGroup] = None) -> List[Tensor]: + """Gather data from the whole group in a list. + + Note: + Calling ``all_gather`` in non-distributed environment does nothing + and just returns a list containing :attr:`data` itself. + + Note: + Unlike PyTorch ``torch.distributed.all_gather``, :meth:`all_gather` in + MMEngine does not pass in an empty list ``gather_list`` and returns + the ``gather_list`` directly, which is more convenient. The difference + between their interfaces is as below: + + - MMEngine: all_gather(data, group) -> gather_list + - PyTorch: all_gather(gather_list, data, group) -> None + + Args: + data (Tensor): Tensor to be gathered. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Defaults to None. + + Returns: + list[Tensor]: Return a list containing data from the whole group if + in distributed environment, otherwise a list only containing + :attr:`data` itself. + + Examples: + >>> import torch + >>> import mmengine.dist as dist + + >>> # non-distributed environment + >>> data = torch.arange(2, dtype=torch.int64) + >>> data + tensor([0, 1]) + >>> output = dist.all_gather(data) + >>> output + [tensor([0, 1])] + + >>> # distributed environment + >>> # We have 2 process groups, 2 ranks. + >>> data = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank + >>> data + tensor([1, 2]) # Rank 0 + tensor([3, 4]) # Rank 1 + >>> output = dist.all_gather(data) + >>> output + [tensor([1, 2]), tensor([3, 4])] # Rank 0 + [tensor([1, 2]), tensor([3, 4])] # Rank 1 + """ + world_size = get_world_size(group) + if world_size == 1: + return [data] + + if group is None: + group = get_default_group() + + gather_list = [torch.empty_like(data) for _ in range(world_size)] + dist.all_gather(gather_list, data, group) + return gather_list + + +def gather( + data: Tensor, + dst: int = 0, + group: Optional[dist.ProcessGroup] = None) -> List[Optional[Tensor]]: + """Gather data from the whole group to ``dst`` process. + + Note: + Calling ``gather`` in non-distributed environment dose nothing + and just returns a list containing :attr:`data` itself. + + Note: + ``NCCL`` backend does not support ``gather``. + + Note: + Unlike PyTorch ``torch.distributed.gather``, :meth:`gather` in + MMEngine does not pass in an empty list ``gather_list`` and returns + the ``gather_list`` directly, which is more convenient. The difference + between their interfaces is as below: + + - MMEngine: gather(data, dst, group) -> gather_list + - PyTorch: gather(data, gather_list, dst, group) -> None + + Args: + data (Tensor): Tensor to be gathered. CUDA tensor is not supported. + dst (int): Destination rank. Defaults to 0. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Defaults to None. + + Returns: + list[Tensor]: ``dst`` process will get a list of tensor gathering from + the whole group. Other process will get a empty list. If in + non-distributed environment, just return a list containing + :attr:`data` itself. + + Examples: + >>> import torch + >>> import mmengine.dist as dist + + >>> # non-distributed environment + >>> data = torch.arange(2, dtype=torch.int64) + >>> data + tensor([0, 1]) + >>> output = dist.gather(data) + >>> output + [tensor([0, 1])] + + >>> # distributed environment + >>> # We have 2 process groups, 2 ranks. + >>> data = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank + >>> data + tensor([1, 2]) # Rank 0 + tensor([3, 4]) # Rank 1 + >>> output = dist.gather(data) + >>> output + [tensor([1, 2]), tensor([3, 4])] # Rank 0 + [] # Rank 1 + """ + world_size = get_world_size(group) + if world_size == 1: + return [data] + + if group is None: + group = get_default_group() + + if get_rank(group) == dst: + gather_list = [torch.empty_like(data) for _ in range(world_size)] + else: + gather_list = [] + + dist.gather(data, gather_list, dst, group) + return gather_list + + +def broadcast(data: Tensor, + src: int = 0, + group: Optional[dist.ProcessGroup] = None) -> None: + """Broadcast the data from ``src`` process to the whole group. + + ``data`` must have the same number of elements in all processes + participating in the collective. + + Note: + Calling ``broadcast`` in non-distributed environment does nothing. + + Args: + data (Tensor): Data to be sent if ``src`` is the rank of current + process, and data to be used to save received data otherwise. + src (int): Source rank. Defaults to 0. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Defaults to None. + + Examples: + >>> import torch + >>> import mmengine.dist as dist + + >>> # non-distributed environment + >>> data = torch.arange(2, dtype=torch.int64) + >>> data + tensor([0, 1]) + >>> dist.broadcast(data) + >>> data + tensor([0, 1]) + + >>> # distributed environment + >>> # We have 2 process groups, 2 ranks. + >>> data = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank + >>> data + tensor([1, 2]) # Rank 0 + tensor([3, 4]) # Rank 1 + >>> dist.broadcast(data) + >>> data + tensor([1, 2]) # Rank 0 + tensor([1, 2]) # Rank 1 + """ + if get_world_size(group) > 1: + if group is None: + group = get_default_group() + + dist.broadcast(data, src, group) + + +def sync_random_seed(group: Optional[dist.ProcessGroup] = None) -> int: + """Synchronize a random seed to all processes. + + Args: + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Defaults to None. + + Returns: + int: Random seed. + + Examples: + >>> import torch + >>> import mmengine.dist as dist + + >>> # non-distributed environment + >>> seed = dist.sync_random_seed() + >>> seed # which a random number + 587791752 + + >>> distributed environment + >>> # We have 2 process groups, 2 ranks. + >>> seed = dist.sync_random_seed() + >>> seed + 587791752 # Rank 0 + 587791752 # Rank 1 + """ + seed = np.random.randint(2**31) + if get_world_size(group) == 1: + return seed + + if group is None: + group = get_default_group() + + if get_rank(group) == 0: + random_num = torch.tensor(seed, dtype=torch.int32) + else: + random_num = torch.tensor(0, dtype=torch.int32) + + dist.broadcast(random_num, src=0, group=group) + + return random_num.item() + + +def _object_to_tensor(obj: Any) -> Tuple[Tensor, Tensor]: + """Serialize picklable python object to tensor.""" + byte_storage = torch.ByteStorage.from_buffer(pickle.dumps(obj)) + # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor + # and specifying dtype. Otherwise, it will cause 100X slowdown. + # See: https://github.com/pytorch/pytorch/issues/65696 + byte_tensor = torch.ByteTensor(byte_storage) + local_size = torch.LongTensor([byte_tensor.numel()]) + return byte_tensor, local_size + + +def _tensor_to_object(tensor: Tensor, tensor_size: int) -> Any: + """Deserialize tensor to picklable python object.""" + buf = tensor.cpu().numpy().tobytes()[:tensor_size] + return pickle.loads(buf) + + +def _broadcast_object_list(object_list: List[Any], + src: int = 0, + group: Optional[dist.ProcessGroup] = None) -> None: + """Broadcast picklable objects in ``object_list`` to the whole group. + + Similar to :func:`broadcast`, but Python objects can be passed in. Note + that all objects in ``object_list`` must be picklable in order to be + broadcasted. + """ + if dist.distributed_c10d._rank_not_in_group(group): + return + + my_rank = get_rank() + # Serialize object_list elements to tensors on src rank. + if my_rank == src: + tensor_list, size_list = zip( + *[_object_to_tensor(obj) for obj in object_list]) + object_sizes_tensor = torch.cat(size_list) + else: + object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long) + + # Current device selection. + # To preserve backwards compatibility, ``device`` is ``None`` by default. + # in which case we run current logic of device selection, i.e. + # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In + # the case it is not ``None`` we move the size and object tensors to be + # broadcasted to this device. + group_backend = get_backend(group) + is_nccl_backend = group_backend == dist.Backend.NCCL + current_device = torch.device('cpu') + if is_nccl_backend: + # See note about using torch.cuda.current_device() here in + # docstring. We cannot simply use my_rank since rank == device is + # not necessarily true. + current_device = torch.device('cuda', torch.cuda.current_device()) + object_sizes_tensor = object_sizes_tensor.to(current_device) + + # Broadcast object sizes + dist.broadcast(object_sizes_tensor, src=src, group=group) + + # Concatenate and broadcast serialized object tensors + if my_rank == src: + object_tensor = torch.cat(tensor_list) + else: + object_tensor = torch.empty( + torch.sum(object_sizes_tensor).int().item(), + dtype=torch.uint8, + ) + + if is_nccl_backend: + object_tensor = object_tensor.to(current_device) + dist.broadcast(object_tensor, src=src, group=group) + # Deserialize objects using their stored sizes. + offset = 0 + if my_rank != src: + for i, obj_size in enumerate(object_sizes_tensor): + obj_view = object_tensor[offset:offset + obj_size] + obj_view = obj_view.type(torch.uint8) + if obj_view.device != torch.device('cpu'): + obj_view = obj_view.cpu() + offset += obj_size + object_list[i] = _tensor_to_object(obj_view, obj_size) + + +def broadcast_object_list(data: List[Any], + src: int = 0, + group: Optional[dist.ProcessGroup] = None) -> None: + """Broadcasts picklable objects in ``object_list`` to the whole group. + Similar to :func:`broadcast`, but Python objects can be passed in. Note + that all objects in ``object_list`` must be picklable in order to be + broadcasted. + + Note: + Calling ``broadcast_object_list`` in non-distributed environment does + nothing. + + Args: + data (List[Any]): List of input objects to broadcast. + Each object must be picklable. Only objects on the ``src`` rank + will be broadcast, but each rank must provide lists of equal sizes. + src (int): Source rank from which to broadcast ``object_list``. + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + device (``torch.device``, optional): If not None, the objects are + serialized and converted to tensors which are moved to the + ``device`` before broadcasting. Default is ``None``. + + Note: + For NCCL-based process groups, internal tensor representations of + objects must be moved to the GPU device before communication starts. + In this case, the used device is given by + ``torch.cuda.current_device()`` and it is the user's responsibility to + ensure that this is correctly set so that each rank has an individual + GPU, via ``torch.cuda.set_device()``. + + Examples: + >>> import torch + >>> import mmengine.dist as dist + + >>> # non-distributed environment + >>> data = ['foo', 12, {1: 2}] + >>> dist.broadcast_object_list(data) + >>> data + ['foo', 12, {1: 2}] + + >>> # distributed environment + >>> # We have 2 process groups, 2 ranks. + >>> if dist.get_rank() == 0: + >>> # Assumes world_size of 3. + >>> data = ["foo", 12, {1: 2}] # any picklable object + >>> else: + >>> data = [None, None, None] + >>> dist.broadcast_object_list(data) + >>> data + ["foo", 12, {1: 2}] # Rank 0 + ["foo", 12, {1: 2}] # Rank 1 + """ + assert isinstance(data, list) + + if get_world_size(group) > 1: + if group is None: + group = get_default_group() + + if digit_version(TORCH_VERSION) >= digit_version('1.8.0'): + dist.broadcast_object_list(data, src, group) + else: + _broadcast_object_list(data, src, group) + + +def all_reduce_dict(data: Dict[str, Tensor], + op: str = 'sum', + group: Optional[dist.ProcessGroup] = None) -> None: + """Reduces the dict across all machines in such a way that all get the + final result. + + The code is modified from https://github.com/Megvii- + BaseDetection/YOLOX/blob/main/yolox/utils/allreduce_norm.py. + + Args: + data (dict[str, Tensor]): Data to be reduced. + op (str): Operation to reduce data. Defaults to 'sum'. Optional values + are 'sum', 'mean' and 'produce', 'min', 'max', 'band', 'bor' and + 'bxor'. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Defaults to None. + + Examples: + >>> import torch + >>> import mmengine.dist as dist + + >>> # non-distributed environment + >>> data = { + 'key1': torch.arange(2, dtype=torch.int64), + 'key2': torch.arange(3, dtype=torch.int64) + } + >>> dist.all_reduce_dict(data) + >>> data + {'key1': tensor([0, 1]), 'key2': tensor([0, 1, 2])} + + >>> # distributed environment + >>> # We have 2 process groups, 2 ranks. + >>> data = { + 'key1': torch.arange(2, dtype=torch.int64), + 'key2': torch.arange(3, dtype=torch.int64) + } + >>> dist.all_reduce_dict(data) + >>> data + {'key1': tensor([0, 2]), 'key2': tensor([0, 2, 4])} # Rank 0 + {'key1': tensor([0, 2]), 'key2': tensor([0, 2, 4])} # Rank 1 + """ + assert isinstance(data, dict) + + world_size = get_world_size(group) + if world_size > 1: + + if group is None: + group = get_default_group() + + # ensure keys are consistent across processes + keys = sorted(data.keys()) + tensor_shapes = [data[k].shape for k in keys] + tensor_sizes = [data[k].numel() for k in keys] + + if digit_version(TORCH_VERSION) == digit_version('1.5.0'): + # `torch.cat` in torch1.5 can not concatenate different types so + # we fallback to convert them all to float type. + flatten_tensor = torch.cat( + [data[k].flatten().float() for k in keys]) + else: + flatten_tensor = torch.cat([data[k].flatten() for k in keys]) + + all_reduce(flatten_tensor, op=op, group=group) + + split_tensors = [ + x.reshape(shape) for x, shape in zip( + torch.split(flatten_tensor, tensor_sizes), tensor_shapes) + ] + + for k, v in zip(keys, split_tensors): + data[k] = v + + +def _all_gather_object(object_list: List[Any], + obj: Any, + group: Optional[dist.ProcessGroup] = None) -> None: + """Gather picklable objects from the whole group into a list. + + Similar to :func:`all_gather`, but Python objects can be passed in. + Note that the object must be picklable in order to be gathered. + + Args: + object_list (list[Any]): Output list. It should be correctly sized as + the size of the group for this collective and will contain the + output. + object (Any): Pickable Python object to be broadcast from current + process. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Defaults to None. + + Returns: + None. If the calling rank is part of this group, the output of the + collective will be populated into the input ``object_list``. If the + calling rank is not part of the group, the passed in ``object_list`` + will be unmodified. + """ + if dist.distributed_c10d._rank_not_in_group(group): + return + + input_tensor, local_size = _object_to_tensor(obj) + group_backend = get_backend(group) + current_device = torch.device('cpu') + is_nccl_backend = group_backend == dist.Backend.NCCL + if is_nccl_backend: + # See note about using torch.cuda.current_device() here in docstring. + # We cannot simply use my_rank since rank == device is not necessarily + # true. + current_device = torch.device('cuda', torch.cuda.current_device()) + input_tensor = input_tensor.to(current_device) + local_size = local_size.to(current_device) + # Gather all local sizes. This is so that we can find the max size, and + # index until the correct size when deserializing the tensors. + group_size = get_world_size(group=group) + object_sizes_tensor = torch.zeros( + group_size, dtype=torch.long, device=current_device) + object_size_list = [ + object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) + ] + # Allgather tensor sizes + dist.all_gather(object_size_list, local_size, group=group) + max_object_size = int(max(object_size_list).item()) + # Resize tensor to max size across all ranks. + input_tensor.resize_(max_object_size) + coalesced_output_tensor = torch.empty( + max_object_size * group_size, dtype=torch.uint8, device=current_device) + # Output tensors are nonoverlapping views of coalesced_output_tensor + output_tensors = [ + coalesced_output_tensor[max_object_size * i:max_object_size * (i + 1)] + for i in range(group_size) + ] + dist.all_gather(output_tensors, input_tensor, group=group) + # Deserialize outputs back to object. + for i, tensor in enumerate(output_tensors): + tensor = tensor.type(torch.uint8) + if tensor.device != torch.device('cpu'): + tensor = tensor.cpu() + tensor_size = object_size_list[i] + object_list[i] = _tensor_to_object(tensor, tensor_size) + + +def all_gather_object(data: Any, + group: Optional[dist.ProcessGroup] = None) -> List[Any]: + """Gather picklable objects from the whole group into a list. Similar to + :func:`all_gather`, but Python objects can be passed in. Note that the + object must be picklable in order to be gathered. + + Note: + Calling ``all_gather_object`` in non-distributed environment does + nothing and just returns a list containing :attr:`data` itself. + + Note: + Unlike PyTorch ``torch.distributed.all_gather_object``, + :meth:`all_gather_object` in MMEngine does not pass in an empty list + ``gather_list`` and returns the ``gather_list`` directly, which is + more convenient. The difference between their interfaces is as below: + + - MMEngine: all_gather_object(data, group) -> gather_list + - PyTorch: all_gather_object(gather_list, data, group) -> None + + Args: + data (Any): Pickable Python object to be broadcast from current + process. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Defaults to None. + + Returns: + list[Tensor]: Return a list containing data from the whole group if + in distributed environment, otherwise a list only containing + :attr:`data` itself. + + Note: + For NCCL-based process groups, internal tensor representations + of objects must be moved to the GPU device before communication starts. + In this case, the used device is given by + ``torch.cuda.current_device()`` and it is the user's responsibility to + ensure that this is correctly set so that each rank has an individual + GPU, via ``torch.cuda.set_device()``. + + Examples: + >>> import torch + >>> import mmengine.dist as dist + + >>> # non-distributed environment + >>> data = ['foo', 12, {1: 2}] # any picklable object + >>> gather_objects = dist.all_gather_object(data[dist.get_rank()]) + >>> output + ['foo'] + + >>> # distributed environment + >>> # We have 3 process groups, 3 ranks. + >>> output = dist.all_gather_object(data[dist.get_rank()]) + >>> output + ['foo', 12, {1: 2}] # Rank 0 + ['foo', 12, {1: 2}] # Rank 1 + ['foo', 12, {1: 2}] # Rank 2 + """ + world_size = get_world_size(group) + if world_size == 1: + return [data] + + if group is None: + group = get_default_group() + + gather_list = [None] * world_size + + if digit_version(TORCH_VERSION) >= digit_version('1.8.0'): + dist.all_gather_object(gather_list, data, group) + else: + _all_gather_object(gather_list, data, group) + + return gather_list + + +def _validate_output_list_for_rank(my_rank: int, dst: int, + gather_list: Optional[list]) -> None: + """Validate whether ``gather_list`` is None in non-dst ranks.""" + if dst == my_rank: + if not gather_list: + raise ValueError( + 'Argument ``gather_list`` must be specified on destination ' + 'rank.') + elif gather_list: + raise ValueError('Argument ``gather_list`` must NOT be specified ' + 'on non-destination ranks.') + + +def _gather_object(obj: Any, + object_gather_list=None, + dst: int = 0, + group: Optional[dist.ProcessGroup] = None) -> None: + """Gathers picklable objects from the whole group in a single process. + + Similar to :func:`gather`, but Python objects can be passed in. Note that + the object must be picklable in order to be gathered. + + Args: + obj (Any): Input object. Must be picklable. + object_gather_list (list[Any], optional): Output list. On the ``dst`` + rank, it should be correctly sized as the size of the group for + this collective and will contain the output. Must be ``None`` on + non-dst ranks. Defaults to None. + dst (int): Destination rank. Defaults to 0. + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Defaults to None. + """ + if dist.distributed_c10d._rank_not_in_group(group): + return + + # Ensure object_gather_list is specified appopriately. + my_rank = get_rank() + _validate_output_list_for_rank(my_rank, dst, object_gather_list) + input_tensor, local_size = _object_to_tensor(obj) + group_backend = get_backend(group) + current_device = torch.device('cpu') + is_nccl_backend = group_backend == dist.Backend.NCCL + if is_nccl_backend: + current_device = torch.device('cuda', torch.cuda.current_device()) + input_tensor = input_tensor.to(current_device) + local_size = local_size.to(current_device) + # Gather all local sizes. This is so that we can find the max size, and + # index until the correct size when deserializing the tensors. + group_size = get_world_size(group=group) + object_sizes_tensor = torch.zeros( + group_size, dtype=torch.long, device=current_device) + object_size_list = [ + object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) + ] + # Allgather tensor sizes. An all-gather is needed here despite this being a + # gather, since each rank needs to broadcast a tensor of the same (maximal) + # size. + dist.all_gather(object_size_list, local_size, group=group) + max_object_size = int(max(object_size_list).item()) + # Resize tensor to max size across all ranks. + input_tensor.resize_(max_object_size) + # Avoid populating output tensors if the result won't be gathered on this + # rank. + if my_rank == dst: + coalesced_output_tensor = torch.empty( + max_object_size * group_size, + dtype=torch.uint8, + device=current_device) + # Output tensors are nonoverlapping views of coalesced_output_tensor + output_tensors = [ + coalesced_output_tensor[max_object_size * i:max_object_size * + (i + 1)] for i in range(group_size) + ] + # All ranks call gather with equal-sized tensors. + dist.gather( + input_tensor, + gather_list=output_tensors if my_rank == dst else None, + dst=dst, + group=group, + ) + if my_rank != dst: + return + for i, tensor in enumerate(output_tensors): + tensor = tensor.type(torch.uint8) + tensor_size = object_size_list[i] + object_gather_list[i] = _tensor_to_object(tensor, tensor_size) + + +def gather_object( + data: Any, + dst: int = 0, + group: Optional[dist.ProcessGroup] = None) -> Optional[List[Any]]: + """Gathers picklable objects from the whole group in a single process. + Similar to :func:`gather`, but Python objects can be passed in. Note that + the object must be picklable in order to be gathered. + + Note: + ``NCCL backend`` dost not support ``gather_object``. + + Note: + Unlike PyTorch ``torch.distributed.gather_object``, + :meth:`gather_object` in MMEngine does not pass in an empty list + ``gather_list`` and returns the ``gather_list`` directly, which is + more convenient. The difference between their interfaces is as below: + + - MMEngine: gather_object(data, dst, group) -> gather_list + - PyTorch: gather_object(data, gather_list, data, group) -> None + + Args: + obj (Any): Input object. Must be picklable. + dst (int): Destination rank. Defaults to 0. + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Defaults to None. + + Returns: + list[Any]. On the ``dst`` rank, return ``gather_list`` which contains + the output of the collective. + + Examples: + >>> import torch + >>> import mmengine.dist as dist + + >>> # non-distributed environment + >>> data = ['foo', 12, {1: 2}] # any picklable object + >>> gather_objects = dist.gather_object(data[dist.get_rank()]) + >>> output + ['foo'] + + >>> # distributed environment + >>> # We have 3 process groups, 3 ranks. + >>> dist.gather_object(gather_objects[dist.get_rank()], dst=0) + >>> output + ['foo', 12, {1: 2}] # Rank 0 + None # Rank 1 + None # Rank 2 + """ + world_size = get_world_size(group) + if world_size == 1: + return [data] + + if group is None: + group = get_default_group() + + gather_list = [None] * world_size if get_rank(group) == dst else None + + if digit_version(TORCH_VERSION) >= digit_version('1.8.0'): + dist.gather_object(data, gather_list, dst, group) + else: + _gather_object(data, gather_list, dst, group) + + return gather_list + + +def collect_results(results: list, + size: int, + device: str = 'cpu', + tmpdir: Optional[str] = None) -> Optional[list]: + """Collected results in distributed environments. + + Args: + results (list[object]): Result list containing result parts to be + collected. Each item of ``result_part`` should be a picklable + object. + size (int): Size of the results, commonly equal to length of + the results. + device (str): Device name. Optional values are 'cpu' and 'gpu'. + tmpdir (str | None): Temporal directory for collected results to + store. If set to None, it will create a temporal directory for it. + ``tmpdir`` should be None when device is 'gpu'. Defaults to None. + + Returns: + list or None: The collected results. + + Examples: + >>> # distributed environment + >>> # We have 2 process groups, 2 ranks. + >>> import mmengine.dist as dist + >>> if dist.get_rank() == 0: + data = ['foo', {1: 2}] + else: + data = [24, {'a': 'b'}] + >>> size = 4 + >>> output = dist.collect_results(data, size, device='cpu') + >>> output + ['foo', 24, {1: 2}, {'a': 'b'}] # rank 0 + None # rank 1 + """ + if device not in ['gpu', 'cpu']: + raise NotImplementedError( + f"device must be 'cpu' or 'gpu', but got {device}") + + if device == 'gpu': + assert tmpdir is None, 'tmpdir should be None when device is "gpu"' + return collect_results_gpu(results, size) + else: + return collect_results_cpu(results, size, tmpdir) + + +def collect_results_cpu(result_part: list, + size: int, + tmpdir: Optional[str] = None) -> Optional[list]: + """Collect results under cpu mode. + + On cpu mode, this function will save the results on different gpus to + ``tmpdir`` and collect them by the rank 0 worker. + + Args: + result_part (list): Result list containing result parts + to be collected. Each item of ``result_part`` should be a picklable + object. + size (int): Size of the results, commonly equal to length of + the results. + tmpdir (str | None): Temporal directory for collected results to + store. If set to None, it will create a random temporal directory + for it. Defaults to None. + + Returns: + list or None: The collected results. + + Examples: + >>> # distributed environment + >>> # We have 2 process groups, 2 ranks. + >>> import mmengine.dist as dist + >>> if dist.get_rank() == 0: + data = ['foo', {1: 2}] + else: + data = [24, {'a': 'b'}] + >>> size = 4 + >>> output = dist.collect_results_cpu(data, size) + >>> output + ['foo', 24, {1: 2}, {'a': 'b'}] # rank 0 + None # rank 1 + """ + rank, world_size = get_dist_info() + if world_size == 1: + return result_part[:size] + + # create a tmp dir if it is not specified + if tmpdir is None: + MAX_LEN = 512 + # 32 is whitespace + dir_tensor = torch.full((MAX_LEN, ), + 32, + dtype=torch.uint8, + device='cuda') + if rank == 0: + mmengine.mkdir_or_exist('.dist_test') + tmpdir = tempfile.mkdtemp(dir='.dist_test') + tmpdir = torch.tensor( + bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda') + dir_tensor[:len(tmpdir)] = tmpdir + dist.broadcast(dir_tensor, 0) + tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip() + else: + mmengine.mkdir_or_exist(tmpdir) + + # dump the part result to the dir + with open(osp.join(tmpdir, f'part_{rank}.pkl'), 'wb') as f: # type: ignore + pickle.dump(result_part, f, protocol=2) + dist.barrier() + # collect all parts + if rank != 0: + return None + else: + # load results of all parts from tmp dir + part_list = [] + for i in range(world_size): + path = osp.join(tmpdir, f'part_{i}.pkl') # type: ignore + with open(path, 'rb') as f: + part_list.append(pickle.load(f)) + # sort the results + ordered_results = [] + for res in zip(*part_list): + ordered_results.extend(list(res)) + # the dataloader may pad some samples + ordered_results = ordered_results[:size] + # remove tmp dir + shutil.rmtree(tmpdir) # type: ignore + return ordered_results + + +def collect_results_gpu(result_part: list, size: int) -> Optional[list]: + """Collect results under gpu mode. + + On gpu mode, this function will encode results to gpu tensors and use gpu + communication for results collection. + + Args: + result_part (list[object]): Result list containing result parts + to be collected. Each item of ``result_part`` should be a picklable + object. + size (int): Size of the results, commonly equal to length of + the results. + + Returns: + list or None: The collected results. + + Examples: + >>> # distributed environment + >>> # We have 2 process groups, 2 ranks. + >>> import mmengine.dist as dist + >>> if dist.get_rank() == 0: + data = ['foo', {1: 2}] + else: + data = [24, {'a': 'b'}] + >>> size = 4 + >>> output = dist.collect_results_gpu(data, size) + >>> output + ['foo', 24, {1: 2}, {'a': 'b'}] # rank 0 + None # rank 1 + """ + rank, world_size = get_dist_info() + if world_size == 1: + return result_part[:size] + + # dump result part to tensor with pickle + part_tensor = torch.tensor( + bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda') + # gather all result part tensor shape + shape_tensor = torch.tensor(part_tensor.shape, device='cuda') + shape_list = [shape_tensor.clone() for _ in range(world_size)] + dist.all_gather(shape_list, shape_tensor) + # padding result part tensor to max length + shape_max = torch.tensor(shape_list).max() + part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda') + part_send[:shape_tensor[0]] = part_tensor + part_recv_list = [ + part_tensor.new_zeros(shape_max) for _ in range(world_size) + ] + # gather all result part. Note that NCCL does not support gather so use + # all_gather + dist.all_gather(part_recv_list, part_send) + + if rank == 0: + part_list = [] + for recv, shape in zip(part_recv_list, shape_list): + part_list.append( + pickle.loads(recv[:shape[0]].cpu().numpy().tobytes())) + # sort the results + ordered_results = [] + for res in zip(*part_list): + ordered_results.extend(list(res)) + # the dataloader may pad some samples + ordered_results = ordered_results[:size] + return ordered_results + else: + return None diff --git a/mmengine/dist/utils.py b/mmengine/dist/utils.py new file mode 100644 index 00000000..68e55e42 --- /dev/null +++ b/mmengine/dist/utils.py @@ -0,0 +1,335 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools +import os +import subprocess +from typing import Callable, Optional, Tuple + +import torch +import torch.multiprocessing as mp +from torch import distributed as dist + +_LOCAL_PROCESS_GROUP = None + + +def is_distributed() -> bool: + """Return True if distributed environment has been initialized.""" + return dist.is_available() and dist.is_initialized() + + +def get_local_group() -> Optional[dist.ProcessGroup]: + """Return local process group.""" + if not is_distributed(): + return None + + if _LOCAL_PROCESS_GROUP is None: + raise RuntimeError('Local process group is not created, please use ' + '`init_local_group` to setup local process group.') + + return _LOCAL_PROCESS_GROUP + + +def get_default_group() -> Optional[dist.ProcessGroup]: + """Return default process group.""" + + return dist.distributed_c10d._get_default_group() + + +def init_dist(launcher, backend='nccl', **kwargs) -> None: + """Initialize distributed environment. + + Args: + launcher (str): Way to launcher multi processes. Supported launchers + are 'pytorch', 'mpi' and 'slurm'. + backend (str): Communication Backends. Supported backends are 'nccl', + 'gloo' and 'mpi'. Defaults to 'nccl'. + **kwargs: keyword arguments are passed to ``init_process_group``. + """ + if mp.get_start_method(allow_none=True) is None: + mp.set_start_method('spawn') + if launcher == 'pytorch': + _init_dist_pytorch(backend, **kwargs) + elif launcher == 'mpi': + _init_dist_mpi(backend, **kwargs) + elif launcher == 'slurm': + _init_dist_slurm(backend, **kwargs) + else: + raise ValueError(f'Invalid launcher type: {launcher}') + + +def _init_dist_pytorch(backend, **kwargs) -> None: + """Initialize distributed environment with PyTorch launcher. + + Args: + backend (str): Backend of torch.distributed. Supported backends are + 'nccl', 'gloo' and 'mpi'. Defaults to 'nccl'. + **kwargs: keyword arguments are passed to ``init_process_group``. + """ + # TODO: use local_rank instead of rank % num_gpus + rank = int(os.environ['RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + + +def _init_dist_mpi(backend, **kwargs) -> None: + """Initialize distributed environment with MPI launcher. + + Args: + backend (str): Backend of torch.distributed. Supported backends are + 'nccl', 'gloo' and 'mpi'. Defaults to 'nccl'. + **kwargs: keyword arguments are passed to ``init_process_group``. + """ + # TODO: use local_rank instead of rank % num_gpus + rank = int(os.environ['OMPI_COMM_WORLD_RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + + +def _init_dist_slurm(backend, port=None) -> None: + """Initialize slurm distributed training environment. + + If argument ``port`` is not specified, then the master port will be system + environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system + environment variable, then a default port ``29500`` will be used. + + Args: + backend (str): Backend of torch.distributed. + port (int, optional): Master port. Defaults to None. + + TODO: https://github.com/open-mmlab/mmcv/pull/1682 + """ + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(proc_id % num_gpus) + addr = subprocess.getoutput( + f'scontrol show hostname {node_list} | head -n1') + # specify master port + if port is not None: + os.environ['MASTER_PORT'] = str(port) + elif 'MASTER_PORT' in os.environ: + pass # use MASTER_PORT in the environment variable + else: + # 29500 is torch.distributed default port + os.environ['MASTER_PORT'] = '29500' + # use MASTER_ADDR in the environment variable if it already exists + if 'MASTER_ADDR' not in os.environ: + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) + os.environ['RANK'] = str(proc_id) + dist.init_process_group(backend=backend) + + +def init_local_group(node_rank: int, num_gpus_per_node: int): + """Setup the local process group. + + Setup a process group which only includes processes that on the same + machine as the current process. + + The code is modified from + https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py + + Args: + node_rank (int): Rank of machines used for training. + num_gpus_per_node (int): Number of gpus used for training in a single + machine. + """ # noqa: W501 + global _LOCAL_PROCESS_GROUP + assert _LOCAL_PROCESS_GROUP is None + + ranks = list( + range(node_rank * num_gpus_per_node, + (node_rank + 1) * num_gpus_per_node)) + _LOCAL_PROCESS_GROUP = dist.new_group(ranks) + + +def get_backend(group: Optional[dist.ProcessGroup] = None) -> Optional[str]: + """Return the backend of the given process group. + + Note: + Calling ``get_backend`` in non-distributed environment will return + None. + + Args: + group (ProcessGroup, optional): The process group to work on. The + default is the general main process group. If another specific + group is specified, the calling process must be part of + :attr:`group`. Defaults to None. + + Returns: + str or None: Return the backend of the given process group as a lower + case string if in distributed environment, otherwise None. + """ + if is_distributed(): + # handle low versions of torch like 1.5.0 which does not support + # passing in None for group argument + if group is None: + group = get_default_group() + return dist.get_backend(group) + else: + return None + + +def get_world_size(group: Optional[dist.ProcessGroup] = None) -> int: + """Return the number of the given process group. + + Note: + Calling ``get_world_size`` in non-distributed environment will return + 1. + + Args: + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Defaults to None. + + Returns: + int: Return the number of processes of the given process group if in + distributed environment, otherwise 1. + """ + if is_distributed(): + # handle low versions of torch like 1.5.0 which does not support + # passing in None for group argument + if group is None: + group = get_default_group() + return dist.get_world_size(group) + else: + return 1 + + +def get_rank(group: Optional[dist.ProcessGroup] = None) -> int: + """Return the rank of the given process group. + + Rank is a unique identifier assigned to each process within a distributed + process group. They are always consecutive integers ranging from 0 to + ``world_size``. + + Note: + Calling ``get_rank`` in non-distributed environment will return 0. + + Args: + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Defaults to None. + + Returns: + int: Return the rank of the process group if in distributed + environment, otherwise 0. + """ + + if is_distributed(): + # handle low versions of torch like 1.5.0 which does not support + # passing in None for group argument + if group is None: + group = get_default_group() + return dist.get_rank(group) + else: + return 0 + + +def get_local_size() -> int: + """Return the number of the current node. + + Returns: + int: Return the number of processes in the current node if in + distributed environment, otherwise 1. + """ + if not is_distributed(): + return 1 + + if _LOCAL_PROCESS_GROUP is None: + raise RuntimeError('Local process group is not created, please use ' + '`init_local_group` to setup local process group.') + + return dist.get_world_size(_LOCAL_PROCESS_GROUP) + + +def get_local_rank() -> int: + """Return the rank of current process in the current node. + + Returns: + int: Return the rank of current process in the current node if in + distributed environment, otherwise 0 + """ + if not is_distributed(): + return 0 + + if _LOCAL_PROCESS_GROUP is None: + raise RuntimeError('Local process group is not created, please use ' + '`init_local_group` to setup local process group.') + + return dist.get_rank(_LOCAL_PROCESS_GROUP) + + +def get_dist_info( + group: Optional[dist.ProcessGroup] = None) -> Tuple[int, int]: + """Get distributed information of the given process group. + + Note: + Calling ``get_dist_info`` in non-distributed environment will return + (0, 1). + + Args: + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Defaults to None. + + Returns: + tuple[int, int]: Return a tuple containing the ``world_size`` and + ``rank``. + """ + world_size = get_world_size(group) + rank = get_rank(group) + return rank, world_size + + +def is_main_process(group: Optional[dist.ProcessGroup] = None) -> bool: + """Whether the current rank of the given process group is equal to 0. + + Args: + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Defaults to None. + + Returns: + bool: Return True if the current rank of the given process group is + equal to 0, otherwise False. + """ + return get_rank(group) == 0 + + +def master_only(func: Callable) -> Callable: + """Decorate those methods which should be executed in master process. + + Args: + func (callable): Function to be decorated. + + Returns: + callable: Return decorated function. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + if is_main_process(): + return func(*args, **kwargs) + + return wrapper + + +def barrier(group: Optional[dist.ProcessGroup] = None) -> None: + """Synchronize all processes from the given process group. + + This collective blocks processes until the whole group enters this + function. + + Note: + Calling ``barrier`` in non-distributed environment will do nothing. + + Args: + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Defaults to None. + """ + if is_distributed(): + # handle low versions of torch like 1.5.0 which does not support + # passing in None for group argument + if group is None: + group = get_default_group() + dist.barrier(group) diff --git a/tests/test_dist/test_dist.py b/tests/test_dist/test_dist.py new file mode 100644 index 00000000..78a55c54 --- /dev/null +++ b/tests/test_dist/test_dist.py @@ -0,0 +1,376 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +import tempfile +from unittest.mock import patch + +import pytest +import torch +import torch.multiprocessing as mp + +import mmengine.dist as dist +from mmengine.dist.dist import sync_random_seed +from mmengine.utils import TORCH_VERSION, digit_version + + +def _test_all_reduce_non_dist(): + data = torch.arange(2, dtype=torch.int64) + expected = torch.arange(2, dtype=torch.int64) + dist.all_reduce(data) + assert torch.allclose(data, expected) + + +def _test_all_gather_non_dist(): + data = torch.arange(2, dtype=torch.int64) + expected = torch.arange(2, dtype=torch.int64) + output = dist.all_gather(data) + assert torch.allclose(output[0], expected) + + +def _test_gather_non_dist(): + data = torch.arange(2, dtype=torch.int64) + expected = torch.arange(2, dtype=torch.int64) + output = dist.gather(data) + assert torch.allclose(output[0], expected) + + +def _test_broadcast_non_dist(): + data = torch.arange(2, dtype=torch.int64) + expected = torch.arange(2, dtype=torch.int64) + dist.broadcast(data) + assert torch.allclose(data, expected) + + +@patch('numpy.random.randint', return_value=10) +def _test_sync_random_seed_no_dist(mock): + assert sync_random_seed() == 10 + + +def _test_broadcast_object_list_no_dist(): + with pytest.raises(AssertionError): + # input should be list of object + dist.broadcast_object_list('foo') + + data = ['foo', 12, {1: 2}] + expected = ['foo', 12, {1: 2}] + dist.broadcast_object_list(data) + assert data == expected + + +def _test_all_reduce_dict_no_dist(): + with pytest.raises(AssertionError): + # input should be dict + dist.all_reduce_dict('foo') + + data = { + 'key1': torch.arange(2, dtype=torch.int64), + 'key2': torch.arange(3, dtype=torch.int64) + } + expected = { + 'key1': torch.arange(2, dtype=torch.int64), + 'key2': torch.arange(3, dtype=torch.int64) + } + dist.all_reduce_dict(data) + for key in data: + assert torch.allclose(data[key], expected[key]) + + +def _test_all_gather_object_no_dist(): + data = 'foo' + expected = 'foo' + gather_objects = dist.all_gather_object(data) + assert gather_objects[0] == expected + + +def _test_gather_object_no_dist(): + data = 'foo' + expected = 'foo' + gather_objects = dist.gather_object(data) + assert gather_objects[0] == expected + + +def _test_collect_results_non_dist(): + data = ['foo', {1: 2}] + size = 2 + expected = ['foo', {1: 2}] + + # test `device=cpu` + output = dist.collect_results(data, size, device='cpu') + assert output == expected + + # test `device=gpu` + output = dist.collect_results(data, size, device='cpu') + assert output == expected + + +def init_process(rank, world_size, functions, backend='gloo'): + """Initialize the distributed environment.""" + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '29505' + os.environ['RANK'] = str(rank) + dist.init_dist('pytorch', backend, rank=rank, world_size=world_size) + + device = 'cpu' if backend == 'gloo' else 'cuda' + + for func in functions: + func(device) + + +def main(functions, world_size=2, backend='gloo'): + try: + mp.spawn( + init_process, + args=(world_size, functions, backend), + nprocs=world_size) + except Exception: + pytest.fail(f'{backend} failed') + + +def _test_all_reduce_dist(device): + for tensor_type, reduce_op in zip([torch.int64, torch.float32], + ['sum', 'mean']): + if dist.get_rank() == 0: + data = torch.tensor([1, 2], dtype=tensor_type).to(device) + else: + data = torch.tensor([3, 4], dtype=tensor_type).to(device) + + if reduce_op == 'sum': + expected = torch.tensor([4, 6], dtype=tensor_type).to(device) + else: + expected = torch.tensor([2, 3], dtype=tensor_type).to(device) + + dist.all_reduce(data, reduce_op) + assert torch.allclose(data, expected) + + +def _test_all_gather_dist(device): + if dist.get_rank() == 0: + data = torch.tensor([0, 1]).to(device) + else: + data = torch.tensor([1, 2]).to(device) + + expected = [ + torch.tensor([0, 1]).to(device), + torch.tensor([1, 2]).to(device) + ] + + output = dist.all_gather(data) + assert torch.allclose(output[dist.get_rank()], expected[dist.get_rank()]) + + +def _test_gather_dist(device): + if dist.get_rank() == 0: + data = torch.tensor([0, 1]).to(device) + else: + data = torch.tensor([1, 2]).to(device) + + output = dist.gather(data) + + if dist.get_rank() == 0: + expected = [ + torch.tensor([0, 1]).to(device), + torch.tensor([1, 2]).to(device) + ] + for i in range(2): + assert torch.allclose(output[i], expected[i]) + else: + assert output == [] + + +def _test_broadcast_dist(device): + if dist.get_rank() == 0: + data = torch.tensor([0, 1]).to(device) + else: + data = torch.tensor([1, 2]).to(device) + + expected = torch.tensor([0, 1]).to(device) + dist.broadcast(data, 0) + assert torch.allclose(data, expected) + + +def _test_sync_random_seed_dist(device): + with patch.object( + torch, 'tensor', + return_value=torch.tensor(1024).to(device)) as mock_tensor: + output = dist.sync_random_seed() + assert output == 1024 + mock_tensor.assert_called() + + +def _test_broadcast_object_list_dist(device): + if dist.get_rank() == 0: + data = ['foo', 12, {1: 2}] + else: + data = [None, None, None] + + expected = ['foo', 12, {1: 2}] + + dist.broadcast_object_list(data) + + assert data == expected + + +def _test_all_reduce_dict_dist(device): + for tensor_type, reduce_op in zip([torch.int64, torch.float32], + ['sum', 'mean']): + if dist.get_rank() == 0: + data = { + 'key1': torch.tensor([0, 1], dtype=tensor_type).to(device), + 'key2': torch.tensor([1, 2], dtype=tensor_type).to(device) + } + else: + data = { + 'key1': torch.tensor([2, 3], dtype=tensor_type).to(device), + 'key2': torch.tensor([3, 4], dtype=tensor_type).to(device) + } + + if reduce_op == 'sum': + expected = { + 'key1': torch.tensor([2, 4], dtype=tensor_type).to(device), + 'key2': torch.tensor([4, 6], dtype=tensor_type).to(device) + } + else: + expected = { + 'key1': torch.tensor([1, 2], dtype=tensor_type).to(device), + 'key2': torch.tensor([2, 3], dtype=tensor_type).to(device) + } + + dist.all_reduce_dict(data, reduce_op) + + for key in data: + assert torch.allclose(data[key], expected[key]) + + # `torch.cat` in torch1.5 can not concatenate different types so we + # fallback to convert them all to float type. + if digit_version(TORCH_VERSION) == digit_version('1.5.0'): + if dist.get_rank() == 0: + data = { + 'key1': torch.tensor([0, 1], dtype=torch.float32).to(device), + 'key2': torch.tensor([1, 2], dtype=torch.int32).to(device) + } + else: + data = { + 'key1': torch.tensor([2, 3], dtype=torch.float32).to(device), + 'key2': torch.tensor([3, 4], dtype=torch.int32).to(device) + } + + expected = { + 'key1': torch.tensor([2, 4], dtype=torch.float32).to(device), + 'key2': torch.tensor([4, 6], dtype=torch.float32).to(device) + } + + dist.all_reduce_dict(data, 'sum') + + for key in data: + assert torch.allclose(data[key], expected[key]) + + +def _test_all_gather_object_dist(device): + if dist.get_rank() == 0: + data = 'foo' + else: + data = {1: 2} + + expected = ['foo', {1: 2}] + output = dist.all_gather_object(data) + + assert output == expected + + +def _test_gather_object_dist(device): + if dist.get_rank() == 0: + data = 'foo' + else: + data = {1: 2} + + output = dist.gather_object(data, dst=0) + + if dist.get_rank() == 0: + assert output == ['foo', {1: 2}] + else: + assert output is None + + +def _test_collect_results_dist(device): + if dist.get_rank() == 0: + data = ['foo', {1: 2}] + else: + data = [24, {'a': 'b'}] + + size = 4 + + expected = ['foo', 24, {1: 2}, {'a': 'b'}] + + # test `device=cpu` + output = dist.collect_results(data, size, device='cpu') + if dist.get_rank() == 0: + assert output == expected + else: + assert output is None + + # test `device=cpu` and `tmpdir is not None` + tmpdir = tempfile.mkdtemp() + # broadcast tmpdir to all ranks to make it consistent + object_list = [tmpdir] + dist.broadcast_object_list(object_list) + output = dist.collect_results( + data, size, device='cpu', tmpdir=object_list[0]) + if dist.get_rank() == 0: + assert output == expected + else: + assert output is None + + if dist.get_rank() == 0: + # object_list[0] will be removed by `dist.collect_results` + assert not osp.exists(object_list[0]) + + # test `device=gpu` + output = dist.collect_results(data, size, device='gpu') + if dist.get_rank() == 0: + assert output == expected + else: + assert output is None + + +def test_non_distributed_env(): + _test_all_reduce_non_dist() + _test_all_gather_non_dist() + _test_gather_non_dist() + _test_broadcast_non_dist() + _test_sync_random_seed_no_dist() + _test_broadcast_object_list_no_dist() + _test_all_reduce_dict_no_dist() + _test_all_gather_object_no_dist() + _test_gather_object_no_dist() + _test_collect_results_non_dist() + + +def test_gloo_backend(): + functions_to_test = [ + _test_all_reduce_dist, + _test_all_gather_dist, + _test_gather_dist, + _test_broadcast_dist, + _test_sync_random_seed_dist, + _test_broadcast_object_list_dist, + _test_all_reduce_dict_dist, + _test_all_gather_object_dist, + _test_gather_object_dist, + ] + main(functions_to_test, backend='gloo') + + +@pytest.mark.skipif( + torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl') +def test_nccl_backend(): + functions_to_test = [ + _test_all_reduce_dist, + _test_all_gather_dist, + _test_broadcast_dist, + _test_sync_random_seed_dist, + _test_broadcast_object_list_dist, + _test_all_reduce_dict_dist, + _test_all_gather_object_dist, + _test_collect_results_dist, + ] + main(functions_to_test, backend='nccl') diff --git a/tests/test_dist/test_utils.py b/tests/test_dist/test_utils.py new file mode 100644 index 00000000..e099c879 --- /dev/null +++ b/tests/test_dist/test_utils.py @@ -0,0 +1,152 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os + +import pytest +import torch +import torch.distributed as torch_dist +import torch.multiprocessing as mp + +import mmengine.dist as dist + + +def _test_get_backend_non_dist(): + assert dist.get_backend() is None + + +def _test_get_world_size_non_dist(): + assert dist.get_world_size() == 1 + + +def _test_get_rank_non_dist(): + assert dist.get_rank() == 0 + + +def _test_local_size_non_dist(): + assert dist.get_local_size() == 1 + + +def _test_local_rank_non_dist(): + assert dist.get_local_rank() == 0 + + +def _test_get_dist_info_non_dist(): + assert dist.get_dist_info() == (0, 1) + + +def _test_is_main_process_non_dist(): + assert dist.is_main_process() + + +def _test_master_only_non_dist(): + + @dist.master_only + def fun(): + assert dist.get_rank() == 0 + + fun() + + +def _test_barrier_non_dist(): + dist.barrier() # nothing is done + + +def init_process(rank, world_size, functions, backend='gloo'): + """Initialize the distributed environment.""" + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '29501' + os.environ['RANK'] = str(rank) + dist.init_dist('pytorch', backend, rank=rank, world_size=world_size) + dist.init_local_group(0, world_size) + + for func in functions: + func() + + +def main(functions, world_size=2, backend='gloo'): + try: + mp.spawn( + init_process, + args=(world_size, functions, backend), + nprocs=world_size) + except Exception: + pytest.fail('error') + + +def _test_get_backend_dist(): + assert dist.get_backend() == torch_dist.get_backend() + + +def _test_get_world_size_dist(): + assert dist.get_world_size() == 2 + + +def _test_get_rank_dist(): + if torch_dist.get_rank() == 0: + assert dist.get_rank() == 0 + else: + assert dist.get_rank() == 1 + + +def _test_local_size_dist(): + assert dist.get_local_size() == 2 + + +def _test_local_rank_dist(): + torch_dist.get_rank(dist.get_local_group()) == dist.get_local_rank() + + +def _test_get_dist_info_dist(): + if dist.get_rank() == 0: + assert dist.get_dist_info() == (0, 2) + else: + assert dist.get_dist_info() == (1, 2) + + +def _test_is_main_process_dist(): + if dist.get_rank() == 0: + assert dist.is_main_process() + else: + assert not dist.is_main_process() + + +def _test_master_only_dist(): + + @dist.master_only + def fun(): + assert dist.get_rank() == 0 + + fun() + + +def test_non_distributed_env(): + _test_get_backend_non_dist() + _test_get_world_size_non_dist() + _test_get_rank_non_dist() + _test_local_size_non_dist() + _test_local_rank_non_dist() + _test_get_dist_info_non_dist() + _test_is_main_process_non_dist() + _test_master_only_non_dist() + _test_barrier_non_dist() + + +functions_to_test = [ + _test_get_backend_dist, + _test_get_world_size_dist, + _test_get_rank_dist, + _test_local_size_dist, + _test_local_rank_dist, + _test_get_dist_info_dist, + _test_is_main_process_dist, + _test_master_only_dist, +] + + +def test_gloo_backend(): + main(functions_to_test) + + +@pytest.mark.skipif( + torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl') +def test_nccl_backend(): + main(functions_to_test, backend='nccl') -- GitLab