diff --git a/mmengine/optim/optimizer/optimizer_wrapper.py b/mmengine/optim/optimizer/optimizer_wrapper.py
index 58dbc051d20699ae69b5c1006a0b84ab4b6a5368..644d5739fcafff8e8fa364c395b243b5966b29e7 100644
--- a/mmengine/optim/optimizer/optimizer_wrapper.py
+++ b/mmengine/optim/optimizer/optimizer_wrapper.py
@@ -161,20 +161,33 @@ class OptimWrapper:
         # the loss factor will always be the same as `_accumulative_counts`.
         self._remainder_counts = -1
 
-    def update_params(self, loss: torch.Tensor) -> None:
+    def update_params(self,
+                      loss: torch.Tensor,
+                      step_kwargs: Optional[Dict] = None,
+                      zero_kwargs: Optional[Dict] = None) -> None:
         """Update parameters in :attr:`optimizer`.
 
         Args:
             loss (torch.Tensor): A tensor for back propagation.
+            step_kwargs (dict): Arguments for optimizer.step.
+                Defaults to None.
+                New in version v0.4.0.
+            zero_kwargs (dict): Arguments for optimizer.zero_grad.
+                Defaults to None.
+                New in version v0.4.0.
         """
+        if step_kwargs is None:
+            step_kwargs = {}
+        if zero_kwargs is None:
+            zero_kwargs = {}
         loss = self.scale_loss(loss)
         self.backward(loss)
         # Update parameters only if `self._inner_count` is divisible by
         # `self._accumulative_counts` or `self._inner_count` equals to
         # `self._max_counts`
         if self.should_update():
-            self.step()
-            self.zero_grad()
+            self.step(**step_kwargs)
+            self.zero_grad(**zero_kwargs)
 
     def backward(self, loss: torch.Tensor, **kwargs) -> None:
         """Perform gradient back propagation.
diff --git a/mmengine/optim/optimizer/optimizer_wrapper_dict.py b/mmengine/optim/optimizer/optimizer_wrapper_dict.py
index 6155b62df0611b578183dd899a44d16b926738d5..8a4b25800360fcab30941886ed80ac3948fde57f 100644
--- a/mmengine/optim/optimizer/optimizer_wrapper_dict.py
+++ b/mmengine/optim/optimizer/optimizer_wrapper_dict.py
@@ -1,6 +1,6 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 from contextlib import contextmanager
-from typing import Dict, Iterator, List, Tuple
+from typing import Dict, Iterator, List, Optional, Tuple
 
 import torch
 import torch.nn as nn
@@ -46,7 +46,10 @@ class OptimWrapperDict(OptimWrapper):
                 f'but got {key}: {type(value)}')
         self.optim_wrappers = optim_wrapper_dict
 
-    def update_params(self, loss: torch.Tensor) -> None:
+    def update_params(self,
+                      loss: torch.Tensor,
+                      step_kwargs: Optional[Dict] = None,
+                      zero_kwargs: Optional[Dict] = None) -> None:
         """Update all optimizer wrappers would lead to a duplicate backward
         errors, and OptimWrapperDict does not know which optimizer wrapper
         should be updated.