Skip to content
Snippets Groups Projects
Unverified Commit 7154df26 authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Enhance] LogProcessor support custom significant digit (#311)

* LogProcessor support custom significant digit

* rename to num_digits
parent 2086bc45
No related branches found
No related tags found
No related merge requests found
...@@ -47,6 +47,8 @@ class LogProcessor: ...@@ -47,6 +47,8 @@ class LogProcessor:
- For those statistic methods with the ``window_size`` argument, - For those statistic methods with the ``window_size`` argument,
if ``by_epoch`` is set to False, ``windows_size`` should not be if ``by_epoch`` is set to False, ``windows_size`` should not be
`epoch` to statistics log value by epoch. `epoch` to statistics log value by epoch.
num_digits (int): The number of significant digit shown in the
logging message.
Examples: Examples:
>>> # `log_name` is defined, `loss_large_window` will be an additional >>> # `log_name` is defined, `loss_large_window` will be an additional
...@@ -92,10 +94,12 @@ class LogProcessor: ...@@ -92,10 +94,12 @@ class LogProcessor:
def __init__(self, def __init__(self,
window_size=10, window_size=10,
by_epoch=True, by_epoch=True,
custom_cfg: Optional[List[dict]] = None): custom_cfg: Optional[List[dict]] = None,
num_digits: int = 4):
self.window_size = window_size self.window_size = window_size
self.by_epoch = by_epoch self.by_epoch = by_epoch
self.custom_cfg = custom_cfg if custom_cfg else [] self.custom_cfg = custom_cfg if custom_cfg else []
self.num_digits = num_digits
self._check_custom_cfg() self._check_custom_cfg()
def get_log_after_iter(self, runner, batch_idx: int, def get_log_after_iter(self, runner, batch_idx: int,
...@@ -124,9 +128,9 @@ class LogProcessor: ...@@ -124,9 +128,9 @@ class LogProcessor:
# Record learning rate. # Record learning rate.
lr_str_list = [] lr_str_list = []
for key, value in tag.items(): for key, value in tag.items():
if key.startswith('lr'): if key.endswith('lr'):
log_tag.pop(key) log_tag.pop(key)
lr_str_list.append(f'{key}: {value:.3e}') lr_str_list.append(f'{key}: ' f'{value:.{self.num_digits}e}')
lr_str = ' '.join(lr_str_list) lr_str = ' '.join(lr_str_list)
# Format log header. # Format log header.
# by_epoch == True # by_epoch == True
...@@ -159,8 +163,9 @@ class LogProcessor: ...@@ -159,8 +163,9 @@ class LogProcessor:
eta = runner.message_hub.get_info('eta') eta = runner.message_hub.get_info('eta')
eta_str = str(datetime.timedelta(seconds=int(eta))) eta_str = str(datetime.timedelta(seconds=int(eta)))
log_str += f'eta: {eta_str} ' log_str += f'eta: {eta_str} '
log_str += (f'time: {tag["time"]:.3f} ' log_str += (f'time: {tag["time"]:.{self.num_digits}f} '
f'data_time: {tag["data_time"]:.3f} ') f'data_time: '
f'{tag["data_time"]:.{self.num_digits}f} ')
# Pop recorded keys # Pop recorded keys
log_tag.pop('time') log_tag.pop('time')
log_tag.pop('data_time') log_tag.pop('data_time')
...@@ -175,7 +180,7 @@ class LogProcessor: ...@@ -175,7 +180,7 @@ class LogProcessor:
if mode == 'val' and not name.startswith('val/loss'): if mode == 'val' and not name.startswith('val/loss'):
continue continue
if isinstance(val, float): if isinstance(val, float):
val = f'{val:.4f}' val = f'{val:.{self.num_digits}f}'
log_items.append(f'{name}: {val}') log_items.append(f'{name}: {val}')
log_str += ' '.join(log_items) log_str += ' '.join(log_items)
return tag, log_str return tag, log_str
...@@ -228,7 +233,7 @@ class LogProcessor: ...@@ -228,7 +233,7 @@ class LogProcessor:
log_items = [] log_items = []
for name, val in tag.items(): for name, val in tag.items():
if isinstance(val, float): if isinstance(val, float):
val = f'{val:.4f}' val = f'{val:.{self.num_digits}f}'
log_items.append(f'{name}: {val}') log_items.append(f'{name}: {val}')
log_str += ' '.join(log_items) log_str += ' '.join(log_items)
return tag, log_str return tag, log_str
......
...@@ -96,13 +96,13 @@ class TestLogProcessor: ...@@ -96,13 +96,13 @@ class TestLogProcessor:
log_str = (f'Epoch({mode}) [2/{len(cur_loop.dataloader)}] ') log_str = (f'Epoch({mode}) [2/{len(cur_loop.dataloader)}] ')
if mode == 'train': if mode == 'train':
log_str += f"lr: {train_logs['lr']:.3e} " log_str += f"lr: {train_logs['lr']:.4e} "
else: else:
log_str += ' ' log_str += ' '
log_str += (f'eta: 0:00:40 ' log_str += (f'eta: 0:00:40 '
f"time: {train_logs['time']:.3f} " f"time: {train_logs['time']:.4f} "
f"data_time: {train_logs['data_time']:.3f} ") f"data_time: {train_logs['data_time']:.4f} ")
if torch.cuda.is_available(): if torch.cuda.is_available():
log_str += 'memory: 100 ' log_str += 'memory: 100 '
...@@ -118,13 +118,13 @@ class TestLogProcessor: ...@@ -118,13 +118,13 @@ class TestLogProcessor:
log_str = f'Iter({mode}) [2/{max_iters}] ' log_str = f'Iter({mode}) [2/{max_iters}] '
if mode == 'train': if mode == 'train':
log_str += f"lr: {train_logs['lr']:.3e} " log_str += f"lr: {train_logs['lr']:.4e} "
else: else:
log_str += ' ' log_str += ' '
log_str += (f'eta: 0:00:40 ' log_str += (f'eta: 0:00:40 '
f"time: {train_logs['time']:.3f} " f"time: {train_logs['time']:.4f} "
f"data_time: {train_logs['data_time']:.3f} ") f"data_time: {train_logs['data_time']:.4f} ")
if torch.cuda.is_available(): if torch.cuda.is_available():
log_str += 'memory: 100 ' log_str += 'memory: 100 '
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment