From fff4742e0b37db781837b268d065cbec18c24a79 Mon Sep 17 00:00:00 2001
From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>
Date: Sun, 13 Mar 2022 21:28:47 +0800
Subject: [PATCH] [Enhancement] Messagehub supports update dict logs. (#120)

* Messagehub support update dict values

* Messagehub support update dict values

* fix assertion error message
---
 mmengine/logging/message_hub.py        | 54 ++++++++++++++++++++++++++
 tests/test_logging/test_message_hub.py | 26 +++++++++++++
 2 files changed, 80 insertions(+)

diff --git a/mmengine/logging/message_hub.py b/mmengine/logging/message_hub.py
index 2b5087e8..a93d6038 100644
--- a/mmengine/logging/message_hub.py
+++ b/mmengine/logging/message_hub.py
@@ -3,6 +3,10 @@ import copy
 from collections import OrderedDict
 from typing import Any, Union
 
+import numpy as np
+import torch
+
+from mmengine.visualization.utils import check_type
 from .base_global_accsessible import BaseGlobalAccessible
 from .log_buffer import LogBuffer
 
@@ -41,6 +45,34 @@ class MessageHub(BaseGlobalAccessible):
         else:
             self._log_buffers[key] = LogBuffer([value], [count])
 
+    def update_log_vars(self, log_dict: dict) -> None:
+        """Update :attr:`_log_buffers` with a dict.
+
+        Args:
+            log_dict (str): Used for batch updating :attr:`_log_buffers`.
+
+        Examples:
+            >>> message_hub = MessageHub.create_instance()
+            >>> log_dict = dict(a=1, b=2, c=3)
+            >>> message_hub.update_log_vars(log_dict)
+            >>> # The default count of  `a`, `b` and `c` is 1.
+            >>> log_dict = dict(a=1, b=2, c=dict(value=1, count=2))
+            >>> message_hub.update_log_vars(log_dict)
+            >>> # The count of `c` is 2.
+        """
+        assert isinstance(log_dict, dict), ('`log_dict` must be a dict!, '
+                                            f'but got {type(log_dict)}')
+        for log_name, log_val in log_dict.items():
+            if isinstance(log_val, dict):
+                assert 'value' in log_val, \
+                    f'value must be defined in {log_val}'
+                count = log_val.get('count', 1)
+                value = self._get_valid_value(log_name, log_val['value'])
+            else:
+                value = self._get_valid_value(log_name, log_val)
+                count = 1
+            self.update_log(log_name, value, count)
+
     def update_info(self, key: str, value: Any) -> None:
         """Update runtime information.
 
@@ -105,3 +137,25 @@ class MessageHub(BaseGlobalAccessible):
             raise KeyError(f'{key} is not found in Messagehub.log_buffers: '
                            f'instance name is: {MessageHub.instance_name}')
         return copy.deepcopy(self._runtime_info[key])
+
+    def _get_valid_value(self, key: str,
+                         value: Union[torch.Tensor, np.ndarray, int, float])\
+            -> Union[int, float]:
+        """Convert value to python built-in type.
+
+        Args:
+            key (str): name of log.
+            value (torch.Tensor or np.ndarray or int or float): value of log.
+
+        Returns:
+            float or int: python built-in type value.
+        """
+        if isinstance(value, np.ndarray):
+            assert value.size == 1
+            value = value.item()
+        elif isinstance(value, torch.Tensor):
+            assert value.numel() == 1
+            value = value.item()
+        else:
+            check_type(key, value, (int, float))
+        return value
diff --git a/tests/test_logging/test_message_hub.py b/tests/test_logging/test_message_hub.py
index 3b4b6452..c9ec25a1 100644
--- a/tests/test_logging/test_message_hub.py
+++ b/tests/test_logging/test_message_hub.py
@@ -1,6 +1,7 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 import numpy as np
 import pytest
+import torch
 
 from mmengine import MessageHub
 
@@ -55,3 +56,28 @@ class TestMessageHub:
         recorded_dict = dict(a=1, b=2)
         message_hub.update_info('test_value', recorded_dict)
         assert message_hub.get_info('test_value') == recorded_dict
+
+    def test_get_log_vars(self):
+        message_hub = MessageHub.create_instance()
+        log_dict = dict(
+            loss=1,
+            loss_cls=torch.tensor(2),
+            loss_bbox=np.array(3),
+            loss_iou=dict(value=1, count=2))
+        message_hub.update_log_vars(log_dict)
+        loss = message_hub.get_log('loss')
+        loss_cls = message_hub.get_log('loss_cls')
+        loss_bbox = message_hub.get_log('loss_bbox')
+        loss_iou = message_hub.get_log('loss_iou')
+        assert loss.current() == 1
+        assert loss_cls.current() == 2
+        assert loss_bbox.current() == 3
+        assert loss_iou.mean() == 0.5
+
+        with pytest.raises(TypeError):
+            loss_dict = dict(error_type=[])
+            message_hub.update_log_vars(loss_dict)
+
+        with pytest.raises(AssertionError):
+            loss_dict = dict(error_type=dict(count=1))
+            message_hub.update_log_vars(loss_dict)
-- 
GitLab