Skip to content
Snippets Groups Projects
Unverified Commit 090104df authored by Z-Fran's avatar Z-Fran Committed by GitHub
Browse files

[Fix] Fix the calculation error of eta_min in CosineRestart (#639)


* [Fix] fix CosineRestart eta_min

* add ut case

* Enhance unit test

Enhance unit test

* remove unused code

Co-authored-by: default avatarHAOCHENYE <21724054@zju.edu.cn>
parent 64ac1430
No related branches found
No related tags found
No related merge requests found
......@@ -1224,16 +1224,11 @@ class CosineRestartParamScheduler(_ParamScheduler):
self.optimizer.param_groups):
eta_max = base_value * current_weight
if self.eta_min_ratio is None:
eta_min = self.eta_min * (1 - current_weight)
eta_min = self.eta_min
else:
eta_min = base_value * self.eta_min_ratio * (1 -
current_weight)
eta_min = base_value * self.eta_min_ratio
if step == 0:
values.append(eta_max)
elif (step - 1 - current_periods) % (2 * current_periods) == 0:
values.append(group[self.param_name] + (eta_max - eta_min) *
(1 - math.cos(math.pi / current_periods)) / 2)
else:
values.append(
(1 + math.cos(math.pi * step / current_periods)) /
......
......@@ -430,6 +430,8 @@ class TestParameterScheduler(TestCase):
targets = [
single_targets, [t * self.layer2_mult for t in single_targets]
]
# Test with non-zero eta-min.
scheduler = CosineRestartParamScheduler(
self.optimizer,
param_name='lr',
......@@ -438,6 +440,26 @@ class TestParameterScheduler(TestCase):
eta_min=0)
self._test_scheduler_value(scheduler, targets, epochs=10)
epochs = 10
t = 10
eta_min = 5e-3
targets1 = [
eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / t)) / 2
for x in range(epochs)
]
targets2 = [
eta_min + (0.5 - eta_min) * (1 + math.cos(math.pi * x / t)) / 2
for x in range(epochs)
]
targets = [targets1, targets2]
scheduler = CosineRestartParamScheduler(
self.optimizer,
param_name='lr',
periods=[t],
restart_weights=[1],
eta_min=eta_min)
self._test_scheduler_value(scheduler, targets, epochs=10)
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
scheduler = construct()
for _ in range(epochs):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment