From 3ba50ce1375067145b03ee5bb7d26766761e1019 Mon Sep 17 00:00:00 2001
From: RangiLyu <lyuchqi@gmail.com>
Date: Fri, 9 Sep 2022 11:44:58 +0800
Subject: [PATCH] [Enhance] Use torch.lerp_() to speed up EMA. (#519)

---
 mmengine/model/averaged_model.py | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/mmengine/model/averaged_model.py b/mmengine/model/averaged_model.py
index be1abc30..4326be1d 100644
--- a/mmengine/model/averaged_model.py
+++ b/mmengine/model/averaged_model.py
@@ -187,8 +187,7 @@ class ExponentialMovingAverage(BaseAveragedModel):
             steps (int): The number of times the parameters have been
                 updated.
         """
-        averaged_param.mul_(1 - self.momentum).add_(
-            source_param, alpha=self.momentum)
+        averaged_param.lerp_(source_param, self.momentum)
 
 
 @MODELS.register_module()
@@ -242,4 +241,4 @@ class MomentumAnnealingEMA(ExponentialMovingAverage):
                 updated.
         """
         momentum = max(self.momentum, self.gamma / (self.gamma + self.steps))
-        averaged_param.mul_(1 - momentum).add_(source_param, alpha=momentum)
+        averaged_param.lerp_(source_param, momentum)
-- 
GitLab