diff --git a/mmengine/runner/amp.py b/mmengine/runner/amp.py index 6278ac1b58a0747637396f3e0b17deadc8538b9c..00fff3a9d396aa4659738dab24eec7d0a9b3f557 100644 --- a/mmengine/runner/amp.py +++ b/mmengine/runner/amp.py @@ -77,7 +77,20 @@ def autocast(enabled: bool = True, **kwargs): 'If pytorch versions is between 1.5.0 and 1.10, ' '`autocast` is only available in gpu mode') - elif digit_version(TORCH_VERSION) >= digit_version('1.10.0'): + elif (digit_version('1.11.0') > digit_version(TORCH_VERSION) >= + digit_version('1.10.0')): + if torch.cuda.is_available(): + kwargs.setdefault('device_type', 'cuda') + else: + kwargs.setdefault('device_type', 'cpu') + # torch.autocast only support `dtype=torch.bfloat16` in + # pytorch 1.10 + kwargs.setdefault('dtype', torch.bfloat16) + + with torch.autocast(enabled=enabled, **kwargs): + yield + + elif digit_version(TORCH_VERSION) >= digit_version('1.11.0'): if torch.cuda.is_available(): kwargs.setdefault('device_type', 'cuda') else: