Skip to content
Snippets Groups Projects
Unverified Commit 0ca54eb7 authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

[Fix] Fix unit tests when gpu is not available (#163)

parent 25014af3
No related branches found
No related tags found
No related merge requests found
...@@ -6,6 +6,7 @@ from unittest.mock import patch ...@@ -6,6 +6,7 @@ from unittest.mock import patch
import pytest import pytest
import torch import torch
import torch.distributed as torch_dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import mmengine.dist as dist import mmengine.dist as dist
...@@ -108,9 +109,16 @@ def init_process(rank, world_size, functions, backend='gloo'): ...@@ -108,9 +109,16 @@ def init_process(rank, world_size, functions, backend='gloo'):
os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29505' os.environ['MASTER_PORT'] = '29505'
os.environ['RANK'] = str(rank) os.environ['RANK'] = str(rank)
dist.init_dist('pytorch', backend, rank=rank, world_size=world_size)
device = 'cpu' if backend == 'gloo' else 'cuda' if backend == 'nccl':
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
device = 'cuda'
else:
device = 'cpu'
torch_dist.init_process_group(
backend=backend, rank=rank, world_size=world_size)
for func in functions: for func in functions:
func(device) func(device)
......
...@@ -55,7 +55,13 @@ def init_process(rank, world_size, functions, backend='gloo'): ...@@ -55,7 +55,13 @@ def init_process(rank, world_size, functions, backend='gloo'):
os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29501' os.environ['MASTER_PORT'] = '29501'
os.environ['RANK'] = str(rank) os.environ['RANK'] = str(rank)
dist.init_dist('pytorch', backend, rank=rank, world_size=world_size)
if backend == 'nccl':
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
torch_dist.init_process_group(
backend=backend, rank=rank, world_size=world_size)
dist.init_local_group(0, world_size) dist.init_local_group(0, world_size)
for func in functions: for func in functions:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment