Skip to content
Snippets Groups Projects
Unverified Commit 6f321f88 authored by RangiLyu's avatar RangiLyu Committed by GitHub
Browse files

[Enhance] Optimize parameter updating speed in AveragedModel. (#281)

* [Enhance] Optimize parameter updating speed in AveragedModel.

* add docstring
parent 6ee67543
No related branches found
No related tags found
No related merge requests found
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
from abc import abstractmethod
from copy import deepcopy
from typing import Optional
......@@ -20,7 +19,10 @@ class BaseAveragedModel(nn.Module):
This class creates a copy of the provided module :attr:`model`
on the device :attr:`device` and allows computing running averages of the
parameters of the :attr:`model`.
The code is referenced from: https://github.com/pytorch/pytorch/blob/master/torch/optim/swa_utils.py
The code is referenced from: https://github.com/pytorch/pytorch/blob/master/torch/optim/swa_utils.py.
Different from the `AveragedModel` in PyTorch, we use in-place operation
to improve the parameter updating speed, which is about 5 times faster
than the non-in-place version.
In mmengine, we provide two ways to use the model averaging:
1. Use the model averaging module in hook:
......@@ -51,19 +53,23 @@ class BaseAveragedModel(nn.Module):
device: Optional[torch.device] = None,
update_buffers: bool = False) -> None:
super().__init__()
self.module = deepcopy(model)
self.module = deepcopy(model).requires_grad_(False)
self.interval = interval
if device is not None:
self.module = self.module.to(device)
self.register_buffer('steps',
torch.tensor(0, dtype=torch.long, device=device))
self.update_buffers = update_buffers
if update_buffers:
self.avg_parameters = self.module.state_dict()
else:
self.avg_parameters = dict(self.module.named_parameters())
@abstractmethod
def avg_func(self, averaged_param: Tensor, source_param: Tensor,
steps: int) -> Tensor:
"""Compute the average of the parameters. All subclasses must implement
this method.
steps: int) -> None:
"""Use in-place operation to compute the average of the parameters. All
subclasses must implement this method.
Args:
averaged_param (Tensor): The averaged parameters.
......@@ -84,23 +90,19 @@ class BaseAveragedModel(nn.Module):
Args:
model (nn.Module): The model whose parameters will be averaged.
"""
if self.steps % self.interval == 0:
avg_param = (
itertools.chain(self.module.parameters(),
self.module.buffers())
if self.update_buffers else self.parameters())
src_param = (
itertools.chain(model.parameters(), model.buffers())
if self.update_buffers else model.parameters())
for p_avg, p_src in zip(avg_param, src_param):
device = p_avg.device
p_src_ = p_src.detach().to(device)
if self.steps == 0:
p_avg.detach().copy_(p_src_)
else:
p_avg.detach().copy_(
self.avg_func(p_avg.detach(), p_src_,
self.steps.to(device)))
src_parameters = (
model.state_dict()
if self.update_buffers else dict(model.named_parameters()))
if self.steps == 0:
for k, p_avg in self.avg_parameters.items():
p_avg.data.copy_(src_parameters[k].data)
elif self.steps % self.interval == 0:
for k, p_avg in self.avg_parameters.items():
if p_avg.dtype.is_floating_point:
device = p_avg.device
self.avg_func(p_avg.data,
src_parameters[k].data.to(device),
self.steps)
self.steps += 1
......@@ -115,7 +117,7 @@ class StochasticWeightAverage(BaseAveragedModel):
"""
def avg_func(self, averaged_param: Tensor, source_param: Tensor,
steps: int) -> Tensor:
steps: int) -> None:
"""Compute the average of the parameters using stochastic weight
average.
......@@ -124,11 +126,10 @@ class StochasticWeightAverage(BaseAveragedModel):
source_param (Tensor): The source parameters.
steps (int): The number of times the parameters have been
updated.
Returns:
Tensor: The averaged parameters.
"""
return averaged_param + (source_param - averaged_param) / (
steps // self.interval + 1)
averaged_param.add_(
source_param - averaged_param,
alpha=1 / (steps // self.interval + 1))
@MODELS.register_module()
......@@ -167,7 +168,7 @@ class ExponentialMovingAverage(BaseAveragedModel):
self.momentum = momentum
def avg_func(self, averaged_param: Tensor, source_param: Tensor,
steps: int) -> Tensor:
steps: int) -> None:
"""Compute the moving average of the parameters using exponential
moving average.
......@@ -176,11 +177,9 @@ class ExponentialMovingAverage(BaseAveragedModel):
source_param (Tensor): The source parameters.
steps (int): The number of times the parameters have been
updated.
Returns:
Tensor: The averaged parameters.
"""
return averaged_param * (1 -
self.momentum) + source_param * self.momentum
averaged_param.mul_(1 - self.momentum).add_(
source_param, alpha=self.momentum)
@MODELS.register_module()
......@@ -222,7 +221,7 @@ class MomentumAnnealingEMA(ExponentialMovingAverage):
self.gamma = gamma
def avg_func(self, averaged_param: Tensor, source_param: Tensor,
steps: int) -> Tensor:
steps: int) -> None:
"""Compute the moving average of the parameters using the linear
momentum strategy.
......@@ -231,8 +230,6 @@ class MomentumAnnealingEMA(ExponentialMovingAverage):
source_param (Tensor): The source parameters.
steps (int): The number of times the parameters have been
updated.
Returns:
Tensor: The averaged parameters.
"""
momentum = max(self.momentum, self.gamma / (self.gamma + self.steps))
return averaged_param * (1 - momentum) + source_param * momentum
averaged_param.mul_(1 - momentum).add_(source_param, alpha=momentum)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment