From 2e4ae21f2b812a4e91284a360f1526becaf47745 Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Wed, 9 Mar 2022 14:05:25 +0800 Subject: [PATCH] Fix test_get_max_memory (#113) --- tests/test_hook/test_logger_hook.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/test_hook/test_logger_hook.py b/tests/test_hook/test_logger_hook.py index 631736e8..90faad77 100644 --- a/tests/test_hook/test_logger_hook.py +++ b/tests/test_hook/test_logger_hook.py @@ -299,17 +299,16 @@ class TestLoggerHook: assert list(tag.keys()) == ['metric'] log_buffers['val/metric'].current.assert_called() - @patch('torch.distributed.reduce', MagicMock()) + @patch('torch.cuda.max_memory_allocated', MagicMock()) + @patch('torch.cuda.reset_peak_memory_stats', MagicMock()) def test_get_max_memory(self): logger_hook = LoggerHook() runner = MagicMock() runner.world_size = 1 runner.model = torch.nn.Linear(1, 1) logger_hook._get_max_memory(runner) - torch.distributed.reduce.assert_not_called() - runner.world_size = 2 - logger_hook._get_max_memory(runner) - torch.distributed.reduce.assert_called() + torch.cuda.max_memory_allocated.assert_called() + torch.cuda.reset_peak_memory_stats.assert_called() def test_get_iter(self): runner = self._setup_runner() -- GitLab