From 6f321f88eec9b501a4763f2519e7ea0c08913761 Mon Sep 17 00:00:00 2001
From: RangiLyu <lyuchqi@gmail.com>
Date: Wed, 8 Jun 2022 16:38:27 +0800
Subject: [PATCH] [Enhance] Optimize parameter updating speed in AveragedModel.
 (#281)

* [Enhance] Optimize parameter updating speed in AveragedModel.

* add docstring
---
 mmengine/model/averaged_model.py | 71 +++++++++++++++-----------------
 1 file changed, 34 insertions(+), 37 deletions(-)

diff --git a/mmengine/model/averaged_model.py b/mmengine/model/averaged_model.py
index 7fefc79e..ab9ab9c0 100644
--- a/mmengine/model/averaged_model.py
+++ b/mmengine/model/averaged_model.py
@@ -1,5 +1,4 @@
 # Copyright (c) OpenMMLab. All rights reserved.
-import itertools
 from abc import abstractmethod
 from copy import deepcopy
 from typing import Optional
@@ -20,7 +19,10 @@ class BaseAveragedModel(nn.Module):
     This class creates a copy of the provided module :attr:`model`
     on the device :attr:`device` and allows computing running averages of the
     parameters of the :attr:`model`.
-    The code is referenced from: https://github.com/pytorch/pytorch/blob/master/torch/optim/swa_utils.py
+    The code is referenced from: https://github.com/pytorch/pytorch/blob/master/torch/optim/swa_utils.py.
+    Different from the `AveragedModel` in PyTorch, we use in-place operation
+    to improve the parameter updating speed, which is about 5 times faster
+    than the non-in-place version.
 
     In mmengine, we provide two ways to use the model averaging:
     1. Use the model averaging module in hook:
@@ -51,19 +53,23 @@ class BaseAveragedModel(nn.Module):
                  device: Optional[torch.device] = None,
                  update_buffers: bool = False) -> None:
         super().__init__()
-        self.module = deepcopy(model)
+        self.module = deepcopy(model).requires_grad_(False)
         self.interval = interval
         if device is not None:
             self.module = self.module.to(device)
         self.register_buffer('steps',
                              torch.tensor(0, dtype=torch.long, device=device))
         self.update_buffers = update_buffers
+        if update_buffers:
+            self.avg_parameters = self.module.state_dict()
+        else:
+            self.avg_parameters = dict(self.module.named_parameters())
 
     @abstractmethod
     def avg_func(self, averaged_param: Tensor, source_param: Tensor,
-                 steps: int) -> Tensor:
-        """Compute the average of the parameters. All subclasses must implement
-        this method.
+                 steps: int) -> None:
+        """Use in-place operation to compute the average of the parameters. All
+        subclasses must implement this method.
 
         Args:
             averaged_param (Tensor): The averaged parameters.
@@ -84,23 +90,19 @@ class BaseAveragedModel(nn.Module):
         Args:
             model (nn.Module): The model whose parameters will be averaged.
         """
-        if self.steps % self.interval == 0:
-            avg_param = (
-                itertools.chain(self.module.parameters(),
-                                self.module.buffers())
-                if self.update_buffers else self.parameters())
-            src_param = (
-                itertools.chain(model.parameters(), model.buffers())
-                if self.update_buffers else model.parameters())
-            for p_avg, p_src in zip(avg_param, src_param):
-                device = p_avg.device
-                p_src_ = p_src.detach().to(device)
-                if self.steps == 0:
-                    p_avg.detach().copy_(p_src_)
-                else:
-                    p_avg.detach().copy_(
-                        self.avg_func(p_avg.detach(), p_src_,
-                                      self.steps.to(device)))
+        src_parameters = (
+            model.state_dict()
+            if self.update_buffers else dict(model.named_parameters()))
+        if self.steps == 0:
+            for k, p_avg in self.avg_parameters.items():
+                p_avg.data.copy_(src_parameters[k].data)
+        elif self.steps % self.interval == 0:
+            for k, p_avg in self.avg_parameters.items():
+                if p_avg.dtype.is_floating_point:
+                    device = p_avg.device
+                    self.avg_func(p_avg.data,
+                                  src_parameters[k].data.to(device),
+                                  self.steps)
         self.steps += 1
 
 
@@ -115,7 +117,7 @@ class StochasticWeightAverage(BaseAveragedModel):
     """
 
     def avg_func(self, averaged_param: Tensor, source_param: Tensor,
-                 steps: int) -> Tensor:
+                 steps: int) -> None:
         """Compute the average of the parameters using stochastic weight
         average.
 
@@ -124,11 +126,10 @@ class StochasticWeightAverage(BaseAveragedModel):
             source_param (Tensor): The source parameters.
             steps (int): The number of times the parameters have been
                 updated.
-        Returns:
-            Tensor: The averaged parameters.
         """
-        return averaged_param + (source_param - averaged_param) / (
-            steps // self.interval + 1)
+        averaged_param.add_(
+            source_param - averaged_param,
+            alpha=1 / (steps // self.interval + 1))
 
 
 @MODELS.register_module()
@@ -167,7 +168,7 @@ class ExponentialMovingAverage(BaseAveragedModel):
         self.momentum = momentum
 
     def avg_func(self, averaged_param: Tensor, source_param: Tensor,
-                 steps: int) -> Tensor:
+                 steps: int) -> None:
         """Compute the moving average of the parameters using exponential
         moving average.
 
@@ -176,11 +177,9 @@ class ExponentialMovingAverage(BaseAveragedModel):
             source_param (Tensor): The source parameters.
             steps (int): The number of times the parameters have been
                 updated.
-        Returns:
-            Tensor: The averaged parameters.
         """
-        return averaged_param * (1 -
-                                 self.momentum) + source_param * self.momentum
+        averaged_param.mul_(1 - self.momentum).add_(
+            source_param, alpha=self.momentum)
 
 
 @MODELS.register_module()
@@ -222,7 +221,7 @@ class MomentumAnnealingEMA(ExponentialMovingAverage):
         self.gamma = gamma
 
     def avg_func(self, averaged_param: Tensor, source_param: Tensor,
-                 steps: int) -> Tensor:
+                 steps: int) -> None:
         """Compute the moving average of the parameters using the linear
         momentum strategy.
 
@@ -231,8 +230,6 @@ class MomentumAnnealingEMA(ExponentialMovingAverage):
             source_param (Tensor): The source parameters.
             steps (int): The number of times the parameters have been
                 updated.
-        Returns:
-            Tensor: The averaged parameters.
         """
         momentum = max(self.momentum, self.gamma / (self.gamma + self.steps))
-        return averaged_param * (1 - momentum) + source_param * momentum
+        averaged_param.mul_(1 - momentum).add_(source_param, alpha=momentum)
-- 
GitLab