diff --git a/tests/test_dist/test_dist.py b/tests/test_dist/test_dist.py
index 3dccb075cc74692ce3aa002b87688a2f2e324cea..14d3dec40f8eb0e93deeae4a952e3794e541f022 100644
--- a/tests/test_dist/test_dist.py
+++ b/tests/test_dist/test_dist.py
@@ -6,6 +6,7 @@ from unittest.mock import patch
 
 import pytest
 import torch
+import torch.distributed as torch_dist
 import torch.multiprocessing as mp
 
 import mmengine.dist as dist
@@ -108,9 +109,16 @@ def init_process(rank, world_size, functions, backend='gloo'):
     os.environ['MASTER_ADDR'] = '127.0.0.1'
     os.environ['MASTER_PORT'] = '29505'
     os.environ['RANK'] = str(rank)
-    dist.init_dist('pytorch', backend, rank=rank, world_size=world_size)
 
-    device = 'cpu' if backend == 'gloo' else 'cuda'
+    if backend == 'nccl':
+        num_gpus = torch.cuda.device_count()
+        torch.cuda.set_device(rank % num_gpus)
+        device = 'cuda'
+    else:
+        device = 'cpu'
+
+    torch_dist.init_process_group(
+        backend=backend, rank=rank, world_size=world_size)
 
     for func in functions:
         func(device)
diff --git a/tests/test_dist/test_utils.py b/tests/test_dist/test_utils.py
index e099c8792cd8bb8cb2f1416eaaec919790ed9c13..b4b74d4526d325c9620793a3ff916575d8b6e963 100644
--- a/tests/test_dist/test_utils.py
+++ b/tests/test_dist/test_utils.py
@@ -55,7 +55,13 @@ def init_process(rank, world_size, functions, backend='gloo'):
     os.environ['MASTER_ADDR'] = '127.0.0.1'
     os.environ['MASTER_PORT'] = '29501'
     os.environ['RANK'] = str(rank)
-    dist.init_dist('pytorch', backend, rank=rank, world_size=world_size)
+
+    if backend == 'nccl':
+        num_gpus = torch.cuda.device_count()
+        torch.cuda.set_device(rank % num_gpus)
+
+    torch_dist.init_process_group(
+        backend=backend, rank=rank, world_size=world_size)
     dist.init_local_group(0, world_size)
 
     for func in functions: