Skip to content
Snippets Groups Projects
Unverified Commit 50650e0b authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

[Enhancement] Refactor the unit tests of dist module with MultiProcessTestCase (#138)

* [Enhancement] Provide MultiProcessTestCase to test distributed related modules

* remove debugging info

* add timeout property

* [Enhancement] Refactor the unit tests of dist module with MultiProcessTestCase

* minor refinement

* minor fix
parent 2d803678
No related branches found
No related tags found
No related merge requests found
This diff is collapsed.
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os import os
import unittest
from unittest import TestCase
import pytest
import torch import torch
import torch.distributed as torch_dist import torch.distributed as torch_dist
import torch.multiprocessing as mp
import mmengine.dist as dist import mmengine.dist as dist
from mmengine.testing._internal import MultiProcessTestCase
def _test_get_backend_non_dist(): class TestUtils(TestCase):
assert dist.get_backend() is None
def test_get_backend(self):
self.assertIsNone(dist.get_backend())
def _test_get_world_size_non_dist(): def test_get_world_size(self):
assert dist.get_world_size() == 1 self.assertEqual(dist.get_world_size(), 1)
def test_get_rank(self):
self.assertEqual(dist.get_rank(), 0)
def _test_get_rank_non_dist(): def test_local_size(self):
assert dist.get_rank() == 0 self.assertEqual(dist.get_local_size(), 1)
def test_local_rank(self):
self.assertEqual(dist.get_local_rank(), 0)
def _test_local_size_non_dist(): def test_get_dist_info(self):
assert dist.get_local_size() == 1 self.assertEqual(dist.get_dist_info(), (0, 1))
def test_is_main_process(self):
self.assertTrue(dist.is_main_process())
def _test_local_rank_non_dist(): def test_master_only(self):
assert dist.get_local_rank() == 0
@dist.master_only
def fun():
assert dist.get_rank() == 0
def _test_get_dist_info_non_dist(): fun()
assert dist.get_dist_info() == (0, 1)
def test_barrier(self):
dist.barrier() # nothing is done
def _test_is_main_process_non_dist():
assert dist.is_main_process()
class TestUtilsWithGLOOBackend(MultiProcessTestCase):
def _test_master_only_non_dist(): def _init_dist_env(self, rank, world_size):
"""Initialize the distributed environment."""
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29505'
os.environ['RANK'] = str(rank)
@dist.master_only torch_dist.init_process_group(
def fun(): backend='gloo', rank=rank, world_size=world_size)
assert dist.get_rank() == 0 dist.init_local_group(0, world_size)
fun() def setUp(self):
super().setUp()
self._spawn_processes()
def test_get_backend(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(dist.get_backend(), torch_dist.get_backend())
def _test_barrier_non_dist(): def test_get_world_size(self):
dist.barrier() # nothing is done self._init_dist_env(self.rank, self.world_size)
self.assertEqual(dist.get_world_size(), 2)
def test_get_rank(self):
self._init_dist_env(self.rank, self.world_size)
if torch_dist.get_rank() == 0:
self.assertEqual(dist.get_rank(), 0)
else:
self.assertEqual(dist.get_rank(), 1)
def init_process(rank, world_size, functions, backend='gloo'): def test_local_size(self):
"""Initialize the distributed environment.""" self._init_dist_env(self.rank, self.world_size)
os.environ['MASTER_ADDR'] = '127.0.0.1' self.assertEqual(dist.get_local_size(), 2)
os.environ['MASTER_PORT'] = '29501'
os.environ['RANK'] = str(rank)
if backend == 'nccl': def test_local_rank(self):
num_gpus = torch.cuda.device_count() self._init_dist_env(self.rank, self.world_size)
torch.cuda.set_device(rank % num_gpus) self.assertEqual(
torch_dist.get_rank(dist.get_local_group()), dist.get_local_rank())
torch_dist.init_process_group(
backend=backend, rank=rank, world_size=world_size)
dist.init_local_group(0, world_size)
for func in functions:
func()
def main(functions, world_size=2, backend='gloo'):
try:
mp.spawn(
init_process,
args=(world_size, functions, backend),
nprocs=world_size)
except Exception:
pytest.fail('error')
def _test_get_backend_dist():
assert dist.get_backend() == torch_dist.get_backend()
def _test_get_world_size_dist():
assert dist.get_world_size() == 2
def _test_get_rank_dist():
if torch_dist.get_rank() == 0:
assert dist.get_rank() == 0
else:
assert dist.get_rank() == 1
def _test_local_size_dist():
assert dist.get_local_size() == 2
def test_get_dist_info(self):
self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
self.assertEqual(dist.get_dist_info(), (0, 2))
else:
self.assertEqual(dist.get_dist_info(), (1, 2))
def _test_local_rank_dist(): def test_is_main_process(self):
torch_dist.get_rank(dist.get_local_group()) == dist.get_local_rank() self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
self.assertTrue(dist.is_main_process())
else:
self.assertFalse(dist.is_main_process())
def test_master_only(self):
self._init_dist_env(self.rank, self.world_size)
def _test_get_dist_info_dist(): @dist.master_only
if dist.get_rank() == 0: def fun():
assert dist.get_dist_info() == (0, 2) assert dist.get_rank() == 0
else:
assert dist.get_dist_info() == (1, 2)
fun()
def _test_is_main_process_dist():
if dist.get_rank() == 0:
assert dist.is_main_process()
else:
assert not dist.is_main_process()
@unittest.skipIf(
torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl')
class TestUtilsWithNCCLBackend(MultiProcessTestCase):
def _test_master_only_dist(): def _init_dist_env(self, rank, world_size):
"""Initialize the distributed environment."""
@dist.master_only os.environ['MASTER_ADDR'] = '127.0.0.1'
def fun(): os.environ['MASTER_PORT'] = '29505'
assert dist.get_rank() == 0 os.environ['RANK'] = str(rank)
fun()
def test_non_distributed_env():
_test_get_backend_non_dist()
_test_get_world_size_non_dist()
_test_get_rank_non_dist()
_test_local_size_non_dist()
_test_local_rank_non_dist()
_test_get_dist_info_non_dist()
_test_is_main_process_non_dist()
_test_master_only_non_dist()
_test_barrier_non_dist()
functions_to_test = [
_test_get_backend_dist,
_test_get_world_size_dist,
_test_get_rank_dist,
_test_local_size_dist,
_test_local_rank_dist,
_test_get_dist_info_dist,
_test_is_main_process_dist,
_test_master_only_dist,
]
def test_gloo_backend():
main(functions_to_test)
@pytest.mark.skipif( num_gpus = torch.cuda.device_count()
torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl') torch.cuda.set_device(rank % num_gpus)
def test_nccl_backend(): torch_dist.init_process_group(
main(functions_to_test, backend='nccl') backend='nccl', rank=rank, world_size=world_size)
dist.init_local_group(0, world_size)
def setUp(self):
super().setUp()
self._spawn_processes()
def test_get_backend(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(dist.get_backend(), torch_dist.get_backend())
def test_get_world_size(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(dist.get_world_size(), 2)
def test_get_rank(self):
self._init_dist_env(self.rank, self.world_size)
if torch_dist.get_rank() == 0:
self.assertEqual(dist.get_rank(), 0)
else:
self.assertEqual(dist.get_rank(), 1)
def test_local_size(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(dist.get_local_size(), 2)
def test_local_rank(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(
torch_dist.get_rank(dist.get_local_group()), dist.get_local_rank())
def test_get_dist_info(self):
self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
self.assertEqual(dist.get_dist_info(), (0, 2))
else:
self.assertEqual(dist.get_dist_info(), (1, 2))
def test_is_main_process(self):
self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
self.assertTrue(dist.is_main_process())
else:
self.assertFalse(dist.is_main_process())
def test_master_only(self):
self._init_dist_env(self.rank, self.world_size)
@dist.master_only
def fun():
assert dist.get_rank() == 0
fun()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment