From 7154df26183b737aee33d0570c7b414d52e78e15 Mon Sep 17 00:00:00 2001
From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>
Date: Wed, 22 Jun 2022 19:35:52 +0800
Subject: [PATCH] [Enhance] LogProcessor support custom significant digit
 (#311)

* LogProcessor support custom significant digit

* rename to num_digits
---
 mmengine/logging/log_processor.py        | 19 ++++++++++++-------
 tests/test_logging/test_log_processor.py | 12 ++++++------
 2 files changed, 18 insertions(+), 13 deletions(-)

diff --git a/mmengine/logging/log_processor.py b/mmengine/logging/log_processor.py
index 23c0c0e4..4b743668 100644
--- a/mmengine/logging/log_processor.py
+++ b/mmengine/logging/log_processor.py
@@ -47,6 +47,8 @@ class LogProcessor:
             - For those statistic methods with the ``window_size`` argument,
               if ``by_epoch`` is set to False, ``windows_size`` should not be
               `epoch` to statistics log value by epoch.
+        num_digits (int): The number of significant digit shown in the
+            logging message.
 
     Examples:
         >>> # `log_name` is defined, `loss_large_window` will be an additional
@@ -92,10 +94,12 @@ class LogProcessor:
     def __init__(self,
                  window_size=10,
                  by_epoch=True,
-                 custom_cfg: Optional[List[dict]] = None):
+                 custom_cfg: Optional[List[dict]] = None,
+                 num_digits: int = 4):
         self.window_size = window_size
         self.by_epoch = by_epoch
         self.custom_cfg = custom_cfg if custom_cfg else []
+        self.num_digits = num_digits
         self._check_custom_cfg()
 
     def get_log_after_iter(self, runner, batch_idx: int,
@@ -124,9 +128,9 @@ class LogProcessor:
         # Record learning rate.
         lr_str_list = []
         for key, value in tag.items():
-            if key.startswith('lr'):
+            if key.endswith('lr'):
                 log_tag.pop(key)
-                lr_str_list.append(f'{key}: {value:.3e}')
+                lr_str_list.append(f'{key}: ' f'{value:.{self.num_digits}e}')
         lr_str = ' '.join(lr_str_list)
         # Format log header.
         # by_epoch == True
@@ -159,8 +163,9 @@ class LogProcessor:
             eta = runner.message_hub.get_info('eta')
             eta_str = str(datetime.timedelta(seconds=int(eta)))
             log_str += f'eta: {eta_str}  '
-            log_str += (f'time: {tag["time"]:.3f}  '
-                        f'data_time: {tag["data_time"]:.3f}  ')
+            log_str += (f'time: {tag["time"]:.{self.num_digits}f}  '
+                        f'data_time: '
+                        f'{tag["data_time"]:.{self.num_digits}f}  ')
             # Pop recorded keys
             log_tag.pop('time')
             log_tag.pop('data_time')
@@ -175,7 +180,7 @@ class LogProcessor:
                 if mode == 'val' and not name.startswith('val/loss'):
                     continue
                 if isinstance(val, float):
-                    val = f'{val:.4f}'
+                    val = f'{val:.{self.num_digits}f}'
                 log_items.append(f'{name}: {val}')
             log_str += '  '.join(log_items)
         return tag, log_str
@@ -228,7 +233,7 @@ class LogProcessor:
         log_items = []
         for name, val in tag.items():
             if isinstance(val, float):
-                val = f'{val:.4f}'
+                val = f'{val:.{self.num_digits}f}'
             log_items.append(f'{name}: {val}')
         log_str += '  '.join(log_items)
         return tag, log_str
diff --git a/tests/test_logging/test_log_processor.py b/tests/test_logging/test_log_processor.py
index 288e23e4..454ecc29 100644
--- a/tests/test_logging/test_log_processor.py
+++ b/tests/test_logging/test_log_processor.py
@@ -96,13 +96,13 @@ class TestLogProcessor:
                 log_str = (f'Epoch({mode}) [2/{len(cur_loop.dataloader)}]  ')
 
             if mode == 'train':
-                log_str += f"lr: {train_logs['lr']:.3e}  "
+                log_str += f"lr: {train_logs['lr']:.4e}  "
             else:
                 log_str += '  '
 
             log_str += (f'eta: 0:00:40  '
-                        f"time: {train_logs['time']:.3f}  "
-                        f"data_time: {train_logs['data_time']:.3f}  ")
+                        f"time: {train_logs['time']:.4f}  "
+                        f"data_time: {train_logs['data_time']:.4f}  ")
 
             if torch.cuda.is_available():
                 log_str += 'memory: 100  '
@@ -118,13 +118,13 @@ class TestLogProcessor:
                 log_str = f'Iter({mode}) [2/{max_iters}]  '
 
             if mode == 'train':
-                log_str += f"lr: {train_logs['lr']:.3e}  "
+                log_str += f"lr: {train_logs['lr']:.4e}  "
             else:
                 log_str += '  '
 
             log_str += (f'eta: 0:00:40  '
-                        f"time: {train_logs['time']:.3f}  "
-                        f"data_time: {train_logs['data_time']:.3f}  ")
+                        f"time: {train_logs['time']:.4f}  "
+                        f"data_time: {train_logs['data_time']:.4f}  ")
 
             if torch.cuda.is_available():
                 log_str += 'memory: 100  '
-- 
GitLab