From a4f5533db6d41672e90413dd1e36cfbf1e840f59 Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Wed, 22 Jun 2022 23:12:20 +0800 Subject: [PATCH] fix torch 1.10 amp error (#330) --- mmengine/runner/amp.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/mmengine/runner/amp.py b/mmengine/runner/amp.py index 6278ac1b..00fff3a9 100644 --- a/mmengine/runner/amp.py +++ b/mmengine/runner/amp.py @@ -77,7 +77,20 @@ def autocast(enabled: bool = True, **kwargs): 'If pytorch versions is between 1.5.0 and 1.10, ' '`autocast` is only available in gpu mode') - elif digit_version(TORCH_VERSION) >= digit_version('1.10.0'): + elif (digit_version('1.11.0') > digit_version(TORCH_VERSION) >= + digit_version('1.10.0')): + if torch.cuda.is_available(): + kwargs.setdefault('device_type', 'cuda') + else: + kwargs.setdefault('device_type', 'cpu') + # torch.autocast only support `dtype=torch.bfloat16` in + # pytorch 1.10 + kwargs.setdefault('dtype', torch.bfloat16) + + with torch.autocast(enabled=enabled, **kwargs): + yield + + elif digit_version(TORCH_VERSION) >= digit_version('1.11.0'): if torch.cuda.is_available(): kwargs.setdefault('device_type', 'cuda') else: -- GitLab