From d0d71742748c7b15ce63d8a61e96383ea0dab1d9 Mon Sep 17 00:00:00 2001
From: Jiazhen Wang <47851024+teamwong111@users.noreply.github.com>
Date: Thu, 16 Jun 2022 20:28:09 +0800
Subject: [PATCH] [Feature] Support MLU Devices (#288)

* support mlu

* add ut and refine docstring
---
 mmengine/device/__init__.py      |  8 ++++++--
 mmengine/device/utils.py         | 24 ++++++++++++++++++++++++
 mmengine/dist/utils.py           | 19 ++++++++++++++++---
 mmengine/runner/runner.py        |  4 ++--
 tests/test_device/test_device.py | 12 ++++++++++++
 5 files changed, 60 insertions(+), 7 deletions(-)
 create mode 100644 tests/test_device/test_device.py

diff --git a/mmengine/device/__init__.py b/mmengine/device/__init__.py
index 604064ae..1a72f444 100644
--- a/mmengine/device/__init__.py
+++ b/mmengine/device/__init__.py
@@ -1,4 +1,8 @@
 # Copyright (c) OpenMMLab. All rights reserved.
-from .utils import get_max_cuda_memory
+from .utils import (get_device, get_max_cuda_memory, is_cuda_available,
+                    is_mlu_available)
 
-__all__ = ['get_max_cuda_memory']
+__all__ = [
+    'get_max_cuda_memory', 'get_device', 'is_cuda_available',
+    'is_mlu_available'
+]
diff --git a/mmengine/device/utils.py b/mmengine/device/utils.py
index b2d5ba2a..aca4476b 100644
--- a/mmengine/device/utils.py
+++ b/mmengine/device/utils.py
@@ -25,3 +25,27 @@ def get_max_cuda_memory(device: Optional[torch.device] = None) -> int:
                           device=device)
     torch.cuda.reset_peak_memory_stats()
     return int(mem_mb.item())
+
+
+def is_cuda_available() -> bool:
+    """Returns True if cuda devices exist."""
+    return torch.cuda.is_available()
+
+
+def is_mlu_available() -> bool:
+    """Returns True if Cambricon PyTorch and mlu devices exist."""
+    return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()
+
+
+def get_device() -> str:
+    """Returns the currently existing device type.
+
+    Returns:
+        str: cuda | mlu | cpu.
+    """
+    if is_cuda_available():
+        return 'cuda'
+    elif is_mlu_available():
+        return 'mlu'
+    else:
+        return 'cpu'
diff --git a/mmengine/dist/utils.py b/mmengine/dist/utils.py
index 8cb9bc96..76171b90 100644
--- a/mmengine/dist/utils.py
+++ b/mmengine/dist/utils.py
@@ -10,6 +10,7 @@ import torch.multiprocessing as mp
 from torch import Tensor
 from torch import distributed as torch_dist
 from torch.distributed import ProcessGroup
+from mmengine.device import is_mlu_available
 
 try:
     # for python < 3.10
@@ -76,9 +77,18 @@ def _init_dist_pytorch(backend, **kwargs) -> None:
     """
     # TODO: use local_rank instead of rank % num_gpus
     rank = int(os.environ['RANK'])
-    num_gpus = torch.cuda.device_count()
-    torch.cuda.set_device(rank % num_gpus)
-    torch_dist.init_process_group(backend=backend, **kwargs)
+    if is_mlu_available():
+        import torch_mlu  # noqa: F401
+        torch.mlu.set_device(rank)
+        torch_dist.init_process_group(
+            backend='cncl',
+            rank=rank,
+            world_size=int(os.environ['WORLD_SIZE']),
+            **kwargs)
+    else:
+        num_gpus = torch.cuda.device_count()
+        torch.cuda.set_device(rank % num_gpus)
+        torch_dist.init_process_group(backend=backend, **kwargs)
 
 
 def _init_dist_mpi(backend, **kwargs) -> None:
@@ -425,6 +435,9 @@ def get_comm_device(group: Optional[ProcessGroup] = None) -> torch.device:
     backend = get_backend(group)
     if backend == torch_dist.Backend.NCCL:
         return torch.device('cuda', torch.cuda.current_device())
+    elif backend == 'cncl':
+        import torch_mlu  # noqa: F401
+        return torch.device('mlu', torch.mlu.current_device())
     else:
         # GLOO and MPI backends use cpu device by default
         return torch.device('cpu')
diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py
index 067dd48b..89d5e646 100644
--- a/mmengine/runner/runner.py
+++ b/mmengine/runner/runner.py
@@ -21,6 +21,7 @@ from torch.utils.data import DataLoader
 import mmengine
 from mmengine.config import Config, ConfigDict
 from mmengine.data import pseudo_collate, worker_init_fn
+from mmengine.device import get_device
 from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist,
                            master_only, sync_random_seed)
 from mmengine.evaluator import Evaluator
@@ -821,8 +822,7 @@ class Runner:
             return model
 
         # Set `export CUDA_VISIBLE_DEVICES=-1` to enable CPU training.
-        if torch.cuda.is_available():
-            model = model.cuda()
+        model = model.to(get_device())
 
         if not self.distributed:
             return model
diff --git a/tests/test_device/test_device.py b/tests/test_device/test_device.py
new file mode 100644
index 00000000..3f5c80b3
--- /dev/null
+++ b/tests/test_device/test_device.py
@@ -0,0 +1,12 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmengine.device import get_device, is_cuda_available, is_mlu_available
+
+
+def test_get_device():
+    device = get_device()
+    if is_cuda_available():
+        assert device == 'cuda'
+    elif is_mlu_available():
+        assert device == 'mlu'
+    else:
+        assert device == 'cpu'
-- 
GitLab