From a4f5533db6d41672e90413dd1e36cfbf1e840f59 Mon Sep 17 00:00:00 2001
From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>
Date: Wed, 22 Jun 2022 23:12:20 +0800
Subject: [PATCH] fix torch 1.10 amp error (#330)

---
 mmengine/runner/amp.py | 15 ++++++++++++++-
 1 file changed, 14 insertions(+), 1 deletion(-)

diff --git a/mmengine/runner/amp.py b/mmengine/runner/amp.py
index 6278ac1b..00fff3a9 100644
--- a/mmengine/runner/amp.py
+++ b/mmengine/runner/amp.py
@@ -77,7 +77,20 @@ def autocast(enabled: bool = True, **kwargs):
                     'If pytorch versions is between 1.5.0 and 1.10, '
                     '`autocast` is only available in gpu mode')
 
-    elif digit_version(TORCH_VERSION) >= digit_version('1.10.0'):
+    elif (digit_version('1.11.0') > digit_version(TORCH_VERSION) >=
+          digit_version('1.10.0')):
+        if torch.cuda.is_available():
+            kwargs.setdefault('device_type', 'cuda')
+        else:
+            kwargs.setdefault('device_type', 'cpu')
+            # torch.autocast only support `dtype=torch.bfloat16` in
+            # pytorch 1.10
+            kwargs.setdefault('dtype', torch.bfloat16)
+
+        with torch.autocast(enabled=enabled, **kwargs):
+            yield
+
+    elif digit_version(TORCH_VERSION) >= digit_version('1.11.0'):
         if torch.cuda.is_available():
             kwargs.setdefault('device_type', 'cuda')
         else:
-- 
GitLab