diff --git a/mmengine/runner/amp.py b/mmengine/runner/amp.py index 00fff3a9d396aa4659738dab24eec7d0a9b3f557..3c91f2105b637b05c93fda9238c43cea39f4ff4a 100644 --- a/mmengine/runner/amp.py +++ b/mmengine/runner/amp.py @@ -1,16 +1,22 @@ # Copyright (c) OpenMMLab. All rights reserved. +import logging from contextlib import contextmanager +from typing import Optional import torch +from mmengine import print_log from mmengine.utils import TORCH_VERSION, digit_version @contextmanager -def autocast(enabled: bool = True, **kwargs): +def autocast(device_type: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + enabled: bool = True, + cache_enabled: Optional[bool] = None): """A wrapper of ``torch.autocast`` and ``toch.cuda.amp.autocast``. - Pytorch 1.6.0 provide ``torch.cuda.amp.autocast`` for running in + Pytorch 1.5.0 provide ``torch.cuda.amp.autocast`` for running in mixed precision , and update it to ``torch.autocast`` in 1.10.0. Both interfaces have different arguments, and ``torch.autocast`` support running with cpu additionally. @@ -49,9 +55,13 @@ def autocast(enabled: bool = True, **kwargs): >>> pass Args: - enabled (bool): Whether autocasting should be enabled in the region. - Defaults to True. - kwargs (dict): Arguments of torch.autocast except for ``enabled``. + device_type (str, required): Whether to use 'cuda' or 'cpu' device. + enabled(bool): Whether autocasting should be enabled in the region. + Defaults to True + dtype (torch_dtype, optional): Whether to use ``torch.float16`` or + ``torch.bfloat16``. + cache_enabled(bool, optional): Whether the weight cache inside + autocast should be enabled. """ # If `enabled` is True, enable an empty context and all calculations # are performed under fp32. @@ -63,9 +73,17 @@ def autocast(enabled: bool = True, **kwargs): digit_version('1.10.0')): # If pytorch version is between 1.5.0 and 1.10.0, the default value of # dtype for `torch.cuda.amp.autocast` is torch.float16. - assert not kwargs, ( - f'autocast under pytorch {TORCH_VERSION} only accept `enabled` ' - 'arguments.') + assert device_type == 'cuda' or device_type is None, ( + 'Pytorch version under 1.5.0 only supports running automatic ' + 'mixed training with cuda') + if dtype is not None or cache_enabled is not None: + print_log( + f'{dtype} and {device_type} will not work for ' + '`autocast` since your Pytorch version: ' + f'{TORCH_VERSION} <= 1.10.0', + logger='current', + level=logging.WARNING) + if torch.cuda.is_available(): with torch.cuda.amp.autocast(enabled=enabled): yield @@ -77,24 +95,23 @@ def autocast(enabled: bool = True, **kwargs): 'If pytorch versions is between 1.5.0 and 1.10, ' '`autocast` is only available in gpu mode') - elif (digit_version('1.11.0') > digit_version(TORCH_VERSION) >= - digit_version('1.10.0')): - if torch.cuda.is_available(): - kwargs.setdefault('device_type', 'cuda') - else: - kwargs.setdefault('device_type', 'cpu') - # torch.autocast only support `dtype=torch.bfloat16` in - # pytorch 1.10 - kwargs.setdefault('dtype', torch.bfloat16) - - with torch.autocast(enabled=enabled, **kwargs): - yield - - elif digit_version(TORCH_VERSION) >= digit_version('1.11.0'): + else: if torch.cuda.is_available(): - kwargs.setdefault('device_type', 'cuda') + device_type = 'cuda' if device_type is None else device_type else: - kwargs.setdefault('device_type', 'cpu') + device_type = 'cpu' if device_type is None else device_type - with torch.autocast(enabled=enabled, **kwargs): + if digit_version(TORCH_VERSION) < digit_version('1.11.0'): + if dtype is not None and dtype != torch.bfloat16: + print_log( + f'{dtype} must be `torch.bfloat16` with Pytorch ' + f'version: {TORCH_VERSION}', + logger='current', + level=logging.WARNING) + dtype = torch.bfloat16 + with torch.autocast( + device_type=device_type, + enabled=enabled, + dtype=dtype, + cache_enabled=cache_enabled): yield diff --git a/tests/test_runner/test_amp.py b/tests/test_runner/test_amp.py index 8ac0dd55226d699a1d7120ea4e194d1b67e4791a..38801bdab8f24a2fe87eaa6429608e39a5f20935 100644 --- a/tests/test_runner/test_amp.py +++ b/tests/test_runner/test_amp.py @@ -42,16 +42,16 @@ class TestAmp(unittest.TestCase): else: devices = ['cpu', 'cuda'] for device in devices: - with autocast(): + with autocast(device_type=device): # torch.autocast support cpu and cuda mode. layer = nn.Conv2d(1, 1, 1).to(device) res = layer(torch.randn(1, 1, 1, 1).to(device)) self.assertIn(res.dtype, (torch.bfloat16, torch.float16)) - with autocast(enabled=False): + with autocast(enabled=False, device_type=device): res = layer(torch.randn(1, 1, 1, 1).to(device)) self.assertEqual(res.dtype, torch.float32) # Test with fp32_enabled - with autocast(enabled=False): + with autocast(enabled=False, device_type=device): layer = nn.Conv2d(1, 1, 1).to(device) res = layer(torch.randn(1, 1, 1, 1).to(device)) self.assertEqual(res.dtype, torch.float32)