Skip to content
Snippets Groups Projects
Unverified Commit 49613414 authored by CokeDong's avatar CokeDong Committed by GitHub
Browse files

[Feature] Support dipu device (#1127)

parent 74512162
No related branches found
No related tags found
No related merge requests found
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .utils import (get_device, get_max_cuda_memory, is_cuda_available, from .utils import (get_device, get_max_cuda_memory, is_cuda_available,
is_mlu_available, is_mps_available, is_npu_available, is_dipu_available, is_mlu_available, is_mps_available,
is_npu_support_full_precision) is_npu_available, is_npu_support_full_precision)
__all__ = [ __all__ = [
'get_max_cuda_memory', 'get_device', 'is_cuda_available', 'get_max_cuda_memory', 'get_device', 'is_cuda_available',
'is_mlu_available', 'is_mps_available', 'is_npu_available', 'is_mlu_available', 'is_mps_available', 'is_npu_available',
'is_npu_support_full_precision' 'is_dipu_available', 'is_npu_support_full_precision'
] ]
...@@ -16,6 +16,12 @@ try: ...@@ -16,6 +16,12 @@ try:
except Exception: except Exception:
IS_NPU_AVAILABLE = False IS_NPU_AVAILABLE = False
try:
import torch_dipu # noqa: F401
IS_DIPU_AVAILABLE = True
except Exception:
IS_DIPU_AVAILABLE = False
def get_max_cuda_memory(device: Optional[torch.device] = None) -> int: def get_max_cuda_memory(device: Optional[torch.device] = None) -> int:
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB) for """Returns the maximum GPU memory occupied by tensors in megabytes (MB) for
...@@ -63,6 +69,10 @@ def is_mps_available() -> bool: ...@@ -63,6 +69,10 @@ def is_mps_available() -> bool:
return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
def is_dipu_available() -> bool:
return IS_DIPU_AVAILABLE
def is_npu_support_full_precision() -> bool: def is_npu_support_full_precision() -> bool:
"""Returns True if npu devices support full precision training.""" """Returns True if npu devices support full precision training."""
version_of_support_full_precision = 220 version_of_support_full_precision = 220
...@@ -79,6 +89,8 @@ elif is_mlu_available(): ...@@ -79,6 +89,8 @@ elif is_mlu_available():
DEVICE = 'mlu' DEVICE = 'mlu'
elif is_mps_available(): elif is_mps_available():
DEVICE = 'mps' DEVICE = 'mps'
elif is_dipu_available():
DEVICE = 'dipu'
def get_device() -> str: def get_device() -> str:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment