diff --git a/mmengine/device/__init__.py b/mmengine/device/__init__.py index 604064ae2b6ff95956d1aafe35e9e12b3e90ec71..1a72f4447e392f956b5cf5ef4babdc918333b902 100644 --- a/mmengine/device/__init__.py +++ b/mmengine/device/__init__.py @@ -1,4 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .utils import get_max_cuda_memory +from .utils import (get_device, get_max_cuda_memory, is_cuda_available, + is_mlu_available) -__all__ = ['get_max_cuda_memory'] +__all__ = [ + 'get_max_cuda_memory', 'get_device', 'is_cuda_available', + 'is_mlu_available' +] diff --git a/mmengine/device/utils.py b/mmengine/device/utils.py index b2d5ba2aee2cdc15d22645b310cd38d6f5e265e4..aca4476b0eecfbbabdb588475e52c395c4ec538b 100644 --- a/mmengine/device/utils.py +++ b/mmengine/device/utils.py @@ -25,3 +25,27 @@ def get_max_cuda_memory(device: Optional[torch.device] = None) -> int: device=device) torch.cuda.reset_peak_memory_stats() return int(mem_mb.item()) + + +def is_cuda_available() -> bool: + """Returns True if cuda devices exist.""" + return torch.cuda.is_available() + + +def is_mlu_available() -> bool: + """Returns True if Cambricon PyTorch and mlu devices exist.""" + return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available() + + +def get_device() -> str: + """Returns the currently existing device type. + + Returns: + str: cuda | mlu | cpu. + """ + if is_cuda_available(): + return 'cuda' + elif is_mlu_available(): + return 'mlu' + else: + return 'cpu' diff --git a/mmengine/dist/utils.py b/mmengine/dist/utils.py index 8cb9bc96e4c586fdba7a21b2ed1aaf83db27b921..76171b908d5b79245c5d0bc48aa69878f9af9e39 100644 --- a/mmengine/dist/utils.py +++ b/mmengine/dist/utils.py @@ -10,6 +10,7 @@ import torch.multiprocessing as mp from torch import Tensor from torch import distributed as torch_dist from torch.distributed import ProcessGroup +from mmengine.device import is_mlu_available try: # for python < 3.10 @@ -76,9 +77,18 @@ def _init_dist_pytorch(backend, **kwargs) -> None: """ # TODO: use local_rank instead of rank % num_gpus rank = int(os.environ['RANK']) - num_gpus = torch.cuda.device_count() - torch.cuda.set_device(rank % num_gpus) - torch_dist.init_process_group(backend=backend, **kwargs) + if is_mlu_available(): + import torch_mlu # noqa: F401 + torch.mlu.set_device(rank) + torch_dist.init_process_group( + backend='cncl', + rank=rank, + world_size=int(os.environ['WORLD_SIZE']), + **kwargs) + else: + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + torch_dist.init_process_group(backend=backend, **kwargs) def _init_dist_mpi(backend, **kwargs) -> None: @@ -425,6 +435,9 @@ def get_comm_device(group: Optional[ProcessGroup] = None) -> torch.device: backend = get_backend(group) if backend == torch_dist.Backend.NCCL: return torch.device('cuda', torch.cuda.current_device()) + elif backend == 'cncl': + import torch_mlu # noqa: F401 + return torch.device('mlu', torch.mlu.current_device()) else: # GLOO and MPI backends use cpu device by default return torch.device('cpu') diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 067dd48b00deca81a68255fba85ce79ab9c36ef5..89d5e64691264407e59cc1e9aa5fa88b86bca7db 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -21,6 +21,7 @@ from torch.utils.data import DataLoader import mmengine from mmengine.config import Config, ConfigDict from mmengine.data import pseudo_collate, worker_init_fn +from mmengine.device import get_device from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist, master_only, sync_random_seed) from mmengine.evaluator import Evaluator @@ -821,8 +822,7 @@ class Runner: return model # Set `export CUDA_VISIBLE_DEVICES=-1` to enable CPU training. - if torch.cuda.is_available(): - model = model.cuda() + model = model.to(get_device()) if not self.distributed: return model diff --git a/tests/test_device/test_device.py b/tests/test_device/test_device.py new file mode 100644 index 0000000000000000000000000000000000000000..3f5c80b303151f83aeb10791bf079f41a14ad575 --- /dev/null +++ b/tests/test_device/test_device.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.device import get_device, is_cuda_available, is_mlu_available + + +def test_get_device(): + device = get_device() + if is_cuda_available(): + assert device == 'cuda' + elif is_mlu_available(): + assert device == 'mlu' + else: + assert device == 'cpu'