Skip to content
Snippets Groups Projects
Unverified Commit 2994195b authored by Alex Yang's avatar Alex Yang Committed by GitHub
Browse files

[Feat] Support training on MPS (#331)

* [Feat] Support mps

* fix docstring
parent e877862d
No related branches found
No related tags found
No related merge requests found
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .utils import (get_device, get_max_cuda_memory, is_cuda_available, from .utils import (get_device, get_max_cuda_memory, is_cuda_available,
is_mlu_available) is_mlu_available, is_mps_available)
__all__ = [ __all__ = [
'get_max_cuda_memory', 'get_device', 'is_cuda_available', 'get_max_cuda_memory', 'get_device', 'is_cuda_available',
'is_mlu_available' 'is_mlu_available', 'is_mps_available'
] ]
...@@ -37,15 +37,25 @@ def is_mlu_available() -> bool: ...@@ -37,15 +37,25 @@ def is_mlu_available() -> bool:
return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available() return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()
def is_mps_available() -> bool:
"""Return True if mps devices exist.
It's specialized for mac m1 chips and require torch version 1.12 or higher.
"""
return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
def get_device() -> str: def get_device() -> str:
"""Returns the currently existing device type. """Returns the currently existing device type.
Returns: Returns:
str: cuda | mlu | cpu. str: cuda | mlu | mps | cpu.
""" """
if is_cuda_available(): if is_cuda_available():
return 'cuda' return 'cuda'
elif is_mlu_available(): elif is_mlu_available():
return 'mlu' return 'mlu'
elif is_mps_available():
return 'mps'
else: else:
return 'cpu' return 'cpu'
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmengine.device import get_device, is_cuda_available, is_mlu_available from mmengine.device import (get_device, is_cuda_available, is_mlu_available,
is_mps_available)
def test_get_device(): def test_get_device():
...@@ -8,5 +9,7 @@ def test_get_device(): ...@@ -8,5 +9,7 @@ def test_get_device():
assert device == 'cuda' assert device == 'cuda'
elif is_mlu_available(): elif is_mlu_available():
assert device == 'mlu' assert device == 'mlu'
elif is_mps_available():
assert device == 'mps'
else: else:
assert device == 'cpu' assert device == 'cpu'
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment