Skip to content
Snippets Groups Projects
Unverified Commit 98c85529 authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

[Refactor] Replace torch distributed with mmengine dist module (#196)

* [Fix] Replace torch distributed with mmengine dist module

* minor refinement

* move all_reduce_params to dist.py

* add unit tests

* update unit tests

* fix test_logger.py

* add examples
parent e37f1f90
No related branches found
No related tags found
No related merge requests found
......@@ -2,7 +2,7 @@
from .dist import (all_gather_object, all_reduce, all_gather, all_reduce_dict,
collect_results, gather, broadcast, gather_object,
sync_random_seed, broadcast_object_list,
collect_results_cpu, collect_results_gpu)
collect_results_cpu, collect_results_gpu, all_reduce_params)
from .utils import (get_dist_info, init_dist, init_local_group, get_backend,
get_world_size, get_rank, get_local_size, get_local_rank,
is_main_process, master_only, barrier, get_local_group,
......@@ -16,6 +16,6 @@ __all__ = [
'get_dist_info', 'init_dist', 'init_local_group', 'get_backend',
'get_world_size', 'get_rank', 'get_local_size', 'get_local_group',
'get_local_rank', 'is_main_process', 'master_only', 'barrier',
'is_distributed', 'get_default_group', 'get_data_device',
'get_comm_device', 'cast_data_device'
'is_distributed', 'get_default_group', 'all_reduce_params',
'get_data_device', 'get_comm_device', 'cast_data_device'
]
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, List, Optional, Tuple, Dict
from typing import Any, List, Optional, Tuple, Dict, Generator, Union
from collections import OrderedDict
import shutil
import pickle
import numpy as np
......@@ -7,6 +8,8 @@ import tempfile
import torch
import os.path as osp
from torch import Tensor
from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors)
from torch import distributed as torch_dist
from torch.distributed import ProcessGroup
......@@ -805,7 +808,7 @@ def gather_object(data: Any,
the object must be picklable in order to be gathered.
Note:
``NCCL backend`` dost not support ``gather_object``.
``NCCL backend`` does not support ``gather_object``.
Note:
Unlike PyTorch ``torch.distributed.gather_object``,
......@@ -1036,3 +1039,92 @@ def collect_results_gpu(result_part: list, size: int) -> Optional[list]:
return ordered_results
else:
return None
def _all_reduce_coalesced(tensors: List[torch.Tensor],
bucket_size_mb: int = -1,
op: str = 'sum',
group: Optional[ProcessGroup] = None) -> None:
"""All-reduce a sequence of tensors as a whole.
Args:
tensors (List[torch.Tensor]): A sequence of tensors to be
all-reduced.
bucket_size_mb (int): The limit of each chunk in megabytes
for grouping tensors into chunks. Defaults to -1.
op (str): Operation to reduce data. Defaults to 'sum'. Optional values
are 'sum', 'mean' and 'produce', 'min', 'max', 'band', 'bor' and
'bxor'.
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used. Defaults to None.
"""
if bucket_size_mb > 0:
bucket_size_bytes = bucket_size_mb * 1024 * 1024
buckets = _take_tensors(tensors, bucket_size_bytes)
else:
buckets = OrderedDict()
for tensor in tensors:
tp = tensor.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(tensor)
buckets = buckets.values()
for bucket in buckets:
flat_tensors = _flatten_dense_tensors(bucket)
all_reduce(flat_tensors, op=op, group=group)
for tensor, synced in zip(
bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
tensor.copy_(synced)
def all_reduce_params(params: Union[List, Generator[torch.Tensor, None, None]],
coalesce: bool = True,
bucket_size_mb: int = -1,
op: str = 'sum',
group: Optional[ProcessGroup] = None) -> None:
"""All-reduce parameters.
Args:
params (List or Generator[torch.Tensor, None, None]): List of
parameters or buffers of a model.
coalesce (bool, optional): Whether to reduce parameters as a whole.
Defaults to True.
bucket_size_mb (int, optional): Size of bucket, the unit is MB.
Defaults to -1.
op (str): Operation to reduce data. Defaults to 'sum'. Optional values
are 'sum', 'mean' and 'produce', 'min', 'max', 'band', 'bor' and
'bxor'.
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used. Defaults to None.
Examples:
>>> import torch
>>> import mmengine.dist as dist
>>> # non-distributed environment
>>> data = [torch.arange(2), torch.arange(3)]
>>> dist.all_reduce_params(data)
>>> data
[tensor([0, 1]), tensor([0, 1, 2])]
>>> # distributed environment
>>> # We have 2 process groups, 2 ranks.
>>> if dist.get_rank() == 0:
... data = [torch.tensor([1, 2]), torch.tensor([3, 4])]
... else:
... data = [torch.tensor([2, 3]), torch.tensor([4, 5])]
>>> dist.all_reduce_params(data)
>>> data
[torch.tensor([3, 5]), torch.tensor([7, 9])]
"""
world_size = get_world_size(group)
if world_size == 1:
return
params_data = [param.data for param in params]
if coalesce:
_all_reduce_coalesced(params_data, bucket_size_mb, op=op, group=group)
else:
for tensor in params_data:
all_reduce(tensor, op=op, group=group)
# Copyright (c) OpenMMLab. All rights reserved.
# from mmengine.dist import get_dist_info, all_reduce
from collections import OrderedDict
from typing import Generator, List
from unittest.mock import MagicMock, Mock
import torch
from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors)
from mmengine import dist
from mmengine.registry import HOOKS
from .hook import Hook
# TODO, replace with import mmengine.dist as dist
dist = Mock()
dist.IS_DIST = MagicMock(return_value=True)
# TODO, replace with mmengine.dist.get_dist_info
get_dist_info = MagicMock(return_value=(0, 1))
# TODO, replace with mmengine.dist.all_reduce
all_reduce = MagicMock()
# TODO, may need to move to dist.utils after implementing dist module
def _allreduce_coalesced(tensors: List[torch.Tensor],
world_size: int,
bucket_size_mb: int = -1) -> None:
"""All-reduce a sequence of tensors as a whole.
Args:
tensors (List[torch.Tensor]): A sequence of tensors to be
all-reduced.
world_size (int): The world size of the process group.
bucket_size_mb (int): The limit of each chunk in megabytes
for grouping tensors into chunks. Defaults to -1.
"""
if bucket_size_mb > 0:
bucket_size_bytes = bucket_size_mb * 1024 * 1024
buckets = _take_tensors(tensors, bucket_size_bytes)
else:
buckets = OrderedDict()
for tensor in tensors:
tp = tensor.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(tensor)
buckets = buckets.values()
for bucket in buckets:
flat_tensors = _flatten_dense_tensors(bucket)
all_reduce(flat_tensors)
flat_tensors.div_(world_size)
for tensor, synced in zip(
bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
tensor.copy_(synced)
def allreduce_params(params: Generator[torch.Tensor, None, None],
coalesce: bool = True,
bucket_size_mb: int = -1) -> None:
"""All-reduce parameters.
Args:
params (Generator[torch.Tensor, None, None]): List of parameters or
buffers of a model.
coalesce (bool, optional): Whether to reduce parameters as a whole.
Defaults to True.
bucket_size_mb (int, optional): Size of bucket, the unit is MB.
Defaults to -1.
"""
_, world_size = get_dist_info()
if world_size == 1:
return
params_data = [param.data for param in params]
if coalesce:
_allreduce_coalesced(params_data, world_size, bucket_size_mb)
else:
for tensor in params_data:
all_reduce(tensor.div_(world_size))
@HOOKS.register_module()
class SyncBuffersHook(Hook):
......@@ -87,7 +12,7 @@ class SyncBuffersHook(Hook):
priority = 'NORMAL'
def __init__(self) -> None:
self.distributed = dist.IS_DIST
self.distributed = dist.is_distributed()
def after_train_epoch(self, runner) -> None:
"""All-reduce model buffers at the end of each epoch.
......@@ -96,4 +21,4 @@ class SyncBuffersHook(Hook):
runner (Runner): The runner of the training process.
"""
if self.distributed:
allreduce_params(runner.model.buffers())
dist.all_reduce_params(runner.model.buffers(), op='mean')
......@@ -5,9 +5,9 @@ import sys
from logging import Logger, LogRecord
from typing import Optional, Union
import torch.distributed as dist
from termcolor import colored
from mmengine import dist
from mmengine.utils import ManagerMixin
......@@ -144,10 +144,8 @@ class MMLogger(Logger, ManagerMixin):
Logger.__init__(self, logger_name)
ManagerMixin.__init__(self, name)
# Get rank in DDP mode.
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
else:
rank = 0
rank = dist.get_rank()
# Config stream_handler. If `rank != 0`. stream_handler can only
# export ERROR logs.
stream_handler = logging.StreamHandler(stream=sys.stdout)
......
......@@ -579,9 +579,6 @@ class Runner:
self._rank, self._world_size = get_dist_info()
timestamp = torch.tensor(time.time(), dtype=torch.float64)
# TODO: handled by broadcast
if self._world_size > 1 and torch.cuda.is_available():
timestamp = timestamp.cuda()
# broadcast timestamp from 0 process to other processes
broadcast(timestamp)
self._timestamp = time.strftime('%Y%m%d_%H%M%S',
......
......@@ -99,6 +99,22 @@ class TestDist(TestCase):
output = dist.collect_results(data, size, device='gpu')
self.assertEqual(output, expected)
def test_all_reduce_params(self):
for tensor_type, reduce_op in zip([torch.int64, torch.float32],
['sum', 'mean']):
data = [
torch.tensor([0, 1], dtype=tensor_type) for _ in range(100)
]
data_gen = (item for item in data)
expected = [
torch.tensor([0, 1], dtype=tensor_type) for _ in range(100)
]
dist.all_reduce_params(data_gen, op=reduce_op)
for item1, item2 in zip(data, expected):
self.assertTrue(torch.allclose(item1, item2))
class TestDistWithGLOOBackend(MultiProcessTestCase):
......@@ -300,6 +316,39 @@ class TestDistWithGLOOBackend(MultiProcessTestCase):
else:
self.assertIsNone(output)
def test_all_reduce_params(self):
self._init_dist_env(self.rank, self.world_size)
tensor_types = [torch.int64, torch.float32]
reduce_ops = ['sum', 'mean']
coalesces = [True, False]
for tensor_type, reduce_op, coalesce in zip(tensor_types, reduce_ops,
coalesces):
if dist.get_rank() == 0:
data = [
torch.tensor([0, 1], dtype=tensor_type) for _ in range(100)
]
else:
data = (
torch.tensor([2, 3], dtype=tensor_type)
for _ in range(100))
data_gen = (item for item in data)
if reduce_op == 'sum':
expected = (
torch.tensor([2, 4], dtype=tensor_type)
for _ in range(100))
else:
expected = (
torch.tensor([1, 2], dtype=tensor_type)
for _ in range(100))
dist.all_reduce_params(data_gen, coalesce=coalesce, op=reduce_op)
for item1, item2 in zip(data, expected):
self.assertTrue(torch.allclose(item1, item2))
@unittest.skipIf(
torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl')
......@@ -568,3 +617,37 @@ class TestDistWithNCCLBackend(MultiProcessTestCase):
self.assertEqual(output, expected)
else:
self.assertIsNone(output)
def test_all_reduce_params(self):
self._init_dist_env(self.rank, self.world_size)
tensor_types = [torch.int64, torch.float32]
reduce_ops = ['sum', 'mean']
coalesces = [True, False]
device_types = ['cpu', 'cuda']
for tensor_type, reduce_op, coalesce, device_type in zip(
tensor_types, reduce_ops, coalesces, device_types):
if dist.get_rank() == 0:
data = [
torch.tensor([0, 1], dtype=tensor_type).to(device_type)
for _ in range(100)
]
else:
data = [
torch.tensor([2, 3], dtype=tensor_type).to(device_type)
for _ in range(100)
]
data_gen = (item for item in data)
if reduce_op == 'sum':
expected = (
torch.tensor([2, 4], dtype=tensor_type).to(device_type)
for _ in range(100))
else:
expected = (
torch.tensor([1, 2], dtype=tensor_type).to(device_type)
for _ in range(100))
for item1, item2 in zip(data_gen, expected):
self.assertTrue(torch.allclose(item1, item2))
......@@ -15,9 +15,7 @@ class TestLogger:
stream_handler_regex_time = r'\d{2}/\d{2} \d{2}:\d{2}:\d{2}'
file_handler_regex_time = r'\d{4}/\d{2}/\d{2} \d{2}:\d{2}:\d{2}'
@patch('torch.distributed.get_rank', lambda: 0)
@patch('torch.distributed.is_initialized', lambda: True)
@patch('torch.distributed.is_available', lambda: True)
@patch('mmengine.dist.get_rank', lambda: 0)
def test_init_rank0(self, tmp_path):
logger = MMLogger.get_instance('rank0.pkg1', log_level='INFO')
assert logger.name == 'mmengine'
......@@ -47,9 +45,7 @@ class TestLogger:
assert logger.instance_name == 'rank0.pkg3'
logging.shutdown()
@patch('torch.distributed.get_rank', lambda: 1)
@patch('torch.distributed.is_initialized', lambda: True)
@patch('torch.distributed.is_available', lambda: True)
@patch('mmengine.dist.get_rank', lambda: 1)
def test_init_rank1(self, tmp_path):
# If `rank!=1`, the `loglevel` of file_handler is `logging.ERROR`.
tmp_file = tmp_path / 'tmp_file.log'
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment