diff --git a/mmengine/dataset/base_dataset.py b/mmengine/dataset/base_dataset.py index 179d003836a7a9c7a03f653106a66ca5452cddf5..17f707d3bb8cdc099b876082cb14bba6afecc625 100644 --- a/mmengine/dataset/base_dataset.py +++ b/mmengine/dataset/base_dataset.py @@ -229,7 +229,7 @@ class BaseDataset(Dataset): self.test_mode = test_mode self.max_refetch = max_refetch self.data_list: List[dict] = [] - self.date_bytes: np.ndarray + self.data_bytes: np.ndarray # Set meta information. self._metainfo = self._get_meta_info(copy.deepcopy(metainfo)) @@ -259,7 +259,7 @@ class BaseDataset(Dataset): start_addr = 0 if idx == 0 else self.data_address[idx - 1].item() end_addr = self.data_address[idx].item() bytes = memoryview( - self.date_bytes[start_addr:end_addr]) # type: ignore + self.data_bytes[start_addr:end_addr]) # type: ignore data_info = pickle.loads(bytes) # type: ignore else: data_info = self.data_list[idx] @@ -302,7 +302,7 @@ class BaseDataset(Dataset): # serialize data_list if self.serialize_data: - self.date_bytes, self.data_address = self._serialize_data() + self.data_bytes, self.data_address = self._serialize_data() self._fully_initialized = True @@ -575,7 +575,7 @@ class BaseDataset(Dataset): # Get subset of data from serialized data or data information sequence # according to `self.serialize_data`. if self.serialize_data: - self.date_bytes, self.data_address = \ + self.data_bytes, self.data_address = \ self._get_serialized_subset(indices) else: self.data_list = self._get_unserialized_subset(indices) @@ -626,9 +626,9 @@ class BaseDataset(Dataset): sub_dataset = self._copy_without_annotation() # Get subset of dataset with serialize and unserialized data. if self.serialize_data: - date_bytes, data_address = \ + data_bytes, data_address = \ self._get_serialized_subset(indices) - sub_dataset.date_bytes = date_bytes.copy() + sub_dataset.data_bytes = data_bytes.copy() sub_dataset.data_address = data_address.copy() else: data_list = self._get_unserialized_subset(indices) @@ -650,7 +650,7 @@ class BaseDataset(Dataset): Tuple[np.ndarray, np.ndarray]: subset of serialized data information. """ - sub_date_bytes: Union[List, np.ndarray] + sub_data_bytes: Union[List, np.ndarray] sub_data_address: Union[List, np.ndarray] if isinstance(indices, int): if indices >= 0: @@ -661,7 +661,7 @@ class BaseDataset(Dataset): if indices > 0 else 0 # Slicing operation of `np.ndarray` does not trigger a memory # copy. - sub_date_bytes = self.date_bytes[:end_addr] + sub_data_bytes = self.data_bytes[:end_addr] # Since the buffer size of first few data information is not # changed, sub_data_address = self.data_address[:indices] @@ -671,11 +671,11 @@ class BaseDataset(Dataset): # Return the last few data information. ignored_bytes_size = self.data_address[indices - 1] start_addr = self.data_address[indices - 1].item() - sub_date_bytes = self.date_bytes[start_addr:] + sub_data_bytes = self.data_bytes[start_addr:] sub_data_address = self.data_address[indices:] sub_data_address = sub_data_address - ignored_bytes_size elif isinstance(indices, Sequence): - sub_date_bytes = [] + sub_data_bytes = [] sub_data_address = [] for idx in indices: assert len(self) > idx >= -len(self) @@ -683,20 +683,20 @@ class BaseDataset(Dataset): self.data_address[idx - 1].item() end_addr = self.data_address[idx].item() # Get data information by address. - sub_date_bytes.append(self.date_bytes[start_addr:end_addr]) + sub_data_bytes.append(self.data_bytes[start_addr:end_addr]) # Get data information size. sub_data_address.append(end_addr - start_addr) # Handle indices is an empty list. - if sub_date_bytes: - sub_date_bytes = np.concatenate(sub_date_bytes) + if sub_data_bytes: + sub_data_bytes = np.concatenate(sub_data_bytes) sub_data_address = np.cumsum(sub_data_address) else: - sub_date_bytes = np.array([]) + sub_data_bytes = np.array([]) sub_data_address = np.array([]) else: raise TypeError('indices should be a int or sequence of int, ' f'but got {type(indices)}') - return sub_date_bytes, sub_data_address # type: ignore + return sub_data_bytes, sub_data_address # type: ignore def _get_unserialized_subset(self, indices: Union[Sequence[int], int]) -> list: @@ -795,7 +795,7 @@ class BaseDataset(Dataset): def _copy_without_annotation(self, memo=dict()) -> 'BaseDataset': """Deepcopy for all attributes other than ``data_list``, - ``data_address`` and ``date_bytes``. + ``data_address`` and ``data_bytes``. Args: memo: Memory dict which used to reconstruct complex object @@ -806,7 +806,7 @@ class BaseDataset(Dataset): memo[id(self)] = other for key, value in self.__dict__.items(): - if key in ['data_list', 'data_address', 'date_bytes']: + if key in ['data_list', 'data_address', 'data_bytes']: continue super(BaseDataset, other).__setattr__(key, copy.deepcopy(value, memo)) diff --git a/tests/test_dist/test_dist.py b/tests/test_dist/test_dist.py index 14d3dec40f8eb0e93deeae4a952e3794e541f022..6d3cb23a7cf4bb7d8b987660af920d8f1963c405 100644 --- a/tests/test_dist/test_dist.py +++ b/tests/test_dist/test_dist.py @@ -2,382 +2,460 @@ import os import os.path as osp import tempfile +import unittest +from unittest import TestCase from unittest.mock import patch -import pytest import torch import torch.distributed as torch_dist -import torch.multiprocessing as mp import mmengine.dist as dist from mmengine.dist.dist import sync_random_seed +from mmengine.testing._internal import MultiProcessTestCase from mmengine.utils import TORCH_VERSION, digit_version -def _test_all_reduce_non_dist(): - data = torch.arange(2, dtype=torch.int64) - expected = torch.arange(2, dtype=torch.int64) - dist.all_reduce(data) - assert torch.allclose(data, expected) +class TestDist(TestCase): + """Test dist module in non-distributed environment.""" + def test_all_reduce(self): + data = torch.arange(2, dtype=torch.int64) + expected = torch.arange(2, dtype=torch.int64) + dist.all_reduce(data) + self.assertTrue(torch.allclose(data, expected)) -def _test_all_gather_non_dist(): - data = torch.arange(2, dtype=torch.int64) - expected = torch.arange(2, dtype=torch.int64) - output = dist.all_gather(data) - assert torch.allclose(output[0], expected) + def test_all_gather(self): + data = torch.arange(2, dtype=torch.int64) + expected = torch.arange(2, dtype=torch.int64) + output = dist.all_gather(data) + self.assertTrue(torch.allclose(output[0], expected)) + def test_gather(self): + data = torch.arange(2, dtype=torch.int64) + expected = torch.arange(2, dtype=torch.int64) + output = dist.gather(data) + self.assertTrue(torch.allclose(output[0], expected)) -def _test_gather_non_dist(): - data = torch.arange(2, dtype=torch.int64) - expected = torch.arange(2, dtype=torch.int64) - output = dist.gather(data) - assert torch.allclose(output[0], expected) + def test_broadcast(self): + data = torch.arange(2, dtype=torch.int64) + expected = torch.arange(2, dtype=torch.int64) + dist.broadcast(data) + self.assertTrue(torch.allclose(data, expected)) + @patch('numpy.random.randint', return_value=10) + def test_sync_random_seed(self, mock): + self.assertEqual(sync_random_seed(), 10) -def _test_broadcast_non_dist(): - data = torch.arange(2, dtype=torch.int64) - expected = torch.arange(2, dtype=torch.int64) - dist.broadcast(data) - assert torch.allclose(data, expected) - - -@patch('numpy.random.randint', return_value=10) -def _test_sync_random_seed_no_dist(mock): - assert sync_random_seed() == 10 - - -def _test_broadcast_object_list_no_dist(): - with pytest.raises(AssertionError): - # input should be list of object - dist.broadcast_object_list('foo') - - data = ['foo', 12, {1: 2}] - expected = ['foo', 12, {1: 2}] - dist.broadcast_object_list(data) - assert data == expected - - -def _test_all_reduce_dict_no_dist(): - with pytest.raises(AssertionError): - # input should be dict - dist.all_reduce_dict('foo') - - data = { - 'key1': torch.arange(2, dtype=torch.int64), - 'key2': torch.arange(3, dtype=torch.int64) - } - expected = { - 'key1': torch.arange(2, dtype=torch.int64), - 'key2': torch.arange(3, dtype=torch.int64) - } - dist.all_reduce_dict(data) - for key in data: - assert torch.allclose(data[key], expected[key]) - - -def _test_all_gather_object_no_dist(): - data = 'foo' - expected = 'foo' - gather_objects = dist.all_gather_object(data) - assert gather_objects[0] == expected - - -def _test_gather_object_no_dist(): - data = 'foo' - expected = 'foo' - gather_objects = dist.gather_object(data) - assert gather_objects[0] == expected - - -def _test_collect_results_non_dist(): - data = ['foo', {1: 2}] - size = 2 - expected = ['foo', {1: 2}] - - # test `device=cpu` - output = dist.collect_results(data, size, device='cpu') - assert output == expected - - # test `device=gpu` - output = dist.collect_results(data, size, device='cpu') - assert output == expected + def test_broadcast_object_list(self): + with self.assertRaises(AssertionError): + # input should be list of object + dist.broadcast_object_list('foo') + data = ['foo', 12, {1: 2}] + expected = ['foo', 12, {1: 2}] + dist.broadcast_object_list(data) + self.assertEqual(data, expected) + + def test_all_reduce_dict(self): + with self.assertRaises(AssertionError): + # input should be dict + dist.all_reduce_dict('foo') + + data = { + 'key1': torch.arange(2, dtype=torch.int64), + 'key2': torch.arange(3, dtype=torch.int64) + } + expected = { + 'key1': torch.arange(2, dtype=torch.int64), + 'key2': torch.arange(3, dtype=torch.int64) + } + dist.all_reduce_dict(data) + for key in data: + self.assertTrue(torch.allclose(data[key], expected[key])) -def init_process(rank, world_size, functions, backend='gloo'): - """Initialize the distributed environment.""" - os.environ['MASTER_ADDR'] = '127.0.0.1' - os.environ['MASTER_PORT'] = '29505' - os.environ['RANK'] = str(rank) + def test_all_gather_object(self): + data = 'foo' + expected = 'foo' + gather_objects = dist.all_gather_object(data) + self.assertEqual(gather_objects[0], expected) - if backend == 'nccl': - num_gpus = torch.cuda.device_count() - torch.cuda.set_device(rank % num_gpus) - device = 'cuda' - else: - device = 'cpu' + def test_gather_object(self): + data = 'foo' + expected = 'foo' + gather_objects = dist.gather_object(data) + self.assertEqual(gather_objects[0], expected) - torch_dist.init_process_group( - backend=backend, rank=rank, world_size=world_size) + def test_collect_results(self): + data = ['foo', {1: 2}] + size = 2 + expected = ['foo', {1: 2}] + + # test `device=cpu` + output = dist.collect_results(data, size, device='cpu') + self.assertEqual(output, expected) + + # test `device=gpu` + output = dist.collect_results(data, size, device='gpu') + self.assertEqual(output, expected) + + +class TestDistWithGLOOBackend(MultiProcessTestCase): + + def _init_dist_env(self, rank, world_size): + """Initialize the distributed environment.""" + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '29505' + os.environ['RANK'] = str(rank) + torch_dist.init_process_group( + backend='gloo', rank=rank, world_size=world_size) + + def setUp(self): + super().setUp() + self._spawn_processes() + + def test_all_reduce(self): + self._init_dist_env(self.rank, self.world_size) + for tensor_type, reduce_op in zip([torch.int64, torch.float32], + ['sum', 'mean']): + if dist.get_rank() == 0: + data = torch.tensor([1, 2], dtype=tensor_type) + else: + data = torch.tensor([3, 4], dtype=tensor_type) + + if reduce_op == 'sum': + expected = torch.tensor([4, 6], dtype=tensor_type) + else: + expected = torch.tensor([2, 3], dtype=tensor_type) + + dist.all_reduce(data, reduce_op) + self.assertTrue(torch.allclose(data, expected)) + + def test_all_gather(self): + self._init_dist_env(self.rank, self.world_size) + if dist.get_rank() == 0: + data = torch.tensor([0, 1]) + else: + data = torch.tensor([1, 2]) - for func in functions: - func(device) + expected = [torch.tensor([0, 1]), torch.tensor([1, 2])] + output = dist.all_gather(data) + self.assertTrue( + torch.allclose(output[dist.get_rank()], expected[dist.get_rank()])) -def main(functions, world_size=2, backend='gloo'): - try: - mp.spawn( - init_process, - args=(world_size, functions, backend), - nprocs=world_size) - except Exception: - pytest.fail(f'{backend} failed') + def test_gather(self): + self._init_dist_env(self.rank, self.world_size) + if dist.get_rank() == 0: + data = torch.tensor([0, 1]) + else: + data = torch.tensor([1, 2]) + output = dist.gather(data) -def _test_all_reduce_dist(device): - for tensor_type, reduce_op in zip([torch.int64, torch.float32], - ['sum', 'mean']): if dist.get_rank() == 0: - data = torch.tensor([1, 2], dtype=tensor_type).to(device) + expected = [torch.tensor([0, 1]), torch.tensor([1, 2])] + for i in range(2): + assert torch.allclose(output[i], expected[i]) else: - data = torch.tensor([3, 4], dtype=tensor_type).to(device) + assert output == [] - if reduce_op == 'sum': - expected = torch.tensor([4, 6], dtype=tensor_type).to(device) + def test_broadcast_dist(self): + self._init_dist_env(self.rank, self.world_size) + if dist.get_rank() == 0: + data = torch.tensor([0, 1]) else: - expected = torch.tensor([2, 3], dtype=tensor_type).to(device) + data = torch.tensor([1, 2]) - dist.all_reduce(data, reduce_op) + expected = torch.tensor([0, 1]) + dist.broadcast(data, 0) assert torch.allclose(data, expected) + def test_sync_random_seed(self): + self._init_dist_env(self.rank, self.world_size) + with patch.object( + torch, 'tensor', + return_value=torch.tensor(1024)) as mock_tensor: + output = dist.sync_random_seed() + assert output == 1024 + mock_tensor.assert_called() + + def test_broadcast_object_list(self): + self._init_dist_env(self.rank, self.world_size) + if dist.get_rank() == 0: + data = ['foo', 12, {1: 2}] + else: + data = [None, None, None] + + expected = ['foo', 12, {1: 2}] + dist.broadcast_object_list(data) + self.assertEqual(data, expected) + + def test_all_reduce_dict(self): + self._init_dist_env(self.rank, self.world_size) + for tensor_type, reduce_op in zip([torch.int64, torch.float32], + ['sum', 'mean']): + if dist.get_rank() == 0: + data = { + 'key1': torch.tensor([0, 1], dtype=tensor_type), + 'key2': torch.tensor([1, 2], dtype=tensor_type), + } + else: + data = { + 'key1': torch.tensor([2, 3], dtype=tensor_type), + 'key2': torch.tensor([3, 4], dtype=tensor_type), + } + + if reduce_op == 'sum': + expected = { + 'key1': torch.tensor([2, 4], dtype=tensor_type), + 'key2': torch.tensor([4, 6], dtype=tensor_type), + } + else: + expected = { + 'key1': torch.tensor([1, 2], dtype=tensor_type), + 'key2': torch.tensor([2, 3], dtype=tensor_type), + } + + dist.all_reduce_dict(data, reduce_op) + + for key in data: + assert torch.allclose(data[key], expected[key]) + + # `torch.cat` in torch1.5 can not concatenate different types so we + # fallback to convert them all to float type. + if digit_version(TORCH_VERSION) == digit_version('1.5.0'): + if dist.get_rank() == 0: + data = { + 'key1': torch.tensor([0, 1], dtype=torch.float32), + 'key2': torch.tensor([1, 2], dtype=torch.int32) + } + else: + data = { + 'key1': torch.tensor([2, 3], dtype=torch.float32), + 'key2': torch.tensor([3, 4], dtype=torch.int32), + } -def _test_all_gather_dist(device): - if dist.get_rank() == 0: - data = torch.tensor([0, 1]).to(device) - else: - data = torch.tensor([1, 2]).to(device) - - expected = [ - torch.tensor([0, 1]).to(device), - torch.tensor([1, 2]).to(device) - ] + expected = { + 'key1': torch.tensor([2, 4], dtype=torch.float32), + 'key2': torch.tensor([4, 6], dtype=torch.float32), + } - output = dist.all_gather(data) - assert torch.allclose(output[dist.get_rank()], expected[dist.get_rank()]) + dist.all_reduce_dict(data, 'sum') + for key in data: + assert torch.allclose(data[key], expected[key]) -def _test_gather_dist(device): - if dist.get_rank() == 0: - data = torch.tensor([0, 1]).to(device) - else: - data = torch.tensor([1, 2]).to(device) + def test_all_gather_object(self): + self._init_dist_env(self.rank, self.world_size) + if dist.get_rank() == 0: + data = 'foo' + else: + data = {1: 2} - output = dist.gather(data) + expected = ['foo', {1: 2}] + output = dist.all_gather_object(data) - if dist.get_rank() == 0: - expected = [ - torch.tensor([0, 1]).to(device), - torch.tensor([1, 2]).to(device) - ] - for i in range(2): - assert torch.allclose(output[i], expected[i]) - else: - assert output == [] + self.assertEqual(output, expected) + def test_gather_object(self): + self._init_dist_env(self.rank, self.world_size) + if dist.get_rank() == 0: + data = 'foo' + else: + data = {1: 2} -def _test_broadcast_dist(device): - if dist.get_rank() == 0: - data = torch.tensor([0, 1]).to(device) - else: - data = torch.tensor([1, 2]).to(device) + output = dist.gather_object(data, dst=0) - expected = torch.tensor([0, 1]).to(device) - dist.broadcast(data, 0) - assert torch.allclose(data, expected) + if dist.get_rank() == 0: + self.assertEqual(output, ['foo', {1: 2}]) + else: + self.assertIsNone(output) -def _test_sync_random_seed_dist(device): - with patch.object( - torch, 'tensor', return_value=torch.tensor(1024)) as mock_tensor: - output = dist.sync_random_seed() - assert output == 1024 - mock_tensor.assert_called() +@unittest.skipIf( + torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl') +class TestDistWithNCCLBackend(MultiProcessTestCase): + def _init_dist_env(self, rank, world_size): + """Initialize the distributed environment.""" + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '29505' + os.environ['RANK'] = str(rank) -def _test_broadcast_object_list_dist(device): - if dist.get_rank() == 0: - data = ['foo', 12, {1: 2}] - else: - data = [None, None, None] + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + torch_dist.init_process_group( + backend='nccl', rank=rank, world_size=world_size) + + def setUp(self): + super().setUp() + self._spawn_processes() + + def test_all_reduce(self): + self._init_dist_env(self.rank, self.world_size) + for tensor_type, reduce_op in zip([torch.int64, torch.float32], + ['sum', 'mean']): + if dist.get_rank() == 0: + data = torch.tensor([1, 2], dtype=tensor_type).cuda() + else: + data = torch.tensor([3, 4], dtype=tensor_type).cuda() + + if reduce_op == 'sum': + expected = torch.tensor([4, 6], dtype=tensor_type).cuda() + else: + expected = torch.tensor([2, 3], dtype=tensor_type).cuda() + + dist.all_reduce(data, reduce_op) + self.assertTrue(torch.allclose(data, expected)) + + def test_all_gather(self): + self._init_dist_env(self.rank, self.world_size) + if dist.get_rank() == 0: + data = torch.tensor([0, 1]).cuda() + else: + data = torch.tensor([1, 2]).cuda() - expected = ['foo', 12, {1: 2}] + expected = [torch.tensor([0, 1]).cuda(), torch.tensor([1, 2]).cuda()] - dist.broadcast_object_list(data) + output = dist.all_gather(data) + self.assertTrue( + torch.allclose(output[dist.get_rank()], expected[dist.get_rank()])) - assert data == expected + def test_broadcast_dist(self): + self._init_dist_env(self.rank, self.world_size) + if dist.get_rank() == 0: + data = torch.tensor([0, 1]).cuda() + else: + data = torch.tensor([1, 2]).cuda() + expected = torch.tensor([0, 1]).cuda() + dist.broadcast(data, 0) + assert torch.allclose(data, expected) -def _test_all_reduce_dict_dist(device): - for tensor_type, reduce_op in zip([torch.int64, torch.float32], - ['sum', 'mean']): + def test_sync_random_seed(self): + self._init_dist_env(self.rank, self.world_size) + with patch.object( + torch, 'tensor', + return_value=torch.tensor(1024)) as mock_tensor: + output = dist.sync_random_seed() + assert output == 1024 + mock_tensor.assert_called() + + def test_broadcast_object_list(self): + self._init_dist_env(self.rank, self.world_size) if dist.get_rank() == 0: - data = { - 'key1': torch.tensor([0, 1], dtype=tensor_type).to(device), - 'key2': torch.tensor([1, 2], dtype=tensor_type).to(device) - } + data = ['foo', 12, {1: 2}] else: - data = { - 'key1': torch.tensor([2, 3], dtype=tensor_type).to(device), - 'key2': torch.tensor([3, 4], dtype=tensor_type).to(device) - } + data = [None, None, None] + + expected = ['foo', 12, {1: 2}] + dist.broadcast_object_list(data) + self.assertEqual(data, expected) + + def test_all_reduce_dict(self): + self._init_dist_env(self.rank, self.world_size) + for tensor_type, reduce_op in zip([torch.int64, torch.float32], + ['sum', 'mean']): + if dist.get_rank() == 0: + data = { + 'key1': torch.tensor([0, 1], dtype=tensor_type).cuda(), + 'key2': torch.tensor([1, 2], dtype=tensor_type).cuda(), + } + else: + data = { + 'key1': torch.tensor([2, 3], dtype=tensor_type).cuda(), + 'key2': torch.tensor([3, 4], dtype=tensor_type).cuda(), + } + + if reduce_op == 'sum': + expected = { + 'key1': torch.tensor([2, 4], dtype=tensor_type).cuda(), + 'key2': torch.tensor([4, 6], dtype=tensor_type).cuda(), + } + else: + expected = { + 'key1': torch.tensor([1, 2], dtype=tensor_type).cuda(), + 'key2': torch.tensor([2, 3], dtype=tensor_type).cuda(), + } + + dist.all_reduce_dict(data, reduce_op) + + for key in data: + assert torch.allclose(data[key], expected[key]) + + # `torch.cat` in torch1.5 can not concatenate different types so we + # fallback to convert them all to float type. + if digit_version(TORCH_VERSION) == digit_version('1.5.0'): + if dist.get_rank() == 0: + data = { + 'key1': torch.tensor([0, 1], dtype=torch.float32).cuda(), + 'key2': torch.tensor([1, 2], dtype=torch.int32).cuda(), + } + else: + data = { + 'key1': torch.tensor([2, 3], dtype=torch.float32).cuda(), + 'key2': torch.tensor([3, 4], dtype=torch.int32).cuda(), + } - if reduce_op == 'sum': - expected = { - 'key1': torch.tensor([2, 4], dtype=tensor_type).to(device), - 'key2': torch.tensor([4, 6], dtype=tensor_type).to(device) - } - else: expected = { - 'key1': torch.tensor([1, 2], dtype=tensor_type).to(device), - 'key2': torch.tensor([2, 3], dtype=tensor_type).to(device) + 'key1': torch.tensor([2, 4], dtype=torch.float32).cuda(), + 'key2': torch.tensor([4, 6], dtype=torch.float32).cuda(), } - dist.all_reduce_dict(data, reduce_op) + dist.all_reduce_dict(data, 'sum') - for key in data: - assert torch.allclose(data[key], expected[key]) + for key in data: + assert torch.allclose(data[key], expected[key]) - # `torch.cat` in torch1.5 can not concatenate different types so we - # fallback to convert them all to float type. - if digit_version(TORCH_VERSION) == digit_version('1.5.0'): + def test_all_gather_object(self): + self._init_dist_env(self.rank, self.world_size) if dist.get_rank() == 0: - data = { - 'key1': torch.tensor([0, 1], dtype=torch.float32).to(device), - 'key2': torch.tensor([1, 2], dtype=torch.int32).to(device) - } + data = 'foo' else: - data = { - 'key1': torch.tensor([2, 3], dtype=torch.float32).to(device), - 'key2': torch.tensor([3, 4], dtype=torch.int32).to(device) - } - - expected = { - 'key1': torch.tensor([2, 4], dtype=torch.float32).to(device), - 'key2': torch.tensor([4, 6], dtype=torch.float32).to(device) - } - - dist.all_reduce_dict(data, 'sum') - - for key in data: - assert torch.allclose(data[key], expected[key]) + data = {1: 2} + expected = ['foo', {1: 2}] + output = dist.all_gather_object(data) -def _test_all_gather_object_dist(device): - if dist.get_rank() == 0: - data = 'foo' - else: - data = {1: 2} - - expected = ['foo', {1: 2}] - output = dist.all_gather_object(data) - - assert output == expected + self.assertEqual(output, expected) + def test_collect_results(self): + self._init_dist_env(self.rank, self.world_size) + if dist.get_rank() == 0: + data = ['foo', {1: 2}] + else: + data = [24, {'a': 'b'}] -def _test_gather_object_dist(device): - if dist.get_rank() == 0: - data = 'foo' - else: - data = {1: 2} + size = 4 - output = dist.gather_object(data, dst=0) + expected = ['foo', 24, {1: 2}, {'a': 'b'}] - if dist.get_rank() == 0: - assert output == ['foo', {1: 2}] - else: - assert output is None + # test `device=cpu` + output = dist.collect_results(data, size, device='cpu') + if dist.get_rank() == 0: + self.assertEqual(output, expected) + else: + self.assertIsNone(output) + + # test `device=cpu` and `tmpdir is not None` + tmpdir = tempfile.mkdtemp() + # broadcast tmpdir to all ranks to make it consistent + object_list = [tmpdir] + dist.broadcast_object_list(object_list) + output = dist.collect_results( + data, size, device='cpu', tmpdir=object_list[0]) + if dist.get_rank() == 0: + self.assertEqual(output, expected) + else: + self.assertIsNone(output) + if dist.get_rank() == 0: + # object_list[0] will be removed by `dist.collect_results` + self.assertFalse(osp.exists(object_list[0])) -def _test_collect_results_dist(device): - if dist.get_rank() == 0: - data = ['foo', {1: 2}] - else: - data = [24, {'a': 'b'}] - - size = 4 - - expected = ['foo', 24, {1: 2}, {'a': 'b'}] - - # test `device=cpu` - output = dist.collect_results(data, size, device='cpu') - if dist.get_rank() == 0: - assert output == expected - else: - assert output is None - - # test `device=cpu` and `tmpdir is not None` - tmpdir = tempfile.mkdtemp() - # broadcast tmpdir to all ranks to make it consistent - object_list = [tmpdir] - dist.broadcast_object_list(object_list) - output = dist.collect_results( - data, size, device='cpu', tmpdir=object_list[0]) - if dist.get_rank() == 0: - assert output == expected - else: - assert output is None - - if dist.get_rank() == 0: - # object_list[0] will be removed by `dist.collect_results` - assert not osp.exists(object_list[0]) - - # test `device=gpu` - output = dist.collect_results(data, size, device='gpu') - if dist.get_rank() == 0: - assert output == expected - else: - assert output is None - - -def test_non_distributed_env(): - _test_all_reduce_non_dist() - _test_all_gather_non_dist() - _test_gather_non_dist() - _test_broadcast_non_dist() - _test_sync_random_seed_no_dist() - _test_broadcast_object_list_no_dist() - _test_all_reduce_dict_no_dist() - _test_all_gather_object_no_dist() - _test_gather_object_no_dist() - _test_collect_results_non_dist() - - -def test_gloo_backend(): - functions_to_test = [ - _test_all_reduce_dist, - _test_all_gather_dist, - _test_gather_dist, - _test_broadcast_dist, - _test_sync_random_seed_dist, - _test_broadcast_object_list_dist, - _test_all_reduce_dict_dist, - _test_all_gather_object_dist, - _test_gather_object_dist, - ] - main(functions_to_test, backend='gloo') - - -@pytest.mark.skipif( - torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl') -def test_nccl_backend(): - functions_to_test = [ - _test_all_reduce_dist, - _test_all_gather_dist, - _test_broadcast_dist, - _test_sync_random_seed_dist, - _test_broadcast_object_list_dist, - _test_all_reduce_dict_dist, - _test_all_gather_object_dist, - _test_collect_results_dist, - ] - main(functions_to_test, backend='nccl') + # test `device=gpu` + output = dist.collect_results(data, size, device='gpu') + if dist.get_rank() == 0: + self.assertEqual(output, expected) + else: + self.assertIsNone(output) diff --git a/tests/test_dist/test_utils.py b/tests/test_dist/test_utils.py index b4b74d4526d325c9620793a3ff916575d8b6e963..124551989a1d160e21ae4ec71c79d0c844500b78 100644 --- a/tests/test_dist/test_utils.py +++ b/tests/test_dist/test_utils.py @@ -1,158 +1,177 @@ # Copyright (c) OpenMMLab. All rights reserved. import os +import unittest +from unittest import TestCase -import pytest import torch import torch.distributed as torch_dist -import torch.multiprocessing as mp import mmengine.dist as dist +from mmengine.testing._internal import MultiProcessTestCase -def _test_get_backend_non_dist(): - assert dist.get_backend() is None +class TestUtils(TestCase): + def test_get_backend(self): + self.assertIsNone(dist.get_backend()) -def _test_get_world_size_non_dist(): - assert dist.get_world_size() == 1 + def test_get_world_size(self): + self.assertEqual(dist.get_world_size(), 1) + def test_get_rank(self): + self.assertEqual(dist.get_rank(), 0) -def _test_get_rank_non_dist(): - assert dist.get_rank() == 0 + def test_local_size(self): + self.assertEqual(dist.get_local_size(), 1) + def test_local_rank(self): + self.assertEqual(dist.get_local_rank(), 0) -def _test_local_size_non_dist(): - assert dist.get_local_size() == 1 + def test_get_dist_info(self): + self.assertEqual(dist.get_dist_info(), (0, 1)) + def test_is_main_process(self): + self.assertTrue(dist.is_main_process()) -def _test_local_rank_non_dist(): - assert dist.get_local_rank() == 0 + def test_master_only(self): + @dist.master_only + def fun(): + assert dist.get_rank() == 0 -def _test_get_dist_info_non_dist(): - assert dist.get_dist_info() == (0, 1) + fun() + def test_barrier(self): + dist.barrier() # nothing is done -def _test_is_main_process_non_dist(): - assert dist.is_main_process() +class TestUtilsWithGLOOBackend(MultiProcessTestCase): -def _test_master_only_non_dist(): + def _init_dist_env(self, rank, world_size): + """Initialize the distributed environment.""" + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '29505' + os.environ['RANK'] = str(rank) - @dist.master_only - def fun(): - assert dist.get_rank() == 0 + torch_dist.init_process_group( + backend='gloo', rank=rank, world_size=world_size) + dist.init_local_group(0, world_size) - fun() + def setUp(self): + super().setUp() + self._spawn_processes() + def test_get_backend(self): + self._init_dist_env(self.rank, self.world_size) + self.assertEqual(dist.get_backend(), torch_dist.get_backend()) -def _test_barrier_non_dist(): - dist.barrier() # nothing is done + def test_get_world_size(self): + self._init_dist_env(self.rank, self.world_size) + self.assertEqual(dist.get_world_size(), 2) + def test_get_rank(self): + self._init_dist_env(self.rank, self.world_size) + if torch_dist.get_rank() == 0: + self.assertEqual(dist.get_rank(), 0) + else: + self.assertEqual(dist.get_rank(), 1) -def init_process(rank, world_size, functions, backend='gloo'): - """Initialize the distributed environment.""" - os.environ['MASTER_ADDR'] = '127.0.0.1' - os.environ['MASTER_PORT'] = '29501' - os.environ['RANK'] = str(rank) + def test_local_size(self): + self._init_dist_env(self.rank, self.world_size) + self.assertEqual(dist.get_local_size(), 2) - if backend == 'nccl': - num_gpus = torch.cuda.device_count() - torch.cuda.set_device(rank % num_gpus) - - torch_dist.init_process_group( - backend=backend, rank=rank, world_size=world_size) - dist.init_local_group(0, world_size) - - for func in functions: - func() - - -def main(functions, world_size=2, backend='gloo'): - try: - mp.spawn( - init_process, - args=(world_size, functions, backend), - nprocs=world_size) - except Exception: - pytest.fail('error') - - -def _test_get_backend_dist(): - assert dist.get_backend() == torch_dist.get_backend() - - -def _test_get_world_size_dist(): - assert dist.get_world_size() == 2 - - -def _test_get_rank_dist(): - if torch_dist.get_rank() == 0: - assert dist.get_rank() == 0 - else: - assert dist.get_rank() == 1 - - -def _test_local_size_dist(): - assert dist.get_local_size() == 2 + def test_local_rank(self): + self._init_dist_env(self.rank, self.world_size) + self.assertEqual( + torch_dist.get_rank(dist.get_local_group()), dist.get_local_rank()) + def test_get_dist_info(self): + self._init_dist_env(self.rank, self.world_size) + if dist.get_rank() == 0: + self.assertEqual(dist.get_dist_info(), (0, 2)) + else: + self.assertEqual(dist.get_dist_info(), (1, 2)) -def _test_local_rank_dist(): - torch_dist.get_rank(dist.get_local_group()) == dist.get_local_rank() + def test_is_main_process(self): + self._init_dist_env(self.rank, self.world_size) + if dist.get_rank() == 0: + self.assertTrue(dist.is_main_process()) + else: + self.assertFalse(dist.is_main_process()) + def test_master_only(self): + self._init_dist_env(self.rank, self.world_size) -def _test_get_dist_info_dist(): - if dist.get_rank() == 0: - assert dist.get_dist_info() == (0, 2) - else: - assert dist.get_dist_info() == (1, 2) + @dist.master_only + def fun(): + assert dist.get_rank() == 0 + fun() -def _test_is_main_process_dist(): - if dist.get_rank() == 0: - assert dist.is_main_process() - else: - assert not dist.is_main_process() +@unittest.skipIf( + torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl') +class TestUtilsWithNCCLBackend(MultiProcessTestCase): -def _test_master_only_dist(): - - @dist.master_only - def fun(): - assert dist.get_rank() == 0 - - fun() - - -def test_non_distributed_env(): - _test_get_backend_non_dist() - _test_get_world_size_non_dist() - _test_get_rank_non_dist() - _test_local_size_non_dist() - _test_local_rank_non_dist() - _test_get_dist_info_non_dist() - _test_is_main_process_non_dist() - _test_master_only_non_dist() - _test_barrier_non_dist() - - -functions_to_test = [ - _test_get_backend_dist, - _test_get_world_size_dist, - _test_get_rank_dist, - _test_local_size_dist, - _test_local_rank_dist, - _test_get_dist_info_dist, - _test_is_main_process_dist, - _test_master_only_dist, -] - - -def test_gloo_backend(): - main(functions_to_test) - + def _init_dist_env(self, rank, world_size): + """Initialize the distributed environment.""" + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '29505' + os.environ['RANK'] = str(rank) -@pytest.mark.skipif( - torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl') -def test_nccl_backend(): - main(functions_to_test, backend='nccl') + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + torch_dist.init_process_group( + backend='nccl', rank=rank, world_size=world_size) + dist.init_local_group(0, world_size) + + def setUp(self): + super().setUp() + self._spawn_processes() + + def test_get_backend(self): + self._init_dist_env(self.rank, self.world_size) + self.assertEqual(dist.get_backend(), torch_dist.get_backend()) + + def test_get_world_size(self): + self._init_dist_env(self.rank, self.world_size) + self.assertEqual(dist.get_world_size(), 2) + + def test_get_rank(self): + self._init_dist_env(self.rank, self.world_size) + if torch_dist.get_rank() == 0: + self.assertEqual(dist.get_rank(), 0) + else: + self.assertEqual(dist.get_rank(), 1) + + def test_local_size(self): + self._init_dist_env(self.rank, self.world_size) + self.assertEqual(dist.get_local_size(), 2) + + def test_local_rank(self): + self._init_dist_env(self.rank, self.world_size) + self.assertEqual( + torch_dist.get_rank(dist.get_local_group()), dist.get_local_rank()) + + def test_get_dist_info(self): + self._init_dist_env(self.rank, self.world_size) + if dist.get_rank() == 0: + self.assertEqual(dist.get_dist_info(), (0, 2)) + else: + self.assertEqual(dist.get_dist_info(), (1, 2)) + + def test_is_main_process(self): + self._init_dist_env(self.rank, self.world_size) + if dist.get_rank() == 0: + self.assertTrue(dist.is_main_process()) + else: + self.assertFalse(dist.is_main_process()) + + def test_master_only(self): + self._init_dist_env(self.rank, self.world_size) + + @dist.master_only + def fun(): + assert dist.get_rank() == 0 + + fun()