Skip to content
Snippets Groups Projects
Unverified Commit 3ba50ce1 authored by RangiLyu's avatar RangiLyu Committed by GitHub
Browse files

[Enhance] Use torch.lerp_() to speed up EMA. (#519)

parent 9bbbd7dc
No related branches found
No related tags found
No related merge requests found
...@@ -187,8 +187,7 @@ class ExponentialMovingAverage(BaseAveragedModel): ...@@ -187,8 +187,7 @@ class ExponentialMovingAverage(BaseAveragedModel):
steps (int): The number of times the parameters have been steps (int): The number of times the parameters have been
updated. updated.
""" """
averaged_param.mul_(1 - self.momentum).add_( averaged_param.lerp_(source_param, self.momentum)
source_param, alpha=self.momentum)
@MODELS.register_module() @MODELS.register_module()
...@@ -242,4 +241,4 @@ class MomentumAnnealingEMA(ExponentialMovingAverage): ...@@ -242,4 +241,4 @@ class MomentumAnnealingEMA(ExponentialMovingAverage):
updated. updated.
""" """
momentum = max(self.momentum, self.gamma / (self.gamma + self.steps)) momentum = max(self.momentum, self.gamma / (self.gamma + self.steps))
averaged_param.mul_(1 - momentum).add_(source_param, alpha=momentum) averaged_param.lerp_(source_param, momentum)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment