Skip to content
Snippets Groups Projects
Unverified Commit 605e78be authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

Haochenye/fix logging (#167)

* remove LoggerHook master_only

* remogve \t in log string

* fix lint

* Fix lint
parent 6e4bcc99
No related branches found
No related tags found
No related merge requests found
...@@ -9,7 +9,6 @@ from typing import Optional, Sequence, Union ...@@ -9,7 +9,6 @@ from typing import Optional, Sequence, Union
import torch import torch
from mmengine.dist import master_only
from mmengine.fileio import FileClient from mmengine.fileio import FileClient
from mmengine.hooks import Hook from mmengine.hooks import Hook
from mmengine.registry import HOOKS from mmengine.registry import HOOKS
...@@ -239,7 +238,6 @@ class LoggerHook(Hook): ...@@ -239,7 +238,6 @@ class LoggerHook(Hook):
runner.logger.info((f'{local_filepath} was removed due to the ' runner.logger.info((f'{local_filepath} was removed due to the '
'`self.keep_local=False`')) '`self.keep_local=False`'))
@master_only
def _log_train(self, runner) -> None: def _log_train(self, runner) -> None:
"""Collect and record training logs which start named with "train/*". """Collect and record training logs which start named with "train/*".
...@@ -272,9 +270,9 @@ class LoggerHook(Hook): ...@@ -272,9 +270,9 @@ class LoggerHook(Hook):
# by iter: Iter [100/100000] # by iter: Iter [100/100000]
if self.by_epoch: if self.by_epoch:
log_str = f'Epoch [{cur_epoch}]' \ log_str = f'Epoch [{cur_epoch}]' \
f'[{cur_iter}/{len(runner.train_loop.dataloader)}]\t' f'[{cur_iter}/{len(runner.train_loop.dataloader)}] '
else: else:
log_str = f'Iter [{cur_iter}/{runner.train_loop.max_iters}]\t' log_str = f'Iter [{cur_iter}/{runner.train_loop.max_iters}] '
log_str += f'{lr_momentum_str}, ' log_str += f'{lr_momentum_str}, '
# Calculate eta time. # Calculate eta time.
self.time_sec_tot += (tag['time'] * self.interval) self.time_sec_tot += (tag['time'] * self.interval)
...@@ -303,7 +301,6 @@ class LoggerHook(Hook): ...@@ -303,7 +301,6 @@ class LoggerHook(Hook):
runner.writer.add_scalars( runner.writer.add_scalars(
tag, step=runner.iter + 1, file_path=self.json_log_path) tag, step=runner.iter + 1, file_path=self.json_log_path)
@master_only
def _log_val(self, runner) -> None: def _log_val(self, runner) -> None:
"""Collect and record training logs which start named with "val/*". """Collect and record training logs which start named with "val/*".
...@@ -321,9 +318,9 @@ class LoggerHook(Hook): ...@@ -321,9 +318,9 @@ class LoggerHook(Hook):
# by iter: Iter[val] [1000] # by iter: Iter[val] [1000]
if self.by_epoch: if self.by_epoch:
# runner.epoch += 1 has been done before val workflow # runner.epoch += 1 has been done before val workflow
log_str = f'Epoch(val) [{cur_epoch}][{eval_iter}]\t' log_str = f'Epoch(val) [{cur_epoch}][{eval_iter}] '
else: else:
log_str = f'Iter(val) [{eval_iter}]\t' log_str = f'Iter(val) [{eval_iter}] '
log_items = [] log_items = []
for name, val in tag.items(): for name, val in tag.items():
......
...@@ -10,6 +10,7 @@ import numpy as np ...@@ -10,6 +10,7 @@ import numpy as np
import torch import torch
from mmengine.data import BaseDataElement from mmengine.data import BaseDataElement
from mmengine.dist import master_only
from mmengine.fileio import dump from mmengine.fileio import dump
from mmengine.registry import VISUALIZERS, WRITERS from mmengine.registry import VISUALIZERS, WRITERS
from mmengine.utils import TORCH_VERSION, ManagerMixin from mmengine.utils import TORCH_VERSION, ManagerMixin
...@@ -796,6 +797,7 @@ class ComposedWriter(ManagerMixin): ...@@ -796,6 +797,7 @@ class ComposedWriter(ManagerMixin):
for writer in self._writers: for writer in self._writers:
writer.add_scalar(name, value, step, **kwargs) writer.add_scalar(name, value, step, **kwargs)
@master_only
def add_scalars(self, def add_scalars(self,
scalar_dict: dict, scalar_dict: dict,
step: int = 0, step: int = 0,
......
...@@ -161,7 +161,7 @@ class TestLoggerHook: ...@@ -161,7 +161,7 @@ class TestLoggerHook:
eta_str = str(datetime.timedelta(seconds=int(eta_second))) eta_str = str(datetime.timedelta(seconds=int(eta_second)))
if by_epoch: if by_epoch:
if torch.cuda.is_available(): if torch.cuda.is_available():
log_str = 'Epoch [2][2/5]\t' \ log_str = 'Epoch [2][2/5] ' \
f"lr: {train_infos['lr']:.3e} " \ f"lr: {train_infos['lr']:.3e} " \
f"momentum: {train_infos['momentum']:.3e}, " \ f"momentum: {train_infos['momentum']:.3e}, " \
f'eta: {eta_str}, ' \ f'eta: {eta_str}, ' \
...@@ -170,7 +170,7 @@ class TestLoggerHook: ...@@ -170,7 +170,7 @@ class TestLoggerHook:
f'memory: 100, ' \ f'memory: 100, ' \
f"loss_cls: {train_infos['loss_cls']:.4f}\n" f"loss_cls: {train_infos['loss_cls']:.4f}\n"
else: else:
log_str = 'Epoch [2][2/5]\t' \ log_str = 'Epoch [2][2/5] ' \
f"lr: {train_infos['lr']:.3e} " \ f"lr: {train_infos['lr']:.3e} " \
f"momentum: {train_infos['momentum']:.3e}, " \ f"momentum: {train_infos['momentum']:.3e}, " \
f'eta: {eta_str}, ' \ f'eta: {eta_str}, ' \
...@@ -180,7 +180,7 @@ class TestLoggerHook: ...@@ -180,7 +180,7 @@ class TestLoggerHook:
assert out == log_str assert out == log_str
else: else:
if torch.cuda.is_available(): if torch.cuda.is_available():
log_str = 'Iter [11/50]\t' \ log_str = 'Iter [11/50] ' \
f"lr: {train_infos['lr']:.3e} " \ f"lr: {train_infos['lr']:.3e} " \
f"momentum: {train_infos['momentum']:.3e}, " \ f"momentum: {train_infos['momentum']:.3e}, " \
f'eta: {eta_str}, ' \ f'eta: {eta_str}, ' \
...@@ -189,7 +189,7 @@ class TestLoggerHook: ...@@ -189,7 +189,7 @@ class TestLoggerHook:
f'memory: 100, ' \ f'memory: 100, ' \
f"loss_cls: {train_infos['loss_cls']:.4f}\n" f"loss_cls: {train_infos['loss_cls']:.4f}\n"
else: else:
log_str = 'Iter [11/50]\t' \ log_str = 'Iter [11/50] ' \
f"lr: {train_infos['lr']:.3e} " \ f"lr: {train_infos['lr']:.3e} " \
f"momentum: {train_infos['momentum']:.3e}, " \ f"momentum: {train_infos['momentum']:.3e}, " \
f'eta: {eta_str}, ' \ f'eta: {eta_str}, ' \
...@@ -212,11 +212,11 @@ class TestLoggerHook: ...@@ -212,11 +212,11 @@ class TestLoggerHook:
runner.writer.add_scalars.assert_called_with( runner.writer.add_scalars.assert_called_with(
metric, step=11, file_path='tmp.json') metric, step=11, file_path='tmp.json')
if by_epoch: if by_epoch:
assert out == 'Epoch(val) [1][5]\taccuracy: 0.9000, ' \ assert out == 'Epoch(val) [1][5] accuracy: 0.9000, ' \
'data_time: 1.0000\n' 'data_time: 1.0000\n'
else: else:
assert out == 'Iter(val) [5]\taccuracy: 0.9000, ' \ assert out == 'Iter(val) [5] accuracy: 0.9000, ' \
'data_time: 1.0000\n' 'data_time: 1.0000\n'
def test_get_window_size(self): def test_get_window_size(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment