diff --git a/mmengine/runner/amp.py b/mmengine/runner/amp.py
index 00fff3a9d396aa4659738dab24eec7d0a9b3f557..3c91f2105b637b05c93fda9238c43cea39f4ff4a 100644
--- a/mmengine/runner/amp.py
+++ b/mmengine/runner/amp.py
@@ -1,16 +1,22 @@
 # Copyright (c) OpenMMLab. All rights reserved.
+import logging
 from contextlib import contextmanager
+from typing import Optional
 
 import torch
 
+from mmengine import print_log
 from mmengine.utils import TORCH_VERSION, digit_version
 
 
 @contextmanager
-def autocast(enabled: bool = True, **kwargs):
+def autocast(device_type: Optional[str] = None,
+             dtype: Optional[torch.dtype] = None,
+             enabled: bool = True,
+             cache_enabled: Optional[bool] = None):
     """A wrapper of ``torch.autocast`` and ``toch.cuda.amp.autocast``.
 
-    Pytorch 1.6.0 provide ``torch.cuda.amp.autocast`` for running in
+    Pytorch 1.5.0 provide ``torch.cuda.amp.autocast`` for running in
     mixed precision , and update it to ``torch.autocast`` in 1.10.0.
     Both interfaces have different arguments, and ``torch.autocast``
     support running with cpu additionally.
@@ -49,9 +55,13 @@ def autocast(enabled: bool = True, **kwargs):
          >>>    pass
 
     Args:
-        enabled (bool): Whether autocasting should be enabled in the region.
-            Defaults to True.
-        kwargs (dict): Arguments of torch.autocast except for ``enabled``.
+        device_type (str, required):  Whether to use 'cuda' or 'cpu' device.
+        enabled(bool):  Whether autocasting should be enabled in the region.
+            Defaults to True
+        dtype (torch_dtype, optional):  Whether to use ``torch.float16`` or
+            ``torch.bfloat16``.
+        cache_enabled(bool, optional):  Whether the weight cache inside
+            autocast should be enabled.
     """
     # If `enabled` is True, enable an empty context and all calculations
     # are performed under fp32.
@@ -63,9 +73,17 @@ def autocast(enabled: bool = True, **kwargs):
             digit_version('1.10.0')):
         # If pytorch version is between 1.5.0 and 1.10.0, the default value of
         # dtype for `torch.cuda.amp.autocast` is torch.float16.
-        assert not kwargs, (
-            f'autocast under pytorch {TORCH_VERSION} only accept `enabled` '
-            'arguments.')
+        assert device_type == 'cuda' or device_type is None, (
+            'Pytorch version under 1.5.0 only supports running automatic '
+            'mixed training with cuda')
+        if dtype is not None or cache_enabled is not None:
+            print_log(
+                f'{dtype} and {device_type} will not work for '
+                '`autocast` since your Pytorch version: '
+                f'{TORCH_VERSION} <= 1.10.0',
+                logger='current',
+                level=logging.WARNING)
+
         if torch.cuda.is_available():
             with torch.cuda.amp.autocast(enabled=enabled):
                 yield
@@ -77,24 +95,23 @@ def autocast(enabled: bool = True, **kwargs):
                     'If pytorch versions is between 1.5.0 and 1.10, '
                     '`autocast` is only available in gpu mode')
 
-    elif (digit_version('1.11.0') > digit_version(TORCH_VERSION) >=
-          digit_version('1.10.0')):
-        if torch.cuda.is_available():
-            kwargs.setdefault('device_type', 'cuda')
-        else:
-            kwargs.setdefault('device_type', 'cpu')
-            # torch.autocast only support `dtype=torch.bfloat16` in
-            # pytorch 1.10
-            kwargs.setdefault('dtype', torch.bfloat16)
-
-        with torch.autocast(enabled=enabled, **kwargs):
-            yield
-
-    elif digit_version(TORCH_VERSION) >= digit_version('1.11.0'):
+    else:
         if torch.cuda.is_available():
-            kwargs.setdefault('device_type', 'cuda')
+            device_type = 'cuda' if device_type is None else device_type
         else:
-            kwargs.setdefault('device_type', 'cpu')
+            device_type = 'cpu' if device_type is None else device_type
 
-        with torch.autocast(enabled=enabled, **kwargs):
+        if digit_version(TORCH_VERSION) < digit_version('1.11.0'):
+            if dtype is not None and dtype != torch.bfloat16:
+                print_log(
+                    f'{dtype} must be `torch.bfloat16` with Pytorch '
+                    f'version: {TORCH_VERSION}',
+                    logger='current',
+                    level=logging.WARNING)
+            dtype = torch.bfloat16
+        with torch.autocast(
+                device_type=device_type,
+                enabled=enabled,
+                dtype=dtype,
+                cache_enabled=cache_enabled):
             yield
diff --git a/tests/test_runner/test_amp.py b/tests/test_runner/test_amp.py
index 8ac0dd55226d699a1d7120ea4e194d1b67e4791a..38801bdab8f24a2fe87eaa6429608e39a5f20935 100644
--- a/tests/test_runner/test_amp.py
+++ b/tests/test_runner/test_amp.py
@@ -42,16 +42,16 @@ class TestAmp(unittest.TestCase):
             else:
                 devices = ['cpu', 'cuda']
             for device in devices:
-                with autocast():
+                with autocast(device_type=device):
                     # torch.autocast support cpu and cuda mode.
                     layer = nn.Conv2d(1, 1, 1).to(device)
                     res = layer(torch.randn(1, 1, 1, 1).to(device))
                     self.assertIn(res.dtype, (torch.bfloat16, torch.float16))
-                    with autocast(enabled=False):
+                    with autocast(enabled=False, device_type=device):
                         res = layer(torch.randn(1, 1, 1, 1).to(device))
                         self.assertEqual(res.dtype, torch.float32)
                 # Test with fp32_enabled
-                with autocast(enabled=False):
+                with autocast(enabled=False, device_type=device):
                     layer = nn.Conv2d(1, 1, 1).to(device)
                     res = layer(torch.randn(1, 1, 1, 1).to(device))
                     self.assertEqual(res.dtype, torch.float32)