From 4d49de7d814c6e256d0873dfe10bcc1ece5b2a0f Mon Sep 17 00:00:00 2001
From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>
Date: Sun, 13 Mar 2022 16:56:29 +0800
Subject: [PATCH] [Fix] Fix LoggerHook save mutiple ranks scalar in the same
 json file. (#124)

* use master_only to decorator _log_train and _log_val

* fix resoloved TODO

fix resoloved TODO

fix resoloved TODO

* fix raise error typo

* ensure log item is python scalar
---
 mmengine/hooks/logger_hook.py | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/mmengine/hooks/logger_hook.py b/mmengine/hooks/logger_hook.py
index 9e2518bf..64c7868e 100644
--- a/mmengine/hooks/logger_hook.py
+++ b/mmengine/hooks/logger_hook.py
@@ -10,6 +10,7 @@ from typing import Any, Optional, Sequence, Tuple, Union
 import torch
 
 from mmengine.data import BaseDataSample
+from mmengine.dist import master_only
 from mmengine.fileio import FileClient
 from mmengine.hooks import Hook
 from mmengine.registry import HOOKS
@@ -232,6 +233,7 @@ class LoggerHook(Hook):
                 runner.logger.info((f'{local_filepath} was removed due to the '
                                     '`self.keep_local=False`'))
 
+    @master_only
     def _log_train(self, runner) -> None:
         """Collect and record training logs which start named with "train/*".
 
@@ -295,6 +297,7 @@ class LoggerHook(Hook):
         runner.writer.add_scalars(
             tag, step=runner.iter + 1, file_path=self.json_log_path)
 
+    @master_only
     def _log_val(self, runner) -> None:
         """Collect and record training logs which start named with "val/*".
 
@@ -402,7 +405,7 @@ class LoggerHook(Hook):
             for cfg in log_cfg:
                 log_name = cfg.get('log_name', None)
                 if log_name in log_names:
-                    raise KeyError(f'{cfg["log_name"]} cannot be Redefined in '
+                    raise KeyError(f'{cfg["log_name"]} cannot be redefined in '
                                    'log_key')
                 if log_name is not None:
                     log_names.add(log_name)
@@ -418,7 +421,7 @@ class LoggerHook(Hook):
                 name = log_cfg.pop('log_name')
             else:
                 name = log_key
-            tag[name] = log_buffers[log_key].statistics(**log_cfg)
+            tag[name] = log_buffers[log_key].statistics(**log_cfg).item()
         else:
             raise ValueError('The structure of `LoggerHook.custom key` is '
                              'wrong, please make sure the type of each key is '
@@ -435,7 +438,6 @@ class LoggerHook(Hook):
             The maximum GPU memory occupied by tensors in megabytes for a given
             device.
         """
-        # TODO use `mmengine.dist.max_memory_allocated` to count mem_mb
         device = getattr(runner.model, 'output_device', None)
         mem = torch.cuda.max_memory_allocated(device=device)
         mem_mb = torch.tensor([int(mem) // (1024 * 1024)],
-- 
GitLab