diff --git a/tests/test_hook/test_logger_hook.py b/tests/test_hook/test_logger_hook.py index 631736e8e01161e9201ac687208c68000ca21b93..90faad7741c09e27fc8f21d0f9045166255bf4ce 100644 --- a/tests/test_hook/test_logger_hook.py +++ b/tests/test_hook/test_logger_hook.py @@ -299,17 +299,16 @@ class TestLoggerHook: assert list(tag.keys()) == ['metric'] log_buffers['val/metric'].current.assert_called() - @patch('torch.distributed.reduce', MagicMock()) + @patch('torch.cuda.max_memory_allocated', MagicMock()) + @patch('torch.cuda.reset_peak_memory_stats', MagicMock()) def test_get_max_memory(self): logger_hook = LoggerHook() runner = MagicMock() runner.world_size = 1 runner.model = torch.nn.Linear(1, 1) logger_hook._get_max_memory(runner) - torch.distributed.reduce.assert_not_called() - runner.world_size = 2 - logger_hook._get_max_memory(runner) - torch.distributed.reduce.assert_called() + torch.cuda.max_memory_allocated.assert_called() + torch.cuda.reset_peak_memory_stats.assert_called() def test_get_iter(self): runner = self._setup_runner()