Skip to content
Snippets Groups Projects
Unverified Commit d0d71742 authored by Jiazhen Wang's avatar Jiazhen Wang Committed by GitHub
Browse files

[Feature] Support MLU Devices (#288)

* support mlu

* add ut and refine docstring
parent e1ed5669
No related branches found
No related tags found
No related merge requests found
# Copyright (c) OpenMMLab. All rights reserved.
from .utils import get_max_cuda_memory
from .utils import (get_device, get_max_cuda_memory, is_cuda_available,
is_mlu_available)
__all__ = ['get_max_cuda_memory']
__all__ = [
'get_max_cuda_memory', 'get_device', 'is_cuda_available',
'is_mlu_available'
]
......@@ -25,3 +25,27 @@ def get_max_cuda_memory(device: Optional[torch.device] = None) -> int:
device=device)
torch.cuda.reset_peak_memory_stats()
return int(mem_mb.item())
def is_cuda_available() -> bool:
"""Returns True if cuda devices exist."""
return torch.cuda.is_available()
def is_mlu_available() -> bool:
"""Returns True if Cambricon PyTorch and mlu devices exist."""
return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()
def get_device() -> str:
"""Returns the currently existing device type.
Returns:
str: cuda | mlu | cpu.
"""
if is_cuda_available():
return 'cuda'
elif is_mlu_available():
return 'mlu'
else:
return 'cpu'
......@@ -10,6 +10,7 @@ import torch.multiprocessing as mp
from torch import Tensor
from torch import distributed as torch_dist
from torch.distributed import ProcessGroup
from mmengine.device import is_mlu_available
try:
# for python < 3.10
......@@ -76,9 +77,18 @@ def _init_dist_pytorch(backend, **kwargs) -> None:
"""
# TODO: use local_rank instead of rank % num_gpus
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
torch_dist.init_process_group(backend=backend, **kwargs)
if is_mlu_available():
import torch_mlu # noqa: F401
torch.mlu.set_device(rank)
torch_dist.init_process_group(
backend='cncl',
rank=rank,
world_size=int(os.environ['WORLD_SIZE']),
**kwargs)
else:
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
torch_dist.init_process_group(backend=backend, **kwargs)
def _init_dist_mpi(backend, **kwargs) -> None:
......@@ -425,6 +435,9 @@ def get_comm_device(group: Optional[ProcessGroup] = None) -> torch.device:
backend = get_backend(group)
if backend == torch_dist.Backend.NCCL:
return torch.device('cuda', torch.cuda.current_device())
elif backend == 'cncl':
import torch_mlu # noqa: F401
return torch.device('mlu', torch.mlu.current_device())
else:
# GLOO and MPI backends use cpu device by default
return torch.device('cpu')
......
......@@ -21,6 +21,7 @@ from torch.utils.data import DataLoader
import mmengine
from mmengine.config import Config, ConfigDict
from mmengine.data import pseudo_collate, worker_init_fn
from mmengine.device import get_device
from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist,
master_only, sync_random_seed)
from mmengine.evaluator import Evaluator
......@@ -821,8 +822,7 @@ class Runner:
return model
# Set `export CUDA_VISIBLE_DEVICES=-1` to enable CPU training.
if torch.cuda.is_available():
model = model.cuda()
model = model.to(get_device())
if not self.distributed:
return model
......
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.device import get_device, is_cuda_available, is_mlu_available
def test_get_device():
device = get_device()
if is_cuda_available():
assert device == 'cuda'
elif is_mlu_available():
assert device == 'mlu'
else:
assert device == 'cpu'
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment