Skip to content
Snippets Groups Projects
Unverified Commit 17dbac18 authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

[Enhancement] Handle the device type of inputs in functions (#137)

* [Enhancement] Handle the device type of inputs in functions

* rename and move three fucntions to dist/utils.py

* minor refinement

* rename dist to torch_dist in utils.py

* update unit tests

* refine unit tests

* add unit tests

* fix unit tests

* replace Sequence with list and tuple

* rename get_backend_device to get_comm_device

* fix unit tests

* fix unit tests

* refactor and add more unit tests

* cast_data_device does not support set type
parent 661e7590
No related branches found
No related tags found
No related merge requests found
......@@ -6,7 +6,8 @@ from .dist import (all_gather_object, all_reduce, all_gather, all_reduce_dict,
from .utils import (get_dist_info, init_dist, init_local_group, get_backend,
get_world_size, get_rank, get_local_size, get_local_rank,
is_main_process, master_only, barrier, get_local_group,
is_distributed, get_default_group)
is_distributed, get_default_group, get_data_device,
get_comm_device, cast_data_device)
__all__ = [
'all_gather_object', 'all_reduce', 'all_gather', 'all_reduce_dict',
......@@ -15,5 +16,6 @@ __all__ = [
'get_dist_info', 'init_dist', 'init_local_group', 'get_backend',
'get_world_size', 'get_rank', 'get_local_size', 'get_local_group',
'get_local_rank', 'is_main_process', 'master_only', 'barrier',
'is_distributed', 'get_default_group'
'is_distributed', 'get_default_group', 'get_data_device',
'get_comm_device', 'cast_data_device'
]
......@@ -7,23 +7,25 @@ import tempfile
import torch
import os.path as osp
from torch import Tensor
from torch import distributed as dist
from torch import distributed as torch_dist
from torch.distributed import ProcessGroup
import mmengine
from .utils import (get_world_size, get_rank, get_backend, get_dist_info,
get_default_group)
get_default_group, barrier, get_data_device,
get_comm_device, cast_data_device)
from mmengine.utils import digit_version, TORCH_VERSION
def _get_reduce_op(name: str) -> dist.ReduceOp:
def _get_reduce_op(name: str) -> torch_dist.ReduceOp:
op_mappings = {
'sum': dist.ReduceOp.SUM,
'product': dist.ReduceOp.PRODUCT,
'min': dist.ReduceOp.MIN,
'max': dist.ReduceOp.MAX,
'band': dist.ReduceOp.BAND,
'bor': dist.ReduceOp.BOR,
'bxor': dist.ReduceOp.BXOR,
'sum': torch_dist.ReduceOp.SUM,
'product': torch_dist.ReduceOp.PRODUCT,
'min': torch_dist.ReduceOp.MIN,
'max': torch_dist.ReduceOp.MAX,
'band': torch_dist.ReduceOp.BAND,
'bor': torch_dist.ReduceOp.BOR,
'bxor': torch_dist.ReduceOp.BXOR,
}
if name.lower() not in op_mappings:
......@@ -35,7 +37,7 @@ def _get_reduce_op(name: str) -> dist.ReduceOp:
def all_reduce(data: Tensor,
op: str = 'sum',
group: Optional[dist.ProcessGroup] = None) -> None:
group: Optional[ProcessGroup] = None) -> None:
"""Reduces the tensor data across all machines in such a way that all get
the final result.
......@@ -70,7 +72,7 @@ def all_reduce(data: Tensor,
>>> data
tensor([1, 2]) # Rank 0
tensor([3, 4]) # Rank 1
>>> dist.all_reduce(data, op=torch.dist.ReduceOp.SUM)
>>> dist.all_reduce(data, op=dist.ReduceOp.SUM)
>>> data
tensor([4, 6]) # Rank 0
tensor([4, 6]) # Rank 1
......@@ -80,17 +82,23 @@ def all_reduce(data: Tensor,
if group is None:
group = get_default_group()
input_device = get_data_device(data)
backend_device = get_comm_device(group)
data_on_device = cast_data_device(data, backend_device)
# pytorch does not support 'mean' operation so we fall back to support
# it with 'sum' operation.
if op.lower() == 'mean':
dist.all_reduce(data, _get_reduce_op('sum'), group)
data.div_(world_size)
torch_dist.all_reduce(data_on_device, _get_reduce_op('sum'), group)
data_on_device.div_(world_size) # type: ignore
else:
dist.all_reduce(data, _get_reduce_op(op), group)
torch_dist.all_reduce(data_on_device, _get_reduce_op(op), group)
cast_data_device(data_on_device, input_device, out=data)
def all_gather(data: Tensor,
group: Optional[dist.ProcessGroup] = None) -> List[Tensor]:
group: Optional[ProcessGroup] = None) -> List[Tensor]:
"""Gather data from the whole group in a list.
Note:
......@@ -146,15 +154,23 @@ def all_gather(data: Tensor,
if group is None:
group = get_default_group()
gather_list = [torch.empty_like(data) for _ in range(world_size)]
dist.all_gather(gather_list, data, group)
return gather_list
input_device = get_data_device(data)
backend_device = get_comm_device(group)
data_on_device = cast_data_device(data, backend_device)
gather_list = [
torch.empty_like(data, device=backend_device)
for _ in range(world_size)
]
torch_dist.all_gather(gather_list, data_on_device, group)
return cast_data_device(gather_list, input_device) # type: ignore
def gather(
data: Tensor,
dst: int = 0,
group: Optional[dist.ProcessGroup] = None) -> List[Optional[Tensor]]:
def gather(data: Tensor,
dst: int = 0,
group: Optional[ProcessGroup] = None) -> List[Optional[Tensor]]:
"""Gather data from the whole group to ``dst`` process.
Note:
......@@ -215,18 +231,28 @@ def gather(
if group is None:
group = get_default_group()
input_device = get_data_device(data)
backend_device = get_comm_device(group)
if get_rank(group) == dst:
gather_list = [torch.empty_like(data) for _ in range(world_size)]
gather_list = [
torch.empty_like(data, device=backend_device)
for _ in range(world_size)
]
else:
gather_list = []
dist.gather(data, gather_list, dst, group)
return gather_list
torch_dist.gather(data, gather_list, dst, group)
if get_rank(group) == dst:
return cast_data_device(gather_list, input_device) # type: ignore
else:
return gather_list
def broadcast(data: Tensor,
src: int = 0,
group: Optional[dist.ProcessGroup] = None) -> None:
group: Optional[ProcessGroup] = None) -> None:
"""Broadcast the data from ``src`` process to the whole group.
``data`` must have the same number of elements in all processes
......@@ -269,10 +295,17 @@ def broadcast(data: Tensor,
if group is None:
group = get_default_group()
dist.broadcast(data, src, group)
input_device = get_data_device(data)
backend_device = get_comm_device(group)
data_on_device = cast_data_device(data, backend_device)
torch_dist.broadcast(data_on_device, src, group)
if get_rank(group) != src:
cast_data_device(data_on_device, input_device, data)
def sync_random_seed(group: Optional[dist.ProcessGroup] = None) -> int:
def sync_random_seed(group: Optional[ProcessGroup] = None) -> int:
"""Synchronize a random seed to all processes.
Args:
......@@ -305,18 +338,14 @@ def sync_random_seed(group: Optional[dist.ProcessGroup] = None) -> int:
if group is None:
group = get_default_group()
group_backend = get_backend(group)
is_nccl_backend = group_backend == dist.Backend.NCCL
current_device = torch.device('cpu')
if is_nccl_backend:
current_device = torch.device('cuda', torch.cuda.current_device())
backend_device = get_comm_device(group)
if get_rank(group) == 0:
random_num = torch.tensor(seed, dtype=torch.int32).to(current_device)
random_num = torch.tensor(seed, dtype=torch.int32).to(backend_device)
else:
random_num = torch.tensor(0, dtype=torch.int32).to(current_device)
random_num = torch.tensor(0, dtype=torch.int32).to(backend_device)
dist.broadcast(random_num, src=0, group=group)
torch_dist.broadcast(random_num, src=0, group=group)
return random_num.item()
......@@ -340,14 +369,14 @@ def _tensor_to_object(tensor: Tensor, tensor_size: int) -> Any:
def _broadcast_object_list(object_list: List[Any],
src: int = 0,
group: Optional[dist.ProcessGroup] = None) -> None:
group: Optional[ProcessGroup] = None) -> None:
"""Broadcast picklable objects in ``object_list`` to the whole group.
Similar to :func:`broadcast`, but Python objects can be passed in. Note
that all objects in ``object_list`` must be picklable in order to be
broadcasted.
"""
if dist.distributed_c10d._rank_not_in_group(group):
if torch_dist.distributed_c10d._rank_not_in_group(group):
return
my_rank = get_rank()
......@@ -366,7 +395,7 @@ def _broadcast_object_list(object_list: List[Any],
# the case it is not ``None`` we move the size and object tensors to be
# broadcasted to this device.
group_backend = get_backend(group)
is_nccl_backend = group_backend == dist.Backend.NCCL
is_nccl_backend = group_backend == torch_dist.Backend.NCCL
current_device = torch.device('cpu')
if is_nccl_backend:
# See note about using torch.cuda.current_device() here in
......@@ -376,7 +405,7 @@ def _broadcast_object_list(object_list: List[Any],
object_sizes_tensor = object_sizes_tensor.to(current_device)
# Broadcast object sizes
dist.broadcast(object_sizes_tensor, src=src, group=group)
torch_dist.broadcast(object_sizes_tensor, src=src, group=group)
# Concatenate and broadcast serialized object tensors
if my_rank == src:
......@@ -389,7 +418,7 @@ def _broadcast_object_list(object_list: List[Any],
if is_nccl_backend:
object_tensor = object_tensor.to(current_device)
dist.broadcast(object_tensor, src=src, group=group)
torch_dist.broadcast(object_tensor, src=src, group=group)
# Deserialize objects using their stored sizes.
offset = 0
if my_rank != src:
......@@ -404,7 +433,7 @@ def _broadcast_object_list(object_list: List[Any],
def broadcast_object_list(data: List[Any],
src: int = 0,
group: Optional[dist.ProcessGroup] = None) -> None:
group: Optional[ProcessGroup] = None) -> None:
"""Broadcasts picklable objects in ``object_list`` to the whole group.
Similar to :func:`broadcast`, but Python objects can be passed in. Note
that all objects in ``object_list`` must be picklable in order to be
......@@ -462,14 +491,14 @@ def broadcast_object_list(data: List[Any],
group = get_default_group()
if digit_version(TORCH_VERSION) >= digit_version('1.8.0'):
dist.broadcast_object_list(data, src, group)
torch_dist.broadcast_object_list(data, src, group)
else:
_broadcast_object_list(data, src, group)
def all_reduce_dict(data: Dict[str, Tensor],
op: str = 'sum',
group: Optional[dist.ProcessGroup] = None) -> None:
group: Optional[ProcessGroup] = None) -> None:
"""Reduces the dict across all machines in such a way that all get the
final result.
......@@ -542,7 +571,7 @@ def all_reduce_dict(data: Dict[str, Tensor],
def _all_gather_object(object_list: List[Any],
obj: Any,
group: Optional[dist.ProcessGroup] = None) -> None:
group: Optional[ProcessGroup] = None) -> None:
"""Gather picklable objects from the whole group into a list.
Similar to :func:`all_gather`, but Python objects can be passed in.
......@@ -563,13 +592,13 @@ def _all_gather_object(object_list: List[Any],
calling rank is not part of the group, the passed in ``object_list``
will be unmodified.
"""
if dist.distributed_c10d._rank_not_in_group(group):
if torch_dist.distributed_c10d._rank_not_in_group(group):
return
input_tensor, local_size = _object_to_tensor(obj)
group_backend = get_backend(group)
current_device = torch.device('cpu')
is_nccl_backend = group_backend == dist.Backend.NCCL
is_nccl_backend = group_backend == torch_dist.Backend.NCCL
if is_nccl_backend:
# See note about using torch.cuda.current_device() here in docstring.
# We cannot simply use my_rank since rank == device is not necessarily
......@@ -586,7 +615,7 @@ def _all_gather_object(object_list: List[Any],
object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size)
]
# Allgather tensor sizes
dist.all_gather(object_size_list, local_size, group=group)
torch_dist.all_gather(object_size_list, local_size, group=group)
max_object_size = int(max(object_size_list).item())
# Resize tensor to max size across all ranks.
input_tensor.resize_(max_object_size)
......@@ -597,7 +626,7 @@ def _all_gather_object(object_list: List[Any],
coalesced_output_tensor[max_object_size * i:max_object_size * (i + 1)]
for i in range(group_size)
]
dist.all_gather(output_tensors, input_tensor, group=group)
torch_dist.all_gather(output_tensors, input_tensor, group=group)
# Deserialize outputs back to object.
for i, tensor in enumerate(output_tensors):
tensor = tensor.type(torch.uint8)
......@@ -608,7 +637,7 @@ def _all_gather_object(object_list: List[Any],
def all_gather_object(data: Any,
group: Optional[dist.ProcessGroup] = None) -> List[Any]:
group: Optional[ProcessGroup] = None) -> List[Any]:
"""Gather picklable objects from the whole group into a list. Similar to
:func:`all_gather`, but Python objects can be passed in. Note that the
object must be picklable in order to be gathered.
......@@ -673,7 +702,7 @@ def all_gather_object(data: Any,
gather_list = [None] * world_size
if digit_version(TORCH_VERSION) >= digit_version('1.8.0'):
dist.all_gather_object(gather_list, data, group)
torch_dist.all_gather_object(gather_list, data, group)
else:
_all_gather_object(gather_list, data, group)
......@@ -696,7 +725,7 @@ def _validate_output_list_for_rank(my_rank: int, dst: int,
def _gather_object(obj: Any,
object_gather_list=None,
dst: int = 0,
group: Optional[dist.ProcessGroup] = None) -> None:
group: Optional[ProcessGroup] = None) -> None:
"""Gathers picklable objects from the whole group in a single process.
Similar to :func:`gather`, but Python objects can be passed in. Note that
......@@ -712,7 +741,7 @@ def _gather_object(obj: Any,
group: (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used. Defaults to None.
"""
if dist.distributed_c10d._rank_not_in_group(group):
if torch_dist.distributed_c10d._rank_not_in_group(group):
return
# Ensure object_gather_list is specified appopriately.
......@@ -721,7 +750,7 @@ def _gather_object(obj: Any,
input_tensor, local_size = _object_to_tensor(obj)
group_backend = get_backend(group)
current_device = torch.device('cpu')
is_nccl_backend = group_backend == dist.Backend.NCCL
is_nccl_backend = group_backend == torch_dist.Backend.NCCL
if is_nccl_backend:
current_device = torch.device('cuda', torch.cuda.current_device())
input_tensor = input_tensor.to(current_device)
......@@ -737,7 +766,7 @@ def _gather_object(obj: Any,
# Allgather tensor sizes. An all-gather is needed here despite this being a
# gather, since each rank needs to broadcast a tensor of the same (maximal)
# size.
dist.all_gather(object_size_list, local_size, group=group)
torch_dist.all_gather(object_size_list, local_size, group=group)
max_object_size = int(max(object_size_list).item())
# Resize tensor to max size across all ranks.
input_tensor.resize_(max_object_size)
......@@ -754,7 +783,7 @@ def _gather_object(obj: Any,
(i + 1)] for i in range(group_size)
]
# All ranks call gather with equal-sized tensors.
dist.gather(
torch_dist.gather(
input_tensor,
gather_list=output_tensors if my_rank == dst else None,
dst=dst,
......@@ -768,10 +797,9 @@ def _gather_object(obj: Any,
object_gather_list[i] = _tensor_to_object(tensor, tensor_size)
def gather_object(
data: Any,
dst: int = 0,
group: Optional[dist.ProcessGroup] = None) -> Optional[List[Any]]:
def gather_object(data: Any,
dst: int = 0,
group: Optional[ProcessGroup] = None) -> Optional[List[Any]]:
"""Gathers picklable objects from the whole group in a single process.
Similar to :func:`gather`, but Python objects can be passed in. Note that
the object must be picklable in order to be gathered.
......@@ -789,7 +817,7 @@ def gather_object(
- PyTorch: gather_object(data, gather_list, data, group) -> None
Args:
obj (Any): Input object. Must be picklable.
data (Any): Input object. Must be picklable.
dst (int): Destination rank. Defaults to 0.
group: (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used. Defaults to None.
......@@ -826,7 +854,7 @@ def gather_object(
gather_list = [None] * world_size if get_rank(group) == dst else None
if digit_version(TORCH_VERSION) >= digit_version('1.8.0'):
dist.gather_object(data, gather_list, dst, group)
torch_dist.gather_object(data, gather_list, dst, group)
else:
_gather_object(data, gather_list, dst, group)
......@@ -921,25 +949,24 @@ def collect_results_cpu(result_part: list,
if tmpdir is None:
MAX_LEN = 512
# 32 is whitespace
dir_tensor = torch.full((MAX_LEN, ),
32,
dtype=torch.uint8,
device='cuda')
dir_tensor = torch.full((MAX_LEN, ), 32, dtype=torch.uint8)
if rank == 0:
mmengine.mkdir_or_exist('.dist_test')
tmpdir = tempfile.mkdtemp(dir='.dist_test')
tmpdir = torch.tensor(
bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
bytearray(tmpdir.encode()), dtype=torch.uint8)
dir_tensor[:len(tmpdir)] = tmpdir
dist.broadcast(dir_tensor, 0)
tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
broadcast(dir_tensor, 0)
tmpdir = dir_tensor.numpy().tobytes().decode().rstrip()
else:
mmengine.mkdir_or_exist(tmpdir)
# dump the part result to the dir
with open(osp.join(tmpdir, f'part_{rank}.pkl'), 'wb') as f: # type: ignore
pickle.dump(result_part, f, protocol=2)
dist.barrier()
barrier()
# collect all parts
if rank != 0:
return None
......@@ -995,29 +1022,11 @@ def collect_results_gpu(result_part: list, size: int) -> Optional[list]:
if world_size == 1:
return result_part[:size]
# dump result part to tensor with pickle
part_tensor = torch.tensor(
bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
# gather all result part tensor shape
shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
shape_list = [shape_tensor.clone() for _ in range(world_size)]
dist.all_gather(shape_list, shape_tensor)
# padding result part tensor to max length
shape_max = torch.tensor(shape_list).max()
part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
part_send[:shape_tensor[0]] = part_tensor
part_recv_list = [
part_tensor.new_zeros(shape_max) for _ in range(world_size)
]
# gather all result part. Note that NCCL does not support gather so use
# all_gather
dist.all_gather(part_recv_list, part_send)
# all_gather_object instead.
part_list = all_gather_object(result_part)
if rank == 0:
part_list = []
for recv, shape in zip(part_recv_list, shape_list):
part_list.append(
pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
# sort the results
ordered_results = []
for res in zip(*part_list):
......
# Copyright (c) OpenMMLab. All rights reserved.
import functools
import os
import numpy as np
import subprocess
from typing import Callable, Optional, Tuple
from typing import Callable, Optional, Tuple, Union
from collections import Mapping, Iterable
import torch
from torch import Tensor
import torch.multiprocessing as mp
from torch import distributed as dist
from torch import distributed as torch_dist
from torch.distributed import ProcessGroup
_LOCAL_PROCESS_GROUP = None
def is_distributed() -> bool:
"""Return True if distributed environment has been initialized."""
return dist.is_available() and dist.is_initialized()
return torch_dist.is_available() and torch_dist.is_initialized()
def get_local_group() -> Optional[dist.ProcessGroup]:
def get_local_group() -> Optional[ProcessGroup]:
"""Return local process group."""
if not is_distributed():
return None
......@@ -28,10 +31,10 @@ def get_local_group() -> Optional[dist.ProcessGroup]:
return _LOCAL_PROCESS_GROUP
def get_default_group() -> Optional[dist.ProcessGroup]:
def get_default_group() -> Optional[ProcessGroup]:
"""Return default process group."""
return dist.distributed_c10d._get_default_group()
return torch_dist.distributed_c10d._get_default_group()
def init_dist(launcher, backend='nccl', **kwargs) -> None:
......@@ -68,7 +71,7 @@ def _init_dist_pytorch(backend, **kwargs) -> None:
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
torch_dist.init_process_group(backend=backend, **kwargs)
def _init_dist_mpi(backend, **kwargs) -> None:
......@@ -83,7 +86,7 @@ def _init_dist_mpi(backend, **kwargs) -> None:
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
torch_dist.init_process_group(backend=backend, **kwargs)
def _init_dist_slurm(backend, port=None) -> None:
......@@ -120,7 +123,7 @@ def _init_dist_slurm(backend, port=None) -> None:
os.environ['WORLD_SIZE'] = str(ntasks)
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
os.environ['RANK'] = str(proc_id)
dist.init_process_group(backend=backend)
torch_dist.init_process_group(backend=backend)
def init_local_group(node_rank: int, num_gpus_per_node: int):
......@@ -143,10 +146,10 @@ def init_local_group(node_rank: int, num_gpus_per_node: int):
ranks = list(
range(node_rank * num_gpus_per_node,
(node_rank + 1) * num_gpus_per_node))
_LOCAL_PROCESS_GROUP = dist.new_group(ranks)
_LOCAL_PROCESS_GROUP = torch_dist.new_group(ranks)
def get_backend(group: Optional[dist.ProcessGroup] = None) -> Optional[str]:
def get_backend(group: Optional[ProcessGroup] = None) -> Optional[str]:
"""Return the backend of the given process group.
Note:
......@@ -168,12 +171,12 @@ def get_backend(group: Optional[dist.ProcessGroup] = None) -> Optional[str]:
# passing in None for group argument
if group is None:
group = get_default_group()
return dist.get_backend(group)
return torch_dist.get_backend(group)
else:
return None
def get_world_size(group: Optional[dist.ProcessGroup] = None) -> int:
def get_world_size(group: Optional[ProcessGroup] = None) -> int:
"""Return the number of the given process group.
Note:
......@@ -193,12 +196,12 @@ def get_world_size(group: Optional[dist.ProcessGroup] = None) -> int:
# passing in None for group argument
if group is None:
group = get_default_group()
return dist.get_world_size(group)
return torch_dist.get_world_size(group)
else:
return 1
def get_rank(group: Optional[dist.ProcessGroup] = None) -> int:
def get_rank(group: Optional[ProcessGroup] = None) -> int:
"""Return the rank of the given process group.
Rank is a unique identifier assigned to each process within a distributed
......@@ -222,7 +225,7 @@ def get_rank(group: Optional[dist.ProcessGroup] = None) -> int:
# passing in None for group argument
if group is None:
group = get_default_group()
return dist.get_rank(group)
return torch_dist.get_rank(group)
else:
return 0
......@@ -241,7 +244,7 @@ def get_local_size() -> int:
raise RuntimeError('Local process group is not created, please use '
'`init_local_group` to setup local process group.')
return dist.get_world_size(_LOCAL_PROCESS_GROUP)
return torch_dist.get_world_size(_LOCAL_PROCESS_GROUP)
def get_local_rank() -> int:
......@@ -258,11 +261,10 @@ def get_local_rank() -> int:
raise RuntimeError('Local process group is not created, please use '
'`init_local_group` to setup local process group.')
return dist.get_rank(_LOCAL_PROCESS_GROUP)
return torch_dist.get_rank(_LOCAL_PROCESS_GROUP)
def get_dist_info(
group: Optional[dist.ProcessGroup] = None) -> Tuple[int, int]:
def get_dist_info(group: Optional[ProcessGroup] = None) -> Tuple[int, int]:
"""Get distributed information of the given process group.
Note:
......@@ -282,7 +284,7 @@ def get_dist_info(
return rank, world_size
def is_main_process(group: Optional[dist.ProcessGroup] = None) -> bool:
def is_main_process(group: Optional[ProcessGroup] = None) -> bool:
"""Whether the current rank of the given process group is equal to 0.
Args:
......@@ -314,7 +316,7 @@ def master_only(func: Callable) -> Callable:
return wrapper
def barrier(group: Optional[dist.ProcessGroup] = None) -> None:
def barrier(group: Optional[ProcessGroup] = None) -> None:
"""Synchronize all processes from the given process group.
This collective blocks processes until the whole group enters this
......@@ -332,4 +334,166 @@ def barrier(group: Optional[dist.ProcessGroup] = None) -> None:
# passing in None for group argument
if group is None:
group = get_default_group()
dist.barrier(group)
torch_dist.barrier(group)
def get_data_device(data: Union[Tensor, Mapping, Iterable]) -> torch.device:
"""Return the device of ``data``.
If ``data`` is a sequence of Tensor, all items in ``data`` should have a
same device type.
If ``data`` is a dict whose values are Tensor, all values should have a
same device type.
Args:
data (Tensor or Sequence or dict): Inputs to be inferred the device.
Returns:
torch.device: The device of ``data``.
Examples:
>>> import torch
>>> from mmengine.dist import cast_data_device
>>> # data is a Tensor
>>> data = torch.tensor([0, 1])
>>> get_data_device(data)
device(type='cpu')
>>> # data is a list of Tensor
>>> data = [torch.tensor([0, 1]), torch.tensor([2, 3])]
>>> get_data_device(data)
device(type='cpu')
>>> # data is a dict
>>> data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([0, 1])}
>>> get_data_device(data)
device(type='cpu')
"""
if isinstance(data, Tensor):
return data.device
elif isinstance(data, Mapping):
pre = None
for v in data.values():
cur = get_data_device(v)
if pre is None:
pre = cur
else:
if cur != pre:
raise ValueError(
'device type in data should be consistent, but got '
f'{cur} and {pre}')
if pre is None:
raise ValueError('data should not be empty.')
return pre
elif isinstance(data, Iterable) and not isinstance(data, str):
pre = None
for item in data:
cur = get_data_device(item)
if pre is None:
pre = cur
else:
if cur != pre:
raise ValueError(
'device type in data should be consistent, but got '
f'{cur} and {pre}')
if pre is None:
raise ValueError('data should not be empty.')
return pre
else:
raise TypeError('data should be a Tensor, sequence of tensor or dict, '
f'but got {data}')
def get_comm_device(group: Optional[ProcessGroup] = None) -> torch.device:
"""Return the device for communication among groups.
Args:
group (ProcessGroup, optional): The process group to work on.
Returns:
torch.device: The device of backend.
"""
backend = get_backend(group)
if backend == torch_dist.Backend.NCCL:
return torch.device('cuda', torch.cuda.current_device())
else:
# GLOO and MPI backends use cpu device by default
return torch.device('cpu')
def cast_data_device(
data: Union[Tensor, Mapping, Iterable],
device: torch.device,
out: Optional[Union[Tensor, Mapping, Iterable]] = None
) -> Union[Tensor, Mapping, Iterable]:
"""Recursively convert Tensor in ``data`` to ``device``.
If ``data`` has already on the ``device``, it will not be casted again.
Args:
data (Tensor or list or dict): Inputs to be casted.
device (torch.device): Destination device type.
out (Tensor or list or dict, optional): If ``out`` is specified, its
value will be equal to ``data``. Defaults to None.
Returns:
Tensor or list or dict: ``data`` was casted to ``device``.
"""
if out is not None:
if type(data) != type(out):
raise TypeError(
'out should be the same type with data, but got data is '
f'{type(data)} and out is {type(data)}')
if isinstance(out, set):
raise TypeError('out should not be a set')
if isinstance(data, Tensor):
if get_data_device(data) == device:
data_on_device = data
else:
data_on_device = data.to(device)
if out is not None:
# modify the value of out inplace
out.copy_(data_on_device) # type: ignore
return data_on_device
elif isinstance(data, Mapping):
data_on_device = {}
if out is not None:
data_len = len(data)
out_len = len(out) # type: ignore
if data_len != out_len:
raise ValueError('length of data and out should be same, '
f'but got {data_len} and {out_len}')
for k, v in data.items():
data_on_device[k] = cast_data_device(v, device,
out[k]) # type: ignore
else:
for k, v in data.items():
data_on_device[k] = cast_data_device(v, device)
if len(data_on_device) == 0:
raise ValueError('data should not be empty')
# To ensure the type of output as same as input, we use `type(data)`
# to wrap the output
return type(data)(data_on_device) # type: ignore
elif isinstance(data, Iterable) and not isinstance(
data, str) and not isinstance(data, np.ndarray):
data_on_device = []
if out is not None:
for v1, v2 in zip(data, out):
data_on_device.append(cast_data_device(v1, device, v2))
else:
for v in data:
data_on_device.append(cast_data_device(v, device))
if len(data_on_device) == 0:
raise ValueError('data should not be empty')
return type(data)(data_on_device) # type: ignore
else:
raise TypeError('data should be a Tensor, list of tensor or dict, '
f'but got {data}')
......@@ -3,6 +3,7 @@ import os
import os.path as osp
import tempfile
import unittest
from itertools import product
from unittest import TestCase
from unittest.mock import patch
......@@ -247,6 +248,8 @@ class TestDistWithGLOOBackend(MultiProcessTestCase):
def test_all_gather_object(self):
self._init_dist_env(self.rank, self.world_size)
# data is a pickable python object
if dist.get_rank() == 0:
data = 'foo'
else:
......@@ -257,8 +260,21 @@ class TestDistWithGLOOBackend(MultiProcessTestCase):
self.assertEqual(output, expected)
# data is a list of pickable python object
if dist.get_rank() == 0:
data = ['foo', {1: 2}]
else:
data = {2: 3}
expected = [['foo', {1: 2}], {2: 3}]
output = dist.all_gather_object(data)
self.assertEqual(output, expected)
def test_gather_object(self):
self._init_dist_env(self.rank, self.world_size)
# data is a pickable python object
if dist.get_rank() == 0:
data = 'foo'
else:
......@@ -271,6 +287,19 @@ class TestDistWithGLOOBackend(MultiProcessTestCase):
else:
self.assertIsNone(output)
# data is a list of pickable python object
if dist.get_rank() == 0:
data = ['foo', {1: 2}]
else:
data = {2: 3}
output = dist.gather_object(data, dst=0)
if dist.get_rank() == 0:
self.assertEqual(output, [['foo', {1: 2}], {2: 3}])
else:
self.assertIsNone(output)
@unittest.skipIf(
torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl')
......@@ -293,44 +322,59 @@ class TestDistWithNCCLBackend(MultiProcessTestCase):
def test_all_reduce(self):
self._init_dist_env(self.rank, self.world_size)
for tensor_type, reduce_op in zip([torch.int64, torch.float32],
['sum', 'mean']):
tensor_types = [torch.int64, torch.float32]
reduce_ops = ['sum', 'mean']
device_types = ['cpu', 'cuda']
for tensor_type, reduce_op, device_type in product(
tensor_types, reduce_ops, device_types):
# 'mean' op does not support torch.int64
if tensor_type == torch.int64 and reduce_op == 'mean':
continue
if dist.get_rank() == 0:
data = torch.tensor([1, 2], dtype=tensor_type).cuda()
data = torch.tensor([1, 2], dtype=tensor_type).to(device_type)
else:
data = torch.tensor([3, 4], dtype=tensor_type).cuda()
data = torch.tensor([3, 4], dtype=tensor_type).to(device_type)
if reduce_op == 'sum':
expected = torch.tensor([4, 6], dtype=tensor_type).cuda()
expected = torch.tensor([4, 6],
dtype=tensor_type).to(device_type)
else:
expected = torch.tensor([2, 3], dtype=tensor_type).cuda()
expected = torch.tensor([2, 3],
dtype=tensor_type).to(device_type)
dist.all_reduce(data, reduce_op)
self.assertTrue(torch.allclose(data, expected))
def test_all_gather(self):
self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
data = torch.tensor([0, 1]).cuda()
else:
data = torch.tensor([1, 2]).cuda()
for device_type in ('cpu', 'cuda'):
if dist.get_rank() == 0:
data = torch.tensor([0, 1]).to(device_type)
else:
data = torch.tensor([1, 2]).to(device_type)
expected = [torch.tensor([0, 1]).cuda(), torch.tensor([1, 2]).cuda()]
expected = [
torch.tensor([0, 1]).to(device_type),
torch.tensor([1, 2]).to(device_type)
]
output = dist.all_gather(data)
self.assertTrue(
torch.allclose(output[dist.get_rank()], expected[dist.get_rank()]))
output = dist.all_gather(data)
self.assertTrue(
torch.allclose(output[dist.get_rank()],
expected[dist.get_rank()]))
def test_broadcast_dist(self):
self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
data = torch.tensor([0, 1]).cuda()
else:
data = torch.tensor([1, 2]).cuda()
for device_type in ('cpu', 'cuda'):
if dist.get_rank() == 0:
data = torch.tensor([0, 1]).to(device_type)
else:
data = torch.tensor([1, 2]).to(device_type)
expected = torch.tensor([0, 1]).cuda()
dist.broadcast(data, 0)
assert torch.allclose(data, expected)
expected = torch.tensor([0, 1]).to(device_type)
dist.broadcast(data, 0)
assert torch.allclose(data, expected)
def test_sync_random_seed(self):
self._init_dist_env(self.rank, self.world_size)
......@@ -354,28 +398,43 @@ class TestDistWithNCCLBackend(MultiProcessTestCase):
def test_all_reduce_dict(self):
self._init_dist_env(self.rank, self.world_size)
for tensor_type, reduce_op in zip([torch.int64, torch.float32],
['sum', 'mean']):
tensor_types = [torch.int64, torch.float32]
reduce_ops = ['sum', 'mean']
device_types = ['cpu', 'cuda']
for tensor_type, reduce_op, device_type in product(
tensor_types, reduce_ops, device_types):
# 'mean' op does not support torch.int64
if tensor_type == torch.int64 and reduce_op == 'mean':
continue
if dist.get_rank() == 0:
data = {
'key1': torch.tensor([0, 1], dtype=tensor_type).cuda(),
'key2': torch.tensor([1, 2], dtype=tensor_type).cuda(),
'key1':
torch.tensor([0, 1], dtype=tensor_type).to(device_type),
'key2':
torch.tensor([1, 2], dtype=tensor_type).to(device_type),
}
else:
data = {
'key1': torch.tensor([2, 3], dtype=tensor_type).cuda(),
'key2': torch.tensor([3, 4], dtype=tensor_type).cuda(),
'key1':
torch.tensor([2, 3], dtype=tensor_type).to(device_type),
'key2':
torch.tensor([3, 4], dtype=tensor_type).to(device_type),
}
if reduce_op == 'sum':
expected = {
'key1': torch.tensor([2, 4], dtype=tensor_type).cuda(),
'key2': torch.tensor([4, 6], dtype=tensor_type).cuda(),
'key1':
torch.tensor([2, 4], dtype=tensor_type).to(device_type),
'key2':
torch.tensor([4, 6], dtype=tensor_type).to(device_type),
}
else:
expected = {
'key1': torch.tensor([1, 2], dtype=tensor_type).cuda(),
'key2': torch.tensor([2, 3], dtype=tensor_type).cuda(),
'key1':
torch.tensor([1, 2], dtype=tensor_type).to(device_type),
'key2':
torch.tensor([2, 3], dtype=tensor_type).to(device_type),
}
dist.all_reduce_dict(data, reduce_op)
......@@ -385,30 +444,43 @@ class TestDistWithNCCLBackend(MultiProcessTestCase):
# `torch.cat` in torch1.5 can not concatenate different types so we
# fallback to convert them all to float type.
if digit_version(TORCH_VERSION) == digit_version('1.5.0'):
if dist.get_rank() == 0:
data = {
'key1': torch.tensor([0, 1], dtype=torch.float32).cuda(),
'key2': torch.tensor([1, 2], dtype=torch.int32).cuda(),
}
else:
data = {
'key1': torch.tensor([2, 3], dtype=torch.float32).cuda(),
'key2': torch.tensor([3, 4], dtype=torch.int32).cuda(),
}
for device_type in ('cpu', 'cuda'):
if digit_version(TORCH_VERSION) == digit_version('1.5.0'):
if dist.get_rank() == 0:
data = {
'key1':
torch.tensor([0, 1],
dtype=torch.float32).to(device_type),
'key2':
torch.tensor([1, 2],
dtype=torch.int32).to(device_type),
}
else:
data = {
'key1':
torch.tensor([2, 3],
dtype=torch.float32).to(device_type),
'key2':
torch.tensor([3, 4],
dtype=torch.int32).to(device_type),
}
expected = {
'key1': torch.tensor([2, 4], dtype=torch.float32).cuda(),
'key2': torch.tensor([4, 6], dtype=torch.float32).cuda(),
}
expected = {
'key1':
torch.tensor([2, 4], dtype=torch.float32).to(device_type),
'key2':
torch.tensor([4, 6], dtype=torch.float32).to(device_type),
}
dist.all_reduce_dict(data, 'sum')
dist.all_reduce_dict(data, 'sum')
for key in data:
assert torch.allclose(data[key], expected[key])
for key in data:
assert torch.allclose(data[key], expected[key])
def test_all_gather_object(self):
self._init_dist_env(self.rank, self.world_size)
# data is a pickable python object
if dist.get_rank() == 0:
data = 'foo'
else:
......@@ -419,8 +491,21 @@ class TestDistWithNCCLBackend(MultiProcessTestCase):
self.assertEqual(output, expected)
# data is a list of pickable python object
if dist.get_rank() == 0:
data = ['foo', {1: 2}]
else:
data = {2: 3}
expected = [['foo', {1: 2}], {2: 3}]
output = dist.all_gather_object(data)
self.assertEqual(output, expected)
def test_collect_results(self):
self._init_dist_env(self.rank, self.world_size)
# 1. test `device` and `tmpdir` parameters
if dist.get_rank() == 0:
data = ['foo', {1: 2}]
else:
......@@ -430,14 +515,14 @@ class TestDistWithNCCLBackend(MultiProcessTestCase):
expected = ['foo', 24, {1: 2}, {'a': 'b'}]
# test `device=cpu`
# 1.1 test `device=cpu` and `tmpdir` is None
output = dist.collect_results(data, size, device='cpu')
if dist.get_rank() == 0:
self.assertEqual(output, expected)
else:
self.assertIsNone(output)
# test `device=cpu` and `tmpdir is not None`
# 1.2 test `device=cpu` and `tmpdir` is not None
tmpdir = tempfile.mkdtemp()
# broadcast tmpdir to all ranks to make it consistent
object_list = [tmpdir]
......@@ -453,7 +538,31 @@ class TestDistWithNCCLBackend(MultiProcessTestCase):
# object_list[0] will be removed by `dist.collect_results`
self.assertFalse(osp.exists(object_list[0]))
# test `device=gpu`
# 1.3 test `device=gpu`
output = dist.collect_results(data, size, device='gpu')
if dist.get_rank() == 0:
self.assertEqual(output, expected)
else:
self.assertIsNone(output)
# 2. test `size` parameter
if dist.get_rank() == 0:
data = ['foo', {1: 2}]
else:
data = [24, {'a': 'b'}]
size = 3
expected = ['foo', 24, {1: 2}]
# 2.1 test `device=cpu` and `tmpdir` is None
output = dist.collect_results(data, size, device='cpu')
if dist.get_rank() == 0:
self.assertEqual(output, expected)
else:
self.assertIsNone(output)
# 2.2 test `device=gpu`
output = dist.collect_results(data, size, device='gpu')
if dist.get_rank() == 0:
self.assertEqual(output, expected)
......
......@@ -3,6 +3,7 @@ import os
import unittest
from unittest import TestCase
import numpy as np
import torch
import torch.distributed as torch_dist
......@@ -44,6 +45,170 @@ class TestUtils(TestCase):
def test_barrier(self):
dist.barrier() # nothing is done
def test_get_data_device(self):
# data is a Tensor
data = torch.tensor([0, 1])
self.assertEqual(dist.get_data_device(data), torch.device('cpu'))
# data is a list of Tensor
data = [torch.tensor([0, 1]), torch.tensor([2, 3])]
self.assertEqual(dist.get_data_device(data), torch.device('cpu'))
# data is a list but not all items are Tensor
data = [torch.tensor([0, 1]), 123]
with self.assertRaises(TypeError):
dist.get_data_device(data)
# data is a list containing Tensor and a dict
data = [torch.tensor([0, 1]), {'key': torch.tensor([2, 3])}]
self.assertEqual(dist.get_data_device(data), torch.device('cpu'))
# data is a list containing Tensor and a dict but the dict contains
# invalid type
data = [torch.tensor([0, 1]), {'key': '123'}]
with self.assertRaises(TypeError):
dist.get_data_device(data)
# data is a empty list
with self.assertRaises(ValueError):
dist.get_data_device([])
# data is a dict
data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([0, 1])}
self.assertEqual(dist.get_data_device(data), torch.device('cpu'))
# data is a dict but not all values are Tensor
data = {'key1': torch.tensor([0, 1]), 'key2': 123}
with self.assertRaises(TypeError):
dist.get_data_device(data)
# data is a dict and one of values is list of Tensor
data = {'key1': torch.tensor([0, 1]), 'key2': [torch.tensor([0, 1])]}
self.assertEqual(dist.get_data_device(data), torch.device('cpu'))
# data is a dict and one of values is an invalid type
data = {'key1': torch.tensor([0, 1]), 'key2': ['123']}
with self.assertRaises(TypeError):
dist.get_data_device(data)
# data is a empty dict
with self.assertRaises(ValueError):
dist.get_data_device({})
# data is not a valid type
with self.assertRaisesRegex(
TypeError,
'data should be a Tensor, sequence of tensor or dict'):
dist.get_data_device('123')
@unittest.skipIf(
torch.cuda.device_count() == 0, reason='at lest need 1 gpu to test')
def test_cast_data_device(self):
expected_device = torch.device('cuda', torch.cuda.current_device())
# data is a Tensor
data = torch.tensor([0, 1])
output = dist.cast_data_device(data, expected_device)
self.assertEqual(output.device, expected_device)
# data is a Tensor and out is also a Tensor
data = torch.tensor([0, 1])
out = torch.tensor([1, 2])
output = dist.cast_data_device(data, expected_device, out=out)
self.assertEqual(output.device, expected_device)
self.assertTrue(torch.allclose(output.cpu(), out))
# data is a list of Tensor
data = [torch.tensor([0, 1]), torch.tensor([2, 3])]
for item in dist.cast_data_device(data, expected_device):
self.assertEqual(item.device, expected_device)
# both data and out are list of tensor
data = [torch.tensor([0, 1]), torch.tensor([2, 3])]
out = [torch.tensor([3, 4]), torch.tensor([5, 6])]
output = dist.cast_data_device(data, expected_device, out=out)
for item1, item2 in zip(output, out):
self.assertEqual(item1.device, expected_device)
self.assertTrue(torch.allclose(item1.cpu(), item2))
# data is a list containing a Tensor and a dict
data = [torch.tensor([0, 1]), {'key': torch.tensor([2, 3])}]
output = dist.cast_data_device(data, expected_device)
self.assertEqual(output[0].device, expected_device)
self.assertEqual(output[1]['key'].device, expected_device)
# data is a list containing a Tensor and a dict, so does out
data = [torch.tensor([0, 1]), {'key': torch.tensor([2, 3])}]
out = [torch.tensor([3, 4]), {'key': torch.tensor([5, 6])}]
output = dist.cast_data_device(data, expected_device, out=out)
self.assertEqual(output[0].device, expected_device)
self.assertTrue(torch.allclose(output[0].cpu(), out[0]))
self.assertEqual(output[1]['key'].device, expected_device)
self.assertTrue(torch.allclose(output[1]['key'].cpu(), out[1]['key']))
# data is an empty list
with self.assertRaisesRegex(ValueError, 'data should not be empty'):
dist.cast_data_device([], expected_device)
# data is a dict
data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([2, 3])}
output = dist.cast_data_device(data, expected_device)
for k, v in output.items():
self.assertEqual(v.device, expected_device)
# data is a dict, so does out
data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([2, 3])}
out = {'key1': torch.tensor([3, 4]), 'key2': torch.tensor([5, 6])}
output = dist.cast_data_device(data, expected_device, out=out)
for k, v in output.items():
self.assertEqual(v.device, expected_device)
self.assertTrue(torch.allclose(v.cpu(), out[k]))
# the length of data and out should be same
data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([2, 3])}
out = {'key1': torch.tensor([3, 4])}
with self.assertRaisesRegex(ValueError,
'length of data and out should be same'):
dist.cast_data_device(data, expected_device, out=out)
# data is an empty dict
with self.assertRaisesRegex(ValueError, 'data should not be empty'):
dist.cast_data_device({}, expected_device)
# data is a dict and one of values is list
data = {'key1': torch.tensor([0, 1]), 'key2': [torch.tensor([2, 3])]}
out = {'key1': torch.tensor([3, 4]), 'key2': [torch.tensor([5, 6])]}
output = dist.cast_data_device(data, expected_device, out=out)
self.assertEqual(output['key1'].device, expected_device)
self.assertTrue(torch.allclose(output['key1'].cpu(), out['key1']))
self.assertEqual(output['key2'][0].device, expected_device)
self.assertTrue(
torch.allclose(output['key2'][0].cpu(), out['key2'][0]))
# data is not a valid type
with self.assertRaisesRegex(
TypeError, 'data should be a Tensor, list of tensor or dict'):
dist.cast_data_device(123, expected_device)
with self.assertRaisesRegex(
TypeError, 'data should be a Tensor, list of tensor or dict'):
dist.cast_data_device('123', expected_device)
with self.assertRaisesRegex(
TypeError, 'data should be a Tensor, list of tensor or dict'):
dist.cast_data_device(np.array([0, 1]), expected_device)
# data and out are not the same type
data = torch.tensor([0, 1])
out = '123'
with self.assertRaisesRegex(TypeError,
'out should be the same type with data'):
dist.cast_data_device(data, expected_device, out=out)
data = set([0, 1])
out = set([2, 3])
with self.assertRaisesRegex(TypeError, 'out should not be a set'):
dist.cast_data_device(data, expected_device, out=out)
class TestUtilsWithGLOOBackend(MultiProcessTestCase):
......@@ -108,6 +273,69 @@ class TestUtilsWithGLOOBackend(MultiProcessTestCase):
fun()
def test_get_data_device(self):
self._init_dist_env(self.rank, self.world_size)
# data is a Tensor
data = torch.tensor([0, 1])
self.assertEqual(dist.get_data_device(data), torch.device('cpu'))
# data is a list of Tensor
data = [torch.tensor([0, 1]), torch.tensor([2, 3])]
self.assertEqual(dist.get_data_device(data), torch.device('cpu'))
# data is a list but not all items are Tensor
data = [torch.tensor([0, 1]), 123]
with self.assertRaises(TypeError):
dist.get_data_device(data)
# data is a list containing Tensor and a dict
data = [torch.tensor([0, 1]), {'key': torch.tensor([2, 3])}]
self.assertEqual(dist.get_data_device(data), torch.device('cpu'))
# data is a list containing Tensor and a dict but the dict contains
# invalid type
data = [torch.tensor([0, 1]), {'key': '123'}]
with self.assertRaises(TypeError):
dist.get_data_device(data)
# data is a empty list
with self.assertRaises(ValueError):
dist.get_data_device([])
# data is a dict
data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([0, 1])}
self.assertEqual(dist.get_data_device(data), torch.device('cpu'))
# data is a dict but not all values are Tensor
data = {'key1': torch.tensor([0, 1]), 'key2': 123}
with self.assertRaises(TypeError):
dist.get_data_device(data)
# data is a dict and one of values is list of Tensor
data = {'key1': torch.tensor([0, 1]), 'key2': [torch.tensor([0, 1])]}
self.assertEqual(dist.get_data_device(data), torch.device('cpu'))
# data is a dict and one of values is an invalid type
data = {'key1': torch.tensor([0, 1]), 'key2': ['123']}
with self.assertRaises(TypeError):
dist.get_data_device(data)
# data is a empty dict
with self.assertRaises(ValueError):
dist.get_data_device({})
# data is not a valid type
with self.assertRaisesRegex(
TypeError,
'data should be a Tensor, sequence of tensor or dict'):
dist.get_data_device('123')
def test_get_comm_device(self):
self._init_dist_env(self.rank, self.world_size)
group = dist.get_default_group()
assert dist.get_comm_device(group) == torch.device('cpu')
@unittest.skipIf(
torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl')
......@@ -175,3 +403,206 @@ class TestUtilsWithNCCLBackend(MultiProcessTestCase):
assert dist.get_rank() == 0
fun()
def test_get_data_device(self):
self._init_dist_env(self.rank, self.world_size)
expected_device = torch.device('cuda', torch.cuda.current_device())
# data is a Tensor
data = torch.tensor([0, 1]).to(expected_device)
self.assertEqual(dist.get_data_device(data), expected_device)
# data is a list of Tensor
data = [
torch.tensor([0, 1]).to(expected_device),
torch.tensor([2, 3]).to(expected_device)
]
self.assertEqual(dist.get_data_device(data), expected_device)
# data is a list but not all items are Tensor
data = [torch.tensor([0, 1]).to(expected_device), 123]
with self.assertRaises(TypeError):
dist.get_data_device(data)
# data is a list of Tensor but not all items have the same device type
data = [torch.tensor([0, 1]), torch.tensor([2, 3]).to(expected_device)]
with self.assertRaises(ValueError):
dist.get_data_device(data)
# data is a list containing Tensor and a dict
data = [
torch.tensor([0, 1]).to(expected_device), {
'key': torch.tensor([2, 3]).to(expected_device)
}
]
self.assertEqual(dist.get_data_device(data), expected_device)
# data is a list containing Tensor and a dict but the dict contains
# invalid type
data = [torch.tensor([0, 1]).to(expected_device), {'key': '123'}]
with self.assertRaises(TypeError):
dist.get_data_device(data)
# data is a empty list
with self.assertRaises(ValueError):
dist.get_data_device([])
# data is a dict
data = {
'key1': torch.tensor([0, 1]).to(expected_device),
'key2': torch.tensor([0, 1]).to(expected_device)
}
self.assertEqual(dist.get_data_device(data), expected_device)
# data is a dict but not all values are Tensor
data = {'key1': torch.tensor([0, 1]).to(expected_device), 'key2': 123}
with self.assertRaises(TypeError):
dist.get_data_device(data)
# data is a dict but not all values have the same device type
data = {
'key1': torch.tensor([0, 1]),
'key2': torch.tensor([0, 1]).to(expected_device)
}
with self.assertRaises(ValueError):
dist.get_data_device(data)
# data is a dict and one of values is list of Tensor
data = {
'key1': torch.tensor([0, 1]).to(expected_device),
'key2': [torch.tensor([0, 1]).to(expected_device)]
}
self.assertEqual(dist.get_data_device(data), expected_device)
# data is a dict and one of values is an invalid type
data = {
'key1': torch.tensor([0, 1]).to(expected_device),
'key2': ['123']
}
with self.assertRaises(TypeError):
dist.get_data_device(data)
# data is a empty dict
with self.assertRaises(ValueError):
dist.get_data_device({})
# data is not a valid type
with self.assertRaisesRegex(
TypeError,
'data should be a Tensor, sequence of tensor or dict'):
dist.get_data_device('123')
def test_get_comm_device(self):
self._init_dist_env(self.rank, self.world_size)
group = dist.get_default_group()
expected = torch.device('cuda', torch.cuda.current_device())
self.assertEqual(dist.get_comm_device(group), expected)
def test_cast_data_device(self):
self._init_dist_env(self.rank, self.world_size)
expected_device = torch.device('cuda', torch.cuda.current_device())
# data is a Tensor
data = torch.tensor([0, 1])
output = dist.cast_data_device(data, expected_device)
self.assertEqual(output.device, expected_device)
# data is a Tensor and out is also a Tensor
data = torch.tensor([0, 1])
out = torch.tensor([1, 2])
output = dist.cast_data_device(data, expected_device, out=out)
self.assertEqual(output.device, expected_device)
self.assertTrue(torch.allclose(output.cpu(), out))
# data is a list of Tensor
data = [torch.tensor([0, 1]), torch.tensor([2, 3])]
for item in dist.cast_data_device(data, expected_device):
self.assertEqual(item.device, expected_device)
# both data and out are list of tensor
data = [torch.tensor([0, 1]), torch.tensor([2, 3])]
out = [torch.tensor([3, 4]), torch.tensor([5, 6])]
output = dist.cast_data_device(data, expected_device, out=out)
for item1, item2 in zip(output, out):
self.assertEqual(item1.device, expected_device)
self.assertTrue(torch.allclose(item1.cpu(), item2))
# data is a list containing a Tensor and a dict
data = [torch.tensor([0, 1]), {'key': torch.tensor([2, 3])}]
output = dist.cast_data_device(data, expected_device)
self.assertEqual(output[0].device, expected_device)
self.assertEqual(output[1]['key'].device, expected_device)
# data is a list containing a Tensor and a dict, so does out
data = [torch.tensor([0, 1]), {'key': torch.tensor([2, 3])}]
out = [torch.tensor([3, 4]), {'key': torch.tensor([5, 6])}]
output = dist.cast_data_device(data, expected_device, out=out)
self.assertEqual(output[0].device, expected_device)
self.assertTrue(torch.allclose(output[0].cpu(), out[0]))
self.assertEqual(output[1]['key'].device, expected_device)
self.assertTrue(torch.allclose(output[1]['key'].cpu(), out[1]['key']))
# data is an empty list
with self.assertRaisesRegex(ValueError, 'data should not be empty'):
dist.cast_data_device([], expected_device)
# data is a dict
data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([2, 3])}
output = dist.cast_data_device(data, expected_device)
for k, v in output.items():
self.assertEqual(v.device, expected_device)
# data is a dict, so does out
data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([2, 3])}
out = {'key1': torch.tensor([3, 4]), 'key2': torch.tensor([5, 6])}
output = dist.cast_data_device(data, expected_device, out=out)
for k, v in output.items():
self.assertEqual(v.device, expected_device)
self.assertTrue(torch.allclose(v.cpu(), out[k]))
# the length of data and out should be same
data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([2, 3])}
out = {'key1': torch.tensor([3, 4])}
with self.assertRaisesRegex(ValueError,
'length of data and out should be same'):
dist.cast_data_device(data, expected_device, out=out)
# data is an empty dict
with self.assertRaisesRegex(ValueError, 'data should not be empty'):
dist.cast_data_device({}, expected_device)
# data is a dict and one of values is list
data = {'key1': torch.tensor([0, 1]), 'key2': [torch.tensor([2, 3])]}
out = {'key1': torch.tensor([3, 4]), 'key2': [torch.tensor([5, 6])]}
output = dist.cast_data_device(data, expected_device, out=out)
self.assertEqual(output['key1'].device, expected_device)
self.assertTrue(torch.allclose(output['key1'].cpu(), out['key1']))
self.assertEqual(output['key2'][0].device, expected_device)
self.assertTrue(
torch.allclose(output['key2'][0].cpu(), out['key2'][0]))
# data is not a valid type
with self.assertRaisesRegex(
TypeError, 'data should be a Tensor, list of tensor or dict'):
dist.cast_data_device(123, expected_device)
with self.assertRaisesRegex(
TypeError, 'data should be a Tensor, list of tensor or dict'):
dist.cast_data_device('123', expected_device)
with self.assertRaisesRegex(
TypeError, 'data should be a Tensor, list of tensor or dict'):
dist.cast_data_device(np.array([0, 1]), expected_device)
# data and out are not the same type
data = torch.tensor([0, 1])
out = '123'
with self.assertRaisesRegex(TypeError,
'out should be the same type with data'):
dist.cast_data_device(data, expected_device, out=out)
data = set([0, 1])
out = set([2, 3])
with self.assertRaisesRegex(TypeError, 'out should not be a set'):
dist.cast_data_device(data, expected_device, out=out)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment