Skip to content
Snippets Groups Projects
Unverified Commit 381c5f10 authored by Ming-Hsuan-Tu's avatar Ming-Hsuan-Tu Committed by GitHub
Browse files

[Enhance] Support passing kwargs to update_params (#796)


* [Enhance]

Support step arugments and zero arguments with update_params

* Update mmengine/optim/optimizer/optimizer_wrapper.py

* Update mmengine/optim/optimizer/optimizer_wrapper.py

Co-authored-by: default avatarMashiro <57566630+HAOCHENYE@users.noreply.github.com>
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent 57f6644e
No related branches found
No related tags found
No related merge requests found
...@@ -161,20 +161,33 @@ class OptimWrapper: ...@@ -161,20 +161,33 @@ class OptimWrapper:
# the loss factor will always be the same as `_accumulative_counts`. # the loss factor will always be the same as `_accumulative_counts`.
self._remainder_counts = -1 self._remainder_counts = -1
def update_params(self, loss: torch.Tensor) -> None: def update_params(self,
loss: torch.Tensor,
step_kwargs: Optional[Dict] = None,
zero_kwargs: Optional[Dict] = None) -> None:
"""Update parameters in :attr:`optimizer`. """Update parameters in :attr:`optimizer`.
Args: Args:
loss (torch.Tensor): A tensor for back propagation. loss (torch.Tensor): A tensor for back propagation.
step_kwargs (dict): Arguments for optimizer.step.
Defaults to None.
New in version v0.4.0.
zero_kwargs (dict): Arguments for optimizer.zero_grad.
Defaults to None.
New in version v0.4.0.
""" """
if step_kwargs is None:
step_kwargs = {}
if zero_kwargs is None:
zero_kwargs = {}
loss = self.scale_loss(loss) loss = self.scale_loss(loss)
self.backward(loss) self.backward(loss)
# Update parameters only if `self._inner_count` is divisible by # Update parameters only if `self._inner_count` is divisible by
# `self._accumulative_counts` or `self._inner_count` equals to # `self._accumulative_counts` or `self._inner_count` equals to
# `self._max_counts` # `self._max_counts`
if self.should_update(): if self.should_update():
self.step() self.step(**step_kwargs)
self.zero_grad() self.zero_grad(**zero_kwargs)
def backward(self, loss: torch.Tensor, **kwargs) -> None: def backward(self, loss: torch.Tensor, **kwargs) -> None:
"""Perform gradient back propagation. """Perform gradient back propagation.
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from contextlib import contextmanager from contextlib import contextmanager
from typing import Dict, Iterator, List, Tuple from typing import Dict, Iterator, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -46,7 +46,10 @@ class OptimWrapperDict(OptimWrapper): ...@@ -46,7 +46,10 @@ class OptimWrapperDict(OptimWrapper):
f'but got {key}: {type(value)}') f'but got {key}: {type(value)}')
self.optim_wrappers = optim_wrapper_dict self.optim_wrappers = optim_wrapper_dict
def update_params(self, loss: torch.Tensor) -> None: def update_params(self,
loss: torch.Tensor,
step_kwargs: Optional[Dict] = None,
zero_kwargs: Optional[Dict] = None) -> None:
"""Update all optimizer wrappers would lead to a duplicate backward """Update all optimizer wrappers would lead to a duplicate backward
errors, and OptimWrapperDict does not know which optimizer wrapper errors, and OptimWrapperDict does not know which optimizer wrapper
should be updated. should be updated.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment