From 381c5f103c3c77b55ee8ecf233ce0916eac67b61 Mon Sep 17 00:00:00 2001
From: Ming-Hsuan-Tu <qrnnis2623891@gmail.com>
Date: Thu, 8 Dec 2022 13:15:42 +0800
Subject: [PATCH] [Enhance] Support passing kwargs to update_params (#796)

* [Enhance]

Support step arugments and zero arguments with update_params

* Update mmengine/optim/optimizer/optimizer_wrapper.py

* Update mmengine/optim/optimizer/optimizer_wrapper.py

Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
---
 mmengine/optim/optimizer/optimizer_wrapper.py | 19 ++++++++++++++++---
 .../optim/optimizer/optimizer_wrapper_dict.py |  7 +++++--
 2 files changed, 21 insertions(+), 5 deletions(-)

diff --git a/mmengine/optim/optimizer/optimizer_wrapper.py b/mmengine/optim/optimizer/optimizer_wrapper.py
index 58dbc051..644d5739 100644
--- a/mmengine/optim/optimizer/optimizer_wrapper.py
+++ b/mmengine/optim/optimizer/optimizer_wrapper.py
@@ -161,20 +161,33 @@ class OptimWrapper:
         # the loss factor will always be the same as `_accumulative_counts`.
         self._remainder_counts = -1
 
-    def update_params(self, loss: torch.Tensor) -> None:
+    def update_params(self,
+                      loss: torch.Tensor,
+                      step_kwargs: Optional[Dict] = None,
+                      zero_kwargs: Optional[Dict] = None) -> None:
         """Update parameters in :attr:`optimizer`.
 
         Args:
             loss (torch.Tensor): A tensor for back propagation.
+            step_kwargs (dict): Arguments for optimizer.step.
+                Defaults to None.
+                New in version v0.4.0.
+            zero_kwargs (dict): Arguments for optimizer.zero_grad.
+                Defaults to None.
+                New in version v0.4.0.
         """
+        if step_kwargs is None:
+            step_kwargs = {}
+        if zero_kwargs is None:
+            zero_kwargs = {}
         loss = self.scale_loss(loss)
         self.backward(loss)
         # Update parameters only if `self._inner_count` is divisible by
         # `self._accumulative_counts` or `self._inner_count` equals to
         # `self._max_counts`
         if self.should_update():
-            self.step()
-            self.zero_grad()
+            self.step(**step_kwargs)
+            self.zero_grad(**zero_kwargs)
 
     def backward(self, loss: torch.Tensor, **kwargs) -> None:
         """Perform gradient back propagation.
diff --git a/mmengine/optim/optimizer/optimizer_wrapper_dict.py b/mmengine/optim/optimizer/optimizer_wrapper_dict.py
index 6155b62d..8a4b2580 100644
--- a/mmengine/optim/optimizer/optimizer_wrapper_dict.py
+++ b/mmengine/optim/optimizer/optimizer_wrapper_dict.py
@@ -1,6 +1,6 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 from contextlib import contextmanager
-from typing import Dict, Iterator, List, Tuple
+from typing import Dict, Iterator, List, Optional, Tuple
 
 import torch
 import torch.nn as nn
@@ -46,7 +46,10 @@ class OptimWrapperDict(OptimWrapper):
                 f'but got {key}: {type(value)}')
         self.optim_wrappers = optim_wrapper_dict
 
-    def update_params(self, loss: torch.Tensor) -> None:
+    def update_params(self,
+                      loss: torch.Tensor,
+                      step_kwargs: Optional[Dict] = None,
+                      zero_kwargs: Optional[Dict] = None) -> None:
         """Update all optimizer wrappers would lead to a duplicate backward
         errors, and OptimWrapperDict does not know which optimizer wrapper
         should be updated.
-- 
GitLab