diff --git a/mmengine/optim/optimizer/optimizer_wrapper.py b/mmengine/optim/optimizer/optimizer_wrapper.py index 58dbc051d20699ae69b5c1006a0b84ab4b6a5368..644d5739fcafff8e8fa364c395b243b5966b29e7 100644 --- a/mmengine/optim/optimizer/optimizer_wrapper.py +++ b/mmengine/optim/optimizer/optimizer_wrapper.py @@ -161,20 +161,33 @@ class OptimWrapper: # the loss factor will always be the same as `_accumulative_counts`. self._remainder_counts = -1 - def update_params(self, loss: torch.Tensor) -> None: + def update_params(self, + loss: torch.Tensor, + step_kwargs: Optional[Dict] = None, + zero_kwargs: Optional[Dict] = None) -> None: """Update parameters in :attr:`optimizer`. Args: loss (torch.Tensor): A tensor for back propagation. + step_kwargs (dict): Arguments for optimizer.step. + Defaults to None. + New in version v0.4.0. + zero_kwargs (dict): Arguments for optimizer.zero_grad. + Defaults to None. + New in version v0.4.0. """ + if step_kwargs is None: + step_kwargs = {} + if zero_kwargs is None: + zero_kwargs = {} loss = self.scale_loss(loss) self.backward(loss) # Update parameters only if `self._inner_count` is divisible by # `self._accumulative_counts` or `self._inner_count` equals to # `self._max_counts` if self.should_update(): - self.step() - self.zero_grad() + self.step(**step_kwargs) + self.zero_grad(**zero_kwargs) def backward(self, loss: torch.Tensor, **kwargs) -> None: """Perform gradient back propagation. diff --git a/mmengine/optim/optimizer/optimizer_wrapper_dict.py b/mmengine/optim/optimizer/optimizer_wrapper_dict.py index 6155b62df0611b578183dd899a44d16b926738d5..8a4b25800360fcab30941886ed80ac3948fde57f 100644 --- a/mmengine/optim/optimizer/optimizer_wrapper_dict.py +++ b/mmengine/optim/optimizer/optimizer_wrapper_dict.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from contextlib import contextmanager -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Optional, Tuple import torch import torch.nn as nn @@ -46,7 +46,10 @@ class OptimWrapperDict(OptimWrapper): f'but got {key}: {type(value)}') self.optim_wrappers = optim_wrapper_dict - def update_params(self, loss: torch.Tensor) -> None: + def update_params(self, + loss: torch.Tensor, + step_kwargs: Optional[Dict] = None, + zero_kwargs: Optional[Dict] = None) -> None: """Update all optimizer wrappers would lead to a duplicate backward errors, and OptimWrapperDict does not know which optimizer wrapper should be updated.