From 7154df26183b737aee33d0570c7b414d52e78e15 Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Wed, 22 Jun 2022 19:35:52 +0800 Subject: [PATCH] [Enhance] LogProcessor support custom significant digit (#311) * LogProcessor support custom significant digit * rename to num_digits --- mmengine/logging/log_processor.py | 19 ++++++++++++------- tests/test_logging/test_log_processor.py | 12 ++++++------ 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/mmengine/logging/log_processor.py b/mmengine/logging/log_processor.py index 23c0c0e4..4b743668 100644 --- a/mmengine/logging/log_processor.py +++ b/mmengine/logging/log_processor.py @@ -47,6 +47,8 @@ class LogProcessor: - For those statistic methods with the ``window_size`` argument, if ``by_epoch`` is set to False, ``windows_size`` should not be `epoch` to statistics log value by epoch. + num_digits (int): The number of significant digit shown in the + logging message. Examples: >>> # `log_name` is defined, `loss_large_window` will be an additional @@ -92,10 +94,12 @@ class LogProcessor: def __init__(self, window_size=10, by_epoch=True, - custom_cfg: Optional[List[dict]] = None): + custom_cfg: Optional[List[dict]] = None, + num_digits: int = 4): self.window_size = window_size self.by_epoch = by_epoch self.custom_cfg = custom_cfg if custom_cfg else [] + self.num_digits = num_digits self._check_custom_cfg() def get_log_after_iter(self, runner, batch_idx: int, @@ -124,9 +128,9 @@ class LogProcessor: # Record learning rate. lr_str_list = [] for key, value in tag.items(): - if key.startswith('lr'): + if key.endswith('lr'): log_tag.pop(key) - lr_str_list.append(f'{key}: {value:.3e}') + lr_str_list.append(f'{key}: ' f'{value:.{self.num_digits}e}') lr_str = ' '.join(lr_str_list) # Format log header. # by_epoch == True @@ -159,8 +163,9 @@ class LogProcessor: eta = runner.message_hub.get_info('eta') eta_str = str(datetime.timedelta(seconds=int(eta))) log_str += f'eta: {eta_str} ' - log_str += (f'time: {tag["time"]:.3f} ' - f'data_time: {tag["data_time"]:.3f} ') + log_str += (f'time: {tag["time"]:.{self.num_digits}f} ' + f'data_time: ' + f'{tag["data_time"]:.{self.num_digits}f} ') # Pop recorded keys log_tag.pop('time') log_tag.pop('data_time') @@ -175,7 +180,7 @@ class LogProcessor: if mode == 'val' and not name.startswith('val/loss'): continue if isinstance(val, float): - val = f'{val:.4f}' + val = f'{val:.{self.num_digits}f}' log_items.append(f'{name}: {val}') log_str += ' '.join(log_items) return tag, log_str @@ -228,7 +233,7 @@ class LogProcessor: log_items = [] for name, val in tag.items(): if isinstance(val, float): - val = f'{val:.4f}' + val = f'{val:.{self.num_digits}f}' log_items.append(f'{name}: {val}') log_str += ' '.join(log_items) return tag, log_str diff --git a/tests/test_logging/test_log_processor.py b/tests/test_logging/test_log_processor.py index 288e23e4..454ecc29 100644 --- a/tests/test_logging/test_log_processor.py +++ b/tests/test_logging/test_log_processor.py @@ -96,13 +96,13 @@ class TestLogProcessor: log_str = (f'Epoch({mode}) [2/{len(cur_loop.dataloader)}] ') if mode == 'train': - log_str += f"lr: {train_logs['lr']:.3e} " + log_str += f"lr: {train_logs['lr']:.4e} " else: log_str += ' ' log_str += (f'eta: 0:00:40 ' - f"time: {train_logs['time']:.3f} " - f"data_time: {train_logs['data_time']:.3f} ") + f"time: {train_logs['time']:.4f} " + f"data_time: {train_logs['data_time']:.4f} ") if torch.cuda.is_available(): log_str += 'memory: 100 ' @@ -118,13 +118,13 @@ class TestLogProcessor: log_str = f'Iter({mode}) [2/{max_iters}] ' if mode == 'train': - log_str += f"lr: {train_logs['lr']:.3e} " + log_str += f"lr: {train_logs['lr']:.4e} " else: log_str += ' ' log_str += (f'eta: 0:00:40 ' - f"time: {train_logs['time']:.3f} " - f"data_time: {train_logs['data_time']:.3f} ") + f"time: {train_logs['time']:.4f} " + f"data_time: {train_logs['data_time']:.4f} ") if torch.cuda.is_available(): log_str += 'memory: 100 ' -- GitLab