Skip to content
Snippets Groups Projects
Unverified Commit 79067e46 authored by wangjiangben-hw's avatar wangjiangben-hw Committed by GitHub
Browse files

[Fix] Add support for Ascend device (#847)

* add npu device support

* add comment for torch.npu.set_compile_mode
parent 925ac870
No related branches found
No related tags found
No related merge requests found
...@@ -36,6 +36,10 @@ def is_npu_available() -> bool: ...@@ -36,6 +36,10 @@ def is_npu_available() -> bool:
"""Returns True if Ascend PyTorch and npu devices exist.""" """Returns True if Ascend PyTorch and npu devices exist."""
try: try:
import torch_npu # noqa: F401 import torch_npu # noqa: F401
# Enable operator support for dynamic shape and
# binary operator support on the NPU.
torch.npu.set_compile_mode(jit_compile=False)
except Exception: except Exception:
return False return False
return hasattr(torch, 'npu') and torch.npu.is_available() return hasattr(torch, 'npu') and torch.npu.is_available()
......
...@@ -184,6 +184,18 @@ class BaseModel(BaseModule): ...@@ -184,6 +184,18 @@ class BaseModel(BaseModule):
Returns: Returns:
nn.Module: The model itself. nn.Module: The model itself.
""" """
# Since Torch has not officially merged
# the npu-related fields, using the _parse_to function
# directly will cause the NPU to not be found.
# Here, the input parameters are processed to avoid errors.
if args and isinstance(args[0], str) and 'npu' in args[0]:
args = tuple(
[list(args)[0].replace('npu', torch.npu.native_device)])
if kwargs and 'npu' in str(kwargs.get('device', '')):
kwargs['device'] = kwargs['device'].replace(
'npu', torch.npu.native_device)
device = torch._C._nn._parse_to(*args, **kwargs)[0] device = torch._C._nn._parse_to(*args, **kwargs)[0]
if device is not None: if device is not None:
self._set_device(torch.device(device)) self._set_device(torch.device(device))
......
...@@ -87,6 +87,18 @@ class BaseDataPreprocessor(nn.Module): ...@@ -87,6 +87,18 @@ class BaseDataPreprocessor(nn.Module):
Returns: Returns:
nn.Module: The model itself. nn.Module: The model itself.
""" """
# Since Torch has not officially merged
# the npu-related fields, using the _parse_to function
# directly will cause the NPU to not be found.
# Here, the input parameters are processed to avoid errors.
if args and isinstance(args[0], str) and 'npu' in args[0]:
args = tuple(
[list(args)[0].replace('npu', torch.npu.native_device)])
if kwargs and 'npu' in str(kwargs.get('device', '')):
kwargs['device'] = kwargs['device'].replace(
'npu', torch.npu.native_device)
device = torch._C._nn._parse_to(*args, **kwargs)[0] device = torch._C._nn._parse_to(*args, **kwargs)[0]
if device is not None: if device is not None:
self._device = torch.device(device) self._device = torch.device(device)
...@@ -101,6 +113,15 @@ class BaseDataPreprocessor(nn.Module): ...@@ -101,6 +113,15 @@ class BaseDataPreprocessor(nn.Module):
self._device = torch.device(torch.cuda.current_device()) self._device = torch.device(torch.cuda.current_device())
return super().cuda() return super().cuda()
def npu(self, *args, **kwargs) -> nn.Module:
"""Overrides this method to set the :attr:`device`
Returns:
nn.Module: The model itself.
"""
self._device = torch.device(torch.npu.current_device())
return super().npu()
def cpu(self, *args, **kwargs) -> nn.Module: def cpu(self, *args, **kwargs) -> nn.Module:
"""Overrides this method to set the :attr:`device` """Overrides this method to set the :attr:`device`
......
...@@ -507,6 +507,17 @@ class BaseDataElement: ...@@ -507,6 +507,17 @@ class BaseDataElement:
new_data.set_data(data) new_data.set_data(data)
return new_data return new_data
# Tensor-like methods
def npu(self) -> 'BaseDataElement':
"""Convert all tensors to NPU in data."""
new_data = self.new()
for k, v in self.items():
if isinstance(v, (torch.Tensor, BaseDataElement)):
v = v.npu()
data = {k: v}
new_data.set_data(data)
return new_data
# Tensor-like methods # Tensor-like methods
def detach(self) -> 'BaseDataElement': def detach(self) -> 'BaseDataElement':
"""Detach all tensors in data.""" """Detach all tensors in data."""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment