Skip to content
Snippets Groups Projects
Unverified Commit 2e4ae21f authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

Fix test_get_max_memory (#113)

parent c499d726
No related branches found
No related tags found
No related merge requests found
......@@ -299,17 +299,16 @@ class TestLoggerHook:
assert list(tag.keys()) == ['metric']
log_buffers['val/metric'].current.assert_called()
@patch('torch.distributed.reduce', MagicMock())
@patch('torch.cuda.max_memory_allocated', MagicMock())
@patch('torch.cuda.reset_peak_memory_stats', MagicMock())
def test_get_max_memory(self):
logger_hook = LoggerHook()
runner = MagicMock()
runner.world_size = 1
runner.model = torch.nn.Linear(1, 1)
logger_hook._get_max_memory(runner)
torch.distributed.reduce.assert_not_called()
runner.world_size = 2
logger_hook._get_max_memory(runner)
torch.distributed.reduce.assert_called()
torch.cuda.max_memory_allocated.assert_called()
torch.cuda.reset_peak_memory_stats.assert_called()
def test_get_iter(self):
runner = self._setup_runner()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment