diff --git a/mmengine/model/averaged_model.py b/mmengine/model/averaged_model.py index 7fefc79eea3940e152b6ba61ba28e87e456548de..ab9ab9c0ea49e8fd6659b053be2a6fc58fc5bced 100644 --- a/mmengine/model/averaged_model.py +++ b/mmengine/model/averaged_model.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -import itertools from abc import abstractmethod from copy import deepcopy from typing import Optional @@ -20,7 +19,10 @@ class BaseAveragedModel(nn.Module): This class creates a copy of the provided module :attr:`model` on the device :attr:`device` and allows computing running averages of the parameters of the :attr:`model`. - The code is referenced from: https://github.com/pytorch/pytorch/blob/master/torch/optim/swa_utils.py + The code is referenced from: https://github.com/pytorch/pytorch/blob/master/torch/optim/swa_utils.py. + Different from the `AveragedModel` in PyTorch, we use in-place operation + to improve the parameter updating speed, which is about 5 times faster + than the non-in-place version. In mmengine, we provide two ways to use the model averaging: 1. Use the model averaging module in hook: @@ -51,19 +53,23 @@ class BaseAveragedModel(nn.Module): device: Optional[torch.device] = None, update_buffers: bool = False) -> None: super().__init__() - self.module = deepcopy(model) + self.module = deepcopy(model).requires_grad_(False) self.interval = interval if device is not None: self.module = self.module.to(device) self.register_buffer('steps', torch.tensor(0, dtype=torch.long, device=device)) self.update_buffers = update_buffers + if update_buffers: + self.avg_parameters = self.module.state_dict() + else: + self.avg_parameters = dict(self.module.named_parameters()) @abstractmethod def avg_func(self, averaged_param: Tensor, source_param: Tensor, - steps: int) -> Tensor: - """Compute the average of the parameters. All subclasses must implement - this method. + steps: int) -> None: + """Use in-place operation to compute the average of the parameters. All + subclasses must implement this method. Args: averaged_param (Tensor): The averaged parameters. @@ -84,23 +90,19 @@ class BaseAveragedModel(nn.Module): Args: model (nn.Module): The model whose parameters will be averaged. """ - if self.steps % self.interval == 0: - avg_param = ( - itertools.chain(self.module.parameters(), - self.module.buffers()) - if self.update_buffers else self.parameters()) - src_param = ( - itertools.chain(model.parameters(), model.buffers()) - if self.update_buffers else model.parameters()) - for p_avg, p_src in zip(avg_param, src_param): - device = p_avg.device - p_src_ = p_src.detach().to(device) - if self.steps == 0: - p_avg.detach().copy_(p_src_) - else: - p_avg.detach().copy_( - self.avg_func(p_avg.detach(), p_src_, - self.steps.to(device))) + src_parameters = ( + model.state_dict() + if self.update_buffers else dict(model.named_parameters())) + if self.steps == 0: + for k, p_avg in self.avg_parameters.items(): + p_avg.data.copy_(src_parameters[k].data) + elif self.steps % self.interval == 0: + for k, p_avg in self.avg_parameters.items(): + if p_avg.dtype.is_floating_point: + device = p_avg.device + self.avg_func(p_avg.data, + src_parameters[k].data.to(device), + self.steps) self.steps += 1 @@ -115,7 +117,7 @@ class StochasticWeightAverage(BaseAveragedModel): """ def avg_func(self, averaged_param: Tensor, source_param: Tensor, - steps: int) -> Tensor: + steps: int) -> None: """Compute the average of the parameters using stochastic weight average. @@ -124,11 +126,10 @@ class StochasticWeightAverage(BaseAveragedModel): source_param (Tensor): The source parameters. steps (int): The number of times the parameters have been updated. - Returns: - Tensor: The averaged parameters. """ - return averaged_param + (source_param - averaged_param) / ( - steps // self.interval + 1) + averaged_param.add_( + source_param - averaged_param, + alpha=1 / (steps // self.interval + 1)) @MODELS.register_module() @@ -167,7 +168,7 @@ class ExponentialMovingAverage(BaseAveragedModel): self.momentum = momentum def avg_func(self, averaged_param: Tensor, source_param: Tensor, - steps: int) -> Tensor: + steps: int) -> None: """Compute the moving average of the parameters using exponential moving average. @@ -176,11 +177,9 @@ class ExponentialMovingAverage(BaseAveragedModel): source_param (Tensor): The source parameters. steps (int): The number of times the parameters have been updated. - Returns: - Tensor: The averaged parameters. """ - return averaged_param * (1 - - self.momentum) + source_param * self.momentum + averaged_param.mul_(1 - self.momentum).add_( + source_param, alpha=self.momentum) @MODELS.register_module() @@ -222,7 +221,7 @@ class MomentumAnnealingEMA(ExponentialMovingAverage): self.gamma = gamma def avg_func(self, averaged_param: Tensor, source_param: Tensor, - steps: int) -> Tensor: + steps: int) -> None: """Compute the moving average of the parameters using the linear momentum strategy. @@ -231,8 +230,6 @@ class MomentumAnnealingEMA(ExponentialMovingAverage): source_param (Tensor): The source parameters. steps (int): The number of times the parameters have been updated. - Returns: - Tensor: The averaged parameters. """ momentum = max(self.momentum, self.gamma / (self.gamma + self.steps)) - return averaged_param * (1 - momentum) + source_param * momentum + averaged_param.mul_(1 - momentum).add_(source_param, alpha=momentum)