diff --git a/mmengine/model/averaged_model.py b/mmengine/model/averaged_model.py index be1abc30bda2352f5aceee00f3ccc11e1cd5d2c2..4326be1d23f9635ffdd22463435eabc44efe24ed 100644 --- a/mmengine/model/averaged_model.py +++ b/mmengine/model/averaged_model.py @@ -187,8 +187,7 @@ class ExponentialMovingAverage(BaseAveragedModel): steps (int): The number of times the parameters have been updated. """ - averaged_param.mul_(1 - self.momentum).add_( - source_param, alpha=self.momentum) + averaged_param.lerp_(source_param, self.momentum) @MODELS.register_module() @@ -242,4 +241,4 @@ class MomentumAnnealingEMA(ExponentialMovingAverage): updated. """ momentum = max(self.momentum, self.gamma / (self.gamma + self.steps)) - averaged_param.mul_(1 - momentum).add_(source_param, alpha=momentum) + averaged_param.lerp_(source_param, momentum)