From 2994195be2f631838f112cc479e88365fe348a0f Mon Sep 17 00:00:00 2001
From: Alex Yang <50511903+imabackstabber@users.noreply.github.com>
Date: Thu, 23 Jun 2022 16:53:19 +0800
Subject: [PATCH] [Feat] Support training on MPS (#331)

* [Feat] Support mps

* fix docstring
---
 mmengine/device/__init__.py      |  4 ++--
 mmengine/device/utils.py         | 12 +++++++++++-
 tests/test_device/test_device.py |  5 ++++-
 3 files changed, 17 insertions(+), 4 deletions(-)

diff --git a/mmengine/device/__init__.py b/mmengine/device/__init__.py
index 1a72f444..0524e6b9 100644
--- a/mmengine/device/__init__.py
+++ b/mmengine/device/__init__.py
@@ -1,8 +1,8 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 from .utils import (get_device, get_max_cuda_memory, is_cuda_available,
-                    is_mlu_available)
+                    is_mlu_available, is_mps_available)
 
 __all__ = [
     'get_max_cuda_memory', 'get_device', 'is_cuda_available',
-    'is_mlu_available'
+    'is_mlu_available', 'is_mps_available'
 ]
diff --git a/mmengine/device/utils.py b/mmengine/device/utils.py
index aca4476b..c1819cba 100644
--- a/mmengine/device/utils.py
+++ b/mmengine/device/utils.py
@@ -37,15 +37,25 @@ def is_mlu_available() -> bool:
     return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()
 
 
+def is_mps_available() -> bool:
+    """Return True if mps devices exist.
+
+    It's specialized for mac m1 chips and require torch version 1.12 or higher.
+    """
+    return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
+
+
 def get_device() -> str:
     """Returns the currently existing device type.
 
     Returns:
-        str: cuda | mlu | cpu.
+        str: cuda | mlu | mps | cpu.
     """
     if is_cuda_available():
         return 'cuda'
     elif is_mlu_available():
         return 'mlu'
+    elif is_mps_available():
+        return 'mps'
     else:
         return 'cpu'
diff --git a/tests/test_device/test_device.py b/tests/test_device/test_device.py
index 3f5c80b3..1c41721b 100644
--- a/tests/test_device/test_device.py
+++ b/tests/test_device/test_device.py
@@ -1,5 +1,6 @@
 # Copyright (c) OpenMMLab. All rights reserved.
-from mmengine.device import get_device, is_cuda_available, is_mlu_available
+from mmengine.device import (get_device, is_cuda_available, is_mlu_available,
+                             is_mps_available)
 
 
 def test_get_device():
@@ -8,5 +9,7 @@ def test_get_device():
         assert device == 'cuda'
     elif is_mlu_available():
         assert device == 'mlu'
+    elif is_mps_available():
+        assert device == 'mps'
     else:
         assert device == 'cpu'
-- 
GitLab