Skip to content
Snippets Groups Projects
Unverified Commit abe56651 authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Fix] Make autocast compatible with mps (#587)

* [Fix] Make autocast compatible with mps

* Enhance unit test

* fix unit test

* clean the code

* fix unit test
parent 6073d9eb
No related branches found
No related tags found
No related merge requests found
...@@ -121,9 +121,18 @@ def autocast(device_type: Optional[str] = None, ...@@ -121,9 +121,18 @@ def autocast(device_type: Optional[str] = None,
assert dtype == torch.bfloat16, ( assert dtype == torch.bfloat16, (
'In CPU autocast, only support `torch.bfloat16` dtype') 'In CPU autocast, only support `torch.bfloat16` dtype')
elif device_type == 'mlu':
pass
else: else:
raise ValueError('User specified autocast device_type must be ' # Device like MPS does not support fp16 training or testing.
F'cuda or cpu, but got {device_type}') # If an inappropriate device is set and fp16 is enabled, an error
# will be thrown.
if enabled is False:
yield
return
else:
raise ValueError('User specified autocast device_type must be '
f'cuda or cpu, but got {device_type}')
with torch.autocast( with torch.autocast(
device_type=device_type, device_type=device_type,
......
...@@ -4,6 +4,8 @@ import unittest ...@@ -4,6 +4,8 @@ import unittest
import torch import torch
import torch.nn as nn import torch.nn as nn
import mmengine
from mmengine.device import get_device
from mmengine.runner import autocast from mmengine.runner import autocast
from mmengine.utils import digit_version from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION from mmengine.utils.dl_utils import TORCH_VERSION
...@@ -56,3 +58,24 @@ class TestAmp(unittest.TestCase): ...@@ -56,3 +58,24 @@ class TestAmp(unittest.TestCase):
layer = nn.Conv2d(1, 1, 1).to(device) layer = nn.Conv2d(1, 1, 1).to(device)
res = layer(torch.randn(1, 1, 1, 1).to(device)) res = layer(torch.randn(1, 1, 1, 1).to(device))
self.assertEqual(res.dtype, torch.float32) self.assertEqual(res.dtype, torch.float32)
# Test mps
if digit_version(TORCH_VERSION) >= digit_version('1.12.0'):
mmengine.runner.amp.get_device = lambda: 'mps'
with autocast(enabled=False):
layer = nn.Conv2d(1, 1, 1)
res = layer(torch.randn(1, 1, 1, 1))
self.assertEqual(res.dtype, torch.float32)
with self.assertRaisesRegex(ValueError,
'User specified autocast device_type'):
with autocast(enabled=True):
pass
# Native pytorch does not support mlu, here we simply test autocast
# will call `torch.autocast`, which will be overridden by mlu version
# pytorch
mmengine.runner.amp.get_device = lambda: 'mlu'
with self.assertRaises(RuntimeError):
with autocast(enabled=False):
pass
mmengine.runner.amp.get_device = get_device
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment