From 2e4ae21f2b812a4e91284a360f1526becaf47745 Mon Sep 17 00:00:00 2001
From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>
Date: Wed, 9 Mar 2022 14:05:25 +0800
Subject: [PATCH] Fix test_get_max_memory (#113)

---
 tests/test_hook/test_logger_hook.py | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/tests/test_hook/test_logger_hook.py b/tests/test_hook/test_logger_hook.py
index 631736e8..90faad77 100644
--- a/tests/test_hook/test_logger_hook.py
+++ b/tests/test_hook/test_logger_hook.py
@@ -299,17 +299,16 @@ class TestLoggerHook:
         assert list(tag.keys()) == ['metric']
         log_buffers['val/metric'].current.assert_called()
 
-    @patch('torch.distributed.reduce', MagicMock())
+    @patch('torch.cuda.max_memory_allocated', MagicMock())
+    @patch('torch.cuda.reset_peak_memory_stats', MagicMock())
     def test_get_max_memory(self):
         logger_hook = LoggerHook()
         runner = MagicMock()
         runner.world_size = 1
         runner.model = torch.nn.Linear(1, 1)
         logger_hook._get_max_memory(runner)
-        torch.distributed.reduce.assert_not_called()
-        runner.world_size = 2
-        logger_hook._get_max_memory(runner)
-        torch.distributed.reduce.assert_called()
+        torch.cuda.max_memory_allocated.assert_called()
+        torch.cuda.reset_peak_memory_stats.assert_called()
 
     def test_get_iter(self):
         runner = self._setup_runner()
-- 
GitLab