From 0ca54eb71b312531cd5862356550fe9a6058f2d5 Mon Sep 17 00:00:00 2001
From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Date: Fri, 1 Apr 2022 12:50:15 +0800
Subject: [PATCH] [Fix] Fix unit tests when gpu is not available (#163)

---
 tests/test_dist/test_dist.py  | 12 ++++++++++--
 tests/test_dist/test_utils.py |  8 +++++++-
 2 files changed, 17 insertions(+), 3 deletions(-)

diff --git a/tests/test_dist/test_dist.py b/tests/test_dist/test_dist.py
index 3dccb075..14d3dec4 100644
--- a/tests/test_dist/test_dist.py
+++ b/tests/test_dist/test_dist.py
@@ -6,6 +6,7 @@ from unittest.mock import patch
 
 import pytest
 import torch
+import torch.distributed as torch_dist
 import torch.multiprocessing as mp
 
 import mmengine.dist as dist
@@ -108,9 +109,16 @@ def init_process(rank, world_size, functions, backend='gloo'):
     os.environ['MASTER_ADDR'] = '127.0.0.1'
     os.environ['MASTER_PORT'] = '29505'
     os.environ['RANK'] = str(rank)
-    dist.init_dist('pytorch', backend, rank=rank, world_size=world_size)
 
-    device = 'cpu' if backend == 'gloo' else 'cuda'
+    if backend == 'nccl':
+        num_gpus = torch.cuda.device_count()
+        torch.cuda.set_device(rank % num_gpus)
+        device = 'cuda'
+    else:
+        device = 'cpu'
+
+    torch_dist.init_process_group(
+        backend=backend, rank=rank, world_size=world_size)
 
     for func in functions:
         func(device)
diff --git a/tests/test_dist/test_utils.py b/tests/test_dist/test_utils.py
index e099c879..b4b74d45 100644
--- a/tests/test_dist/test_utils.py
+++ b/tests/test_dist/test_utils.py
@@ -55,7 +55,13 @@ def init_process(rank, world_size, functions, backend='gloo'):
     os.environ['MASTER_ADDR'] = '127.0.0.1'
     os.environ['MASTER_PORT'] = '29501'
     os.environ['RANK'] = str(rank)
-    dist.init_dist('pytorch', backend, rank=rank, world_size=world_size)
+
+    if backend == 'nccl':
+        num_gpus = torch.cuda.device_count()
+        torch.cuda.set_device(rank % num_gpus)
+
+    torch_dist.init_process_group(
+        backend=backend, rank=rank, world_size=world_size)
     dist.init_local_group(0, world_size)
 
     for func in functions:
-- 
GitLab