Skip to content
Snippets Groups Projects
Unverified Commit 59b0ccfe authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Fix] Fix pytorch version compatibility of autocast (#339)

* fix unit test of autocast

* fix compatiblity of unit test of optimizerwrapper

* clean code

* fix as comment

* fix docstring
parent 5ac3c233
No related branches found
No related tags found
No related merge requests found
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import logging
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional
import torch import torch
from mmengine import print_log
from mmengine.utils import TORCH_VERSION, digit_version from mmengine.utils import TORCH_VERSION, digit_version
@contextmanager @contextmanager
def autocast(enabled: bool = True, **kwargs): def autocast(device_type: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
enabled: bool = True,
cache_enabled: Optional[bool] = None):
"""A wrapper of ``torch.autocast`` and ``toch.cuda.amp.autocast``. """A wrapper of ``torch.autocast`` and ``toch.cuda.amp.autocast``.
Pytorch 1.6.0 provide ``torch.cuda.amp.autocast`` for running in Pytorch 1.5.0 provide ``torch.cuda.amp.autocast`` for running in
mixed precision , and update it to ``torch.autocast`` in 1.10.0. mixed precision , and update it to ``torch.autocast`` in 1.10.0.
Both interfaces have different arguments, and ``torch.autocast`` Both interfaces have different arguments, and ``torch.autocast``
support running with cpu additionally. support running with cpu additionally.
...@@ -49,9 +55,13 @@ def autocast(enabled: bool = True, **kwargs): ...@@ -49,9 +55,13 @@ def autocast(enabled: bool = True, **kwargs):
>>> pass >>> pass
Args: Args:
enabled (bool): Whether autocasting should be enabled in the region. device_type (str, required): Whether to use 'cuda' or 'cpu' device.
Defaults to True. enabled(bool): Whether autocasting should be enabled in the region.
kwargs (dict): Arguments of torch.autocast except for ``enabled``. Defaults to True
dtype (torch_dtype, optional): Whether to use ``torch.float16`` or
``torch.bfloat16``.
cache_enabled(bool, optional): Whether the weight cache inside
autocast should be enabled.
""" """
# If `enabled` is True, enable an empty context and all calculations # If `enabled` is True, enable an empty context and all calculations
# are performed under fp32. # are performed under fp32.
...@@ -63,9 +73,17 @@ def autocast(enabled: bool = True, **kwargs): ...@@ -63,9 +73,17 @@ def autocast(enabled: bool = True, **kwargs):
digit_version('1.10.0')): digit_version('1.10.0')):
# If pytorch version is between 1.5.0 and 1.10.0, the default value of # If pytorch version is between 1.5.0 and 1.10.0, the default value of
# dtype for `torch.cuda.amp.autocast` is torch.float16. # dtype for `torch.cuda.amp.autocast` is torch.float16.
assert not kwargs, ( assert device_type == 'cuda' or device_type is None, (
f'autocast under pytorch {TORCH_VERSION} only accept `enabled` ' 'Pytorch version under 1.5.0 only supports running automatic '
'arguments.') 'mixed training with cuda')
if dtype is not None or cache_enabled is not None:
print_log(
f'{dtype} and {device_type} will not work for '
'`autocast` since your Pytorch version: '
f'{TORCH_VERSION} <= 1.10.0',
logger='current',
level=logging.WARNING)
if torch.cuda.is_available(): if torch.cuda.is_available():
with torch.cuda.amp.autocast(enabled=enabled): with torch.cuda.amp.autocast(enabled=enabled):
yield yield
...@@ -77,24 +95,23 @@ def autocast(enabled: bool = True, **kwargs): ...@@ -77,24 +95,23 @@ def autocast(enabled: bool = True, **kwargs):
'If pytorch versions is between 1.5.0 and 1.10, ' 'If pytorch versions is between 1.5.0 and 1.10, '
'`autocast` is only available in gpu mode') '`autocast` is only available in gpu mode')
elif (digit_version('1.11.0') > digit_version(TORCH_VERSION) >= else:
digit_version('1.10.0')):
if torch.cuda.is_available():
kwargs.setdefault('device_type', 'cuda')
else:
kwargs.setdefault('device_type', 'cpu')
# torch.autocast only support `dtype=torch.bfloat16` in
# pytorch 1.10
kwargs.setdefault('dtype', torch.bfloat16)
with torch.autocast(enabled=enabled, **kwargs):
yield
elif digit_version(TORCH_VERSION) >= digit_version('1.11.0'):
if torch.cuda.is_available(): if torch.cuda.is_available():
kwargs.setdefault('device_type', 'cuda') device_type = 'cuda' if device_type is None else device_type
else: else:
kwargs.setdefault('device_type', 'cpu') device_type = 'cpu' if device_type is None else device_type
with torch.autocast(enabled=enabled, **kwargs): if digit_version(TORCH_VERSION) < digit_version('1.11.0'):
if dtype is not None and dtype != torch.bfloat16:
print_log(
f'{dtype} must be `torch.bfloat16` with Pytorch '
f'version: {TORCH_VERSION}',
logger='current',
level=logging.WARNING)
dtype = torch.bfloat16
with torch.autocast(
device_type=device_type,
enabled=enabled,
dtype=dtype,
cache_enabled=cache_enabled):
yield yield
...@@ -42,16 +42,16 @@ class TestAmp(unittest.TestCase): ...@@ -42,16 +42,16 @@ class TestAmp(unittest.TestCase):
else: else:
devices = ['cpu', 'cuda'] devices = ['cpu', 'cuda']
for device in devices: for device in devices:
with autocast(): with autocast(device_type=device):
# torch.autocast support cpu and cuda mode. # torch.autocast support cpu and cuda mode.
layer = nn.Conv2d(1, 1, 1).to(device) layer = nn.Conv2d(1, 1, 1).to(device)
res = layer(torch.randn(1, 1, 1, 1).to(device)) res = layer(torch.randn(1, 1, 1, 1).to(device))
self.assertIn(res.dtype, (torch.bfloat16, torch.float16)) self.assertIn(res.dtype, (torch.bfloat16, torch.float16))
with autocast(enabled=False): with autocast(enabled=False, device_type=device):
res = layer(torch.randn(1, 1, 1, 1).to(device)) res = layer(torch.randn(1, 1, 1, 1).to(device))
self.assertEqual(res.dtype, torch.float32) self.assertEqual(res.dtype, torch.float32)
# Test with fp32_enabled # Test with fp32_enabled
with autocast(enabled=False): with autocast(enabled=False, device_type=device):
layer = nn.Conv2d(1, 1, 1).to(device) layer = nn.Conv2d(1, 1, 1).to(device)
res = layer(torch.randn(1, 1, 1, 1).to(device)) res = layer(torch.randn(1, 1, 1, 1).to(device))
self.assertEqual(res.dtype, torch.float32) self.assertEqual(res.dtype, torch.float32)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment