Skip to content
Snippets Groups Projects
Unverified Commit 7154df26 authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Enhance] LogProcessor support custom significant digit (#311)

* LogProcessor support custom significant digit

* rename to num_digits
parent 2086bc45
No related branches found
No related tags found
No related merge requests found
......@@ -47,6 +47,8 @@ class LogProcessor:
- For those statistic methods with the ``window_size`` argument,
if ``by_epoch`` is set to False, ``windows_size`` should not be
`epoch` to statistics log value by epoch.
num_digits (int): The number of significant digit shown in the
logging message.
Examples:
>>> # `log_name` is defined, `loss_large_window` will be an additional
......@@ -92,10 +94,12 @@ class LogProcessor:
def __init__(self,
window_size=10,
by_epoch=True,
custom_cfg: Optional[List[dict]] = None):
custom_cfg: Optional[List[dict]] = None,
num_digits: int = 4):
self.window_size = window_size
self.by_epoch = by_epoch
self.custom_cfg = custom_cfg if custom_cfg else []
self.num_digits = num_digits
self._check_custom_cfg()
def get_log_after_iter(self, runner, batch_idx: int,
......@@ -124,9 +128,9 @@ class LogProcessor:
# Record learning rate.
lr_str_list = []
for key, value in tag.items():
if key.startswith('lr'):
if key.endswith('lr'):
log_tag.pop(key)
lr_str_list.append(f'{key}: {value:.3e}')
lr_str_list.append(f'{key}: ' f'{value:.{self.num_digits}e}')
lr_str = ' '.join(lr_str_list)
# Format log header.
# by_epoch == True
......@@ -159,8 +163,9 @@ class LogProcessor:
eta = runner.message_hub.get_info('eta')
eta_str = str(datetime.timedelta(seconds=int(eta)))
log_str += f'eta: {eta_str} '
log_str += (f'time: {tag["time"]:.3f} '
f'data_time: {tag["data_time"]:.3f} ')
log_str += (f'time: {tag["time"]:.{self.num_digits}f} '
f'data_time: '
f'{tag["data_time"]:.{self.num_digits}f} ')
# Pop recorded keys
log_tag.pop('time')
log_tag.pop('data_time')
......@@ -175,7 +180,7 @@ class LogProcessor:
if mode == 'val' and not name.startswith('val/loss'):
continue
if isinstance(val, float):
val = f'{val:.4f}'
val = f'{val:.{self.num_digits}f}'
log_items.append(f'{name}: {val}')
log_str += ' '.join(log_items)
return tag, log_str
......@@ -228,7 +233,7 @@ class LogProcessor:
log_items = []
for name, val in tag.items():
if isinstance(val, float):
val = f'{val:.4f}'
val = f'{val:.{self.num_digits}f}'
log_items.append(f'{name}: {val}')
log_str += ' '.join(log_items)
return tag, log_str
......
......@@ -96,13 +96,13 @@ class TestLogProcessor:
log_str = (f'Epoch({mode}) [2/{len(cur_loop.dataloader)}] ')
if mode == 'train':
log_str += f"lr: {train_logs['lr']:.3e} "
log_str += f"lr: {train_logs['lr']:.4e} "
else:
log_str += ' '
log_str += (f'eta: 0:00:40 '
f"time: {train_logs['time']:.3f} "
f"data_time: {train_logs['data_time']:.3f} ")
f"time: {train_logs['time']:.4f} "
f"data_time: {train_logs['data_time']:.4f} ")
if torch.cuda.is_available():
log_str += 'memory: 100 '
......@@ -118,13 +118,13 @@ class TestLogProcessor:
log_str = f'Iter({mode}) [2/{max_iters}] '
if mode == 'train':
log_str += f"lr: {train_logs['lr']:.3e} "
log_str += f"lr: {train_logs['lr']:.4e} "
else:
log_str += ' '
log_str += (f'eta: 0:00:40 '
f"time: {train_logs['time']:.3f} "
f"data_time: {train_logs['data_time']:.3f} ")
f"time: {train_logs['time']:.4f} "
f"data_time: {train_logs['data_time']:.4f} ")
if torch.cuda.is_available():
log_str += 'memory: 100 '
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment