From 3ba50ce1375067145b03ee5bb7d26766761e1019 Mon Sep 17 00:00:00 2001 From: RangiLyu <lyuchqi@gmail.com> Date: Fri, 9 Sep 2022 11:44:58 +0800 Subject: [PATCH] [Enhance] Use torch.lerp_() to speed up EMA. (#519) --- mmengine/model/averaged_model.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mmengine/model/averaged_model.py b/mmengine/model/averaged_model.py index be1abc30..4326be1d 100644 --- a/mmengine/model/averaged_model.py +++ b/mmengine/model/averaged_model.py @@ -187,8 +187,7 @@ class ExponentialMovingAverage(BaseAveragedModel): steps (int): The number of times the parameters have been updated. """ - averaged_param.mul_(1 - self.momentum).add_( - source_param, alpha=self.momentum) + averaged_param.lerp_(source_param, self.momentum) @MODELS.register_module() @@ -242,4 +241,4 @@ class MomentumAnnealingEMA(ExponentialMovingAverage): updated. """ momentum = max(self.momentum, self.gamma / (self.gamma + self.steps)) - averaged_param.mul_(1 - momentum).add_(source_param, alpha=momentum) + averaged_param.lerp_(source_param, momentum) -- GitLab