Skip to content
Snippets Groups Projects
Unverified Commit 4d49de7d authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Fix] Fix LoggerHook save mutiple ranks scalar in the same json file. (#124)

* use master_only to decorator _log_train and _log_val

* fix resoloved TODO

fix resoloved TODO

fix resoloved TODO

* fix raise error typo

* ensure log item is python scalar
parent a7961407
No related branches found
No related tags found
No related merge requests found
...@@ -10,6 +10,7 @@ from typing import Any, Optional, Sequence, Tuple, Union ...@@ -10,6 +10,7 @@ from typing import Any, Optional, Sequence, Tuple, Union
import torch import torch
from mmengine.data import BaseDataSample from mmengine.data import BaseDataSample
from mmengine.dist import master_only
from mmengine.fileio import FileClient from mmengine.fileio import FileClient
from mmengine.hooks import Hook from mmengine.hooks import Hook
from mmengine.registry import HOOKS from mmengine.registry import HOOKS
...@@ -232,6 +233,7 @@ class LoggerHook(Hook): ...@@ -232,6 +233,7 @@ class LoggerHook(Hook):
runner.logger.info((f'{local_filepath} was removed due to the ' runner.logger.info((f'{local_filepath} was removed due to the '
'`self.keep_local=False`')) '`self.keep_local=False`'))
@master_only
def _log_train(self, runner) -> None: def _log_train(self, runner) -> None:
"""Collect and record training logs which start named with "train/*". """Collect and record training logs which start named with "train/*".
...@@ -295,6 +297,7 @@ class LoggerHook(Hook): ...@@ -295,6 +297,7 @@ class LoggerHook(Hook):
runner.writer.add_scalars( runner.writer.add_scalars(
tag, step=runner.iter + 1, file_path=self.json_log_path) tag, step=runner.iter + 1, file_path=self.json_log_path)
@master_only
def _log_val(self, runner) -> None: def _log_val(self, runner) -> None:
"""Collect and record training logs which start named with "val/*". """Collect and record training logs which start named with "val/*".
...@@ -402,7 +405,7 @@ class LoggerHook(Hook): ...@@ -402,7 +405,7 @@ class LoggerHook(Hook):
for cfg in log_cfg: for cfg in log_cfg:
log_name = cfg.get('log_name', None) log_name = cfg.get('log_name', None)
if log_name in log_names: if log_name in log_names:
raise KeyError(f'{cfg["log_name"]} cannot be Redefined in ' raise KeyError(f'{cfg["log_name"]} cannot be redefined in '
'log_key') 'log_key')
if log_name is not None: if log_name is not None:
log_names.add(log_name) log_names.add(log_name)
...@@ -418,7 +421,7 @@ class LoggerHook(Hook): ...@@ -418,7 +421,7 @@ class LoggerHook(Hook):
name = log_cfg.pop('log_name') name = log_cfg.pop('log_name')
else: else:
name = log_key name = log_key
tag[name] = log_buffers[log_key].statistics(**log_cfg) tag[name] = log_buffers[log_key].statistics(**log_cfg).item()
else: else:
raise ValueError('The structure of `LoggerHook.custom key` is ' raise ValueError('The structure of `LoggerHook.custom key` is '
'wrong, please make sure the type of each key is ' 'wrong, please make sure the type of each key is '
...@@ -435,7 +438,6 @@ class LoggerHook(Hook): ...@@ -435,7 +438,6 @@ class LoggerHook(Hook):
The maximum GPU memory occupied by tensors in megabytes for a given The maximum GPU memory occupied by tensors in megabytes for a given
device. device.
""" """
# TODO use `mmengine.dist.max_memory_allocated` to count mem_mb
device = getattr(runner.model, 'output_device', None) device = getattr(runner.model, 'output_device', None)
mem = torch.cuda.max_memory_allocated(device=device) mem = torch.cuda.max_memory_allocated(device=device)
mem_mb = torch.tensor([int(mem) // (1024 * 1024)], mem_mb = torch.tensor([int(mem) // (1024 * 1024)],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment