Skip to content
Snippets Groups Projects
Unverified Commit 2fd6beb9 authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Fix] Fix UT of optimizer wrapper failed in pytorch1.6 (#340)

parent bbe00274
No related branches found
No related tags found
No related merge requests found
...@@ -79,9 +79,9 @@ class TestOptimWrapper(MultiProcessTestCase): ...@@ -79,9 +79,9 @@ class TestOptimWrapper(MultiProcessTestCase):
# Test update params every iteration. # Test update params every iteration.
optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=1) optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=1)
self._mock_method(optim_wrapper) self._mock_method(optim_wrapper)
loss = torch.tensor(1) loss = torch.tensor(1.)
optim_wrapper.update_params(loss) optim_wrapper.update_params(loss)
self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1)) self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1.))
optim_wrapper.step.assert_called_with() optim_wrapper.step.assert_called_with()
optim_wrapper.zero_grad.assert_called_with() optim_wrapper.zero_grad.assert_called_with()
...@@ -89,15 +89,15 @@ class TestOptimWrapper(MultiProcessTestCase): ...@@ -89,15 +89,15 @@ class TestOptimWrapper(MultiProcessTestCase):
optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=3) optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=3)
self._mock_method(optim_wrapper) self._mock_method(optim_wrapper)
# `iter=0`, accumulate gradient and do not update params. # `iter=0`, accumulate gradient and do not update params.
loss = torch.tensor(1) loss = torch.tensor(1.)
optim_wrapper.update_params(loss) optim_wrapper.update_params(loss)
self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1) / 3) self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1.) / 3.)
optim_wrapper.step.assert_not_called() optim_wrapper.step.assert_not_called()
optim_wrapper.zero_grad.assert_not_called() optim_wrapper.zero_grad.assert_not_called()
# gradient accumulate # gradient accumulate
optim_wrapper.update_params(loss) optim_wrapper.update_params(loss)
self.assertEqual(optim_wrapper._inner_count, 2) self.assertEqual(optim_wrapper._inner_count, 2.)
# `iter=2`, update params. # `iter=2`, update params.
optim_wrapper.update_params(loss) optim_wrapper.update_params(loss)
...@@ -110,7 +110,7 @@ class TestOptimWrapper(MultiProcessTestCase): ...@@ -110,7 +110,7 @@ class TestOptimWrapper(MultiProcessTestCase):
optim_wrapper.update_params(loss) optim_wrapper.update_params(loss)
optim_wrapper.step.assert_not_called() optim_wrapper.step.assert_not_called()
optim_wrapper.zero_grad.assert_not_called() optim_wrapper.zero_grad.assert_not_called()
self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1) / 3) self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1.) / 3.)
self._mock_method(optim_wrapper) self._mock_method(optim_wrapper)
# After calling `initialize_iter_status`, params will be updated at the # After calling `initialize_iter_status`, params will be updated at the
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment