diff --git a/tests/test_hook/test_logger_hook.py b/tests/test_hook/test_logger_hook.py
index 631736e8e01161e9201ac687208c68000ca21b93..90faad7741c09e27fc8f21d0f9045166255bf4ce 100644
--- a/tests/test_hook/test_logger_hook.py
+++ b/tests/test_hook/test_logger_hook.py
@@ -299,17 +299,16 @@ class TestLoggerHook:
         assert list(tag.keys()) == ['metric']
         log_buffers['val/metric'].current.assert_called()
 
-    @patch('torch.distributed.reduce', MagicMock())
+    @patch('torch.cuda.max_memory_allocated', MagicMock())
+    @patch('torch.cuda.reset_peak_memory_stats', MagicMock())
     def test_get_max_memory(self):
         logger_hook = LoggerHook()
         runner = MagicMock()
         runner.world_size = 1
         runner.model = torch.nn.Linear(1, 1)
         logger_hook._get_max_memory(runner)
-        torch.distributed.reduce.assert_not_called()
-        runner.world_size = 2
-        logger_hook._get_max_memory(runner)
-        torch.distributed.reduce.assert_called()
+        torch.cuda.max_memory_allocated.assert_called()
+        torch.cuda.reset_peak_memory_stats.assert_called()
 
     def test_get_iter(self):
         runner = self._setup_runner()