diff --git a/mmengine/device/__init__.py b/mmengine/device/__init__.py index 1a72f4447e392f956b5cf5ef4babdc918333b902..0524e6b9370f150cca311c6390a2532e907a0790 100644 --- a/mmengine/device/__init__.py +++ b/mmengine/device/__init__.py @@ -1,8 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. from .utils import (get_device, get_max_cuda_memory, is_cuda_available, - is_mlu_available) + is_mlu_available, is_mps_available) __all__ = [ 'get_max_cuda_memory', 'get_device', 'is_cuda_available', - 'is_mlu_available' + 'is_mlu_available', 'is_mps_available' ] diff --git a/mmengine/device/utils.py b/mmengine/device/utils.py index aca4476b0eecfbbabdb588475e52c395c4ec538b..c1819cbaaeba6f18b05ce89bfdb4bd9ecae9807f 100644 --- a/mmengine/device/utils.py +++ b/mmengine/device/utils.py @@ -37,15 +37,25 @@ def is_mlu_available() -> bool: return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available() +def is_mps_available() -> bool: + """Return True if mps devices exist. + + It's specialized for mac m1 chips and require torch version 1.12 or higher. + """ + return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() + + def get_device() -> str: """Returns the currently existing device type. Returns: - str: cuda | mlu | cpu. + str: cuda | mlu | mps | cpu. """ if is_cuda_available(): return 'cuda' elif is_mlu_available(): return 'mlu' + elif is_mps_available(): + return 'mps' else: return 'cpu' diff --git a/tests/test_device/test_device.py b/tests/test_device/test_device.py index 3f5c80b303151f83aeb10791bf079f41a14ad575..1c41721bf473aa75ecc1497c102113c251612156 100644 --- a/tests/test_device/test_device.py +++ b/tests/test_device/test_device.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmengine.device import get_device, is_cuda_available, is_mlu_available +from mmengine.device import (get_device, is_cuda_available, is_mlu_available, + is_mps_available) def test_get_device(): @@ -8,5 +9,7 @@ def test_get_device(): assert device == 'cuda' elif is_mlu_available(): assert device == 'mlu' + elif is_mps_available(): + assert device == 'mps' else: assert device == 'cpu'