From 2994195be2f631838f112cc479e88365fe348a0f Mon Sep 17 00:00:00 2001 From: Alex Yang <50511903+imabackstabber@users.noreply.github.com> Date: Thu, 23 Jun 2022 16:53:19 +0800 Subject: [PATCH] [Feat] Support training on MPS (#331) * [Feat] Support mps * fix docstring --- mmengine/device/__init__.py | 4 ++-- mmengine/device/utils.py | 12 +++++++++++- tests/test_device/test_device.py | 5 ++++- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/mmengine/device/__init__.py b/mmengine/device/__init__.py index 1a72f444..0524e6b9 100644 --- a/mmengine/device/__init__.py +++ b/mmengine/device/__init__.py @@ -1,8 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. from .utils import (get_device, get_max_cuda_memory, is_cuda_available, - is_mlu_available) + is_mlu_available, is_mps_available) __all__ = [ 'get_max_cuda_memory', 'get_device', 'is_cuda_available', - 'is_mlu_available' + 'is_mlu_available', 'is_mps_available' ] diff --git a/mmengine/device/utils.py b/mmengine/device/utils.py index aca4476b..c1819cba 100644 --- a/mmengine/device/utils.py +++ b/mmengine/device/utils.py @@ -37,15 +37,25 @@ def is_mlu_available() -> bool: return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available() +def is_mps_available() -> bool: + """Return True if mps devices exist. + + It's specialized for mac m1 chips and require torch version 1.12 or higher. + """ + return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() + + def get_device() -> str: """Returns the currently existing device type. Returns: - str: cuda | mlu | cpu. + str: cuda | mlu | mps | cpu. """ if is_cuda_available(): return 'cuda' elif is_mlu_available(): return 'mlu' + elif is_mps_available(): + return 'mps' else: return 'cpu' diff --git a/tests/test_device/test_device.py b/tests/test_device/test_device.py index 3f5c80b3..1c41721b 100644 --- a/tests/test_device/test_device.py +++ b/tests/test_device/test_device.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmengine.device import get_device, is_cuda_available, is_mlu_available +from mmengine.device import (get_device, is_cuda_available, is_mlu_available, + is_mps_available) def test_get_device(): @@ -8,5 +9,7 @@ def test_get_device(): assert device == 'cuda' elif is_mlu_available(): assert device == 'mlu' + elif is_mps_available(): + assert device == 'mps' else: assert device == 'cpu' -- GitLab