From 0ca54eb71b312531cd5862356550fe9a6058f2d5 Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Fri, 1 Apr 2022 12:50:15 +0800 Subject: [PATCH] [Fix] Fix unit tests when gpu is not available (#163) --- tests/test_dist/test_dist.py | 12 ++++++++++-- tests/test_dist/test_utils.py | 8 +++++++- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/tests/test_dist/test_dist.py b/tests/test_dist/test_dist.py index 3dccb075..14d3dec4 100644 --- a/tests/test_dist/test_dist.py +++ b/tests/test_dist/test_dist.py @@ -6,6 +6,7 @@ from unittest.mock import patch import pytest import torch +import torch.distributed as torch_dist import torch.multiprocessing as mp import mmengine.dist as dist @@ -108,9 +109,16 @@ def init_process(rank, world_size, functions, backend='gloo'): os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = '29505' os.environ['RANK'] = str(rank) - dist.init_dist('pytorch', backend, rank=rank, world_size=world_size) - device = 'cpu' if backend == 'gloo' else 'cuda' + if backend == 'nccl': + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + device = 'cuda' + else: + device = 'cpu' + + torch_dist.init_process_group( + backend=backend, rank=rank, world_size=world_size) for func in functions: func(device) diff --git a/tests/test_dist/test_utils.py b/tests/test_dist/test_utils.py index e099c879..b4b74d45 100644 --- a/tests/test_dist/test_utils.py +++ b/tests/test_dist/test_utils.py @@ -55,7 +55,13 @@ def init_process(rank, world_size, functions, backend='gloo'): os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = '29501' os.environ['RANK'] = str(rank) - dist.init_dist('pytorch', backend, rank=rank, world_size=world_size) + + if backend == 'nccl': + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + + torch_dist.init_process_group( + backend=backend, rank=rank, world_size=world_size) dist.init_local_group(0, world_size) for func in functions: -- GitLab