Skip to content
Snippets Groups Projects
Unverified Commit 8177ef2a authored by luomaoling's avatar luomaoling Committed by GitHub
Browse files

[Fix] Fix AMP in Ascend and support using NPUJITCompile environment (#994)

* add npu device support

* add npu device support
parent 60872c38
No related branches found
No related tags found
No related merge requests found
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import Optional from typing import Optional
import torch import torch
...@@ -39,7 +40,8 @@ def is_npu_available() -> bool: ...@@ -39,7 +40,8 @@ def is_npu_available() -> bool:
# Enable operator support for dynamic shape and # Enable operator support for dynamic shape and
# binary operator support on the NPU. # binary operator support on the NPU.
torch.npu.set_compile_mode(jit_compile=False) npu_jit_compile = bool(os.getenv('NPUJITCompile', False))
torch.npu.set_compile_mode(jit_compile=npu_jit_compile)
except Exception: except Exception:
return False return False
return hasattr(torch, 'npu') and torch.npu.is_available() return hasattr(torch, 'npu') and torch.npu.is_available()
......
...@@ -126,6 +126,10 @@ def autocast(device_type: Optional[str] = None, ...@@ -126,6 +126,10 @@ def autocast(device_type: Optional[str] = None,
elif device_type == 'mlu': elif device_type == 'mlu':
pass pass
elif device_type == 'npu':
pass
else: else:
# Device like MPS does not support fp16 training or testing. # Device like MPS does not support fp16 training or testing.
# If an inappropriate device is set and fp16 is enabled, an error # If an inappropriate device is set and fp16 is enabled, an error
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment