diff --git a/tests/test_dist/test_dist.py b/tests/test_dist/test_dist.py index 3dccb075cc74692ce3aa002b87688a2f2e324cea..14d3dec40f8eb0e93deeae4a952e3794e541f022 100644 --- a/tests/test_dist/test_dist.py +++ b/tests/test_dist/test_dist.py @@ -6,6 +6,7 @@ from unittest.mock import patch import pytest import torch +import torch.distributed as torch_dist import torch.multiprocessing as mp import mmengine.dist as dist @@ -108,9 +109,16 @@ def init_process(rank, world_size, functions, backend='gloo'): os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = '29505' os.environ['RANK'] = str(rank) - dist.init_dist('pytorch', backend, rank=rank, world_size=world_size) - device = 'cpu' if backend == 'gloo' else 'cuda' + if backend == 'nccl': + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + device = 'cuda' + else: + device = 'cpu' + + torch_dist.init_process_group( + backend=backend, rank=rank, world_size=world_size) for func in functions: func(device) diff --git a/tests/test_dist/test_utils.py b/tests/test_dist/test_utils.py index e099c8792cd8bb8cb2f1416eaaec919790ed9c13..b4b74d4526d325c9620793a3ff916575d8b6e963 100644 --- a/tests/test_dist/test_utils.py +++ b/tests/test_dist/test_utils.py @@ -55,7 +55,13 @@ def init_process(rank, world_size, functions, backend='gloo'): os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = '29501' os.environ['RANK'] = str(rank) - dist.init_dist('pytorch', backend, rank=rank, world_size=world_size) + + if backend == 'nccl': + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + + torch_dist.init_process_group( + backend=backend, rank=rank, world_size=world_size) dist.init_local_group(0, world_size) for func in functions: