From fff4742e0b37db781837b268d065cbec18c24a79 Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Sun, 13 Mar 2022 21:28:47 +0800 Subject: [PATCH] [Enhancement] Messagehub supports update dict logs. (#120) * Messagehub support update dict values * Messagehub support update dict values * fix assertion error message --- mmengine/logging/message_hub.py | 54 ++++++++++++++++++++++++++ tests/test_logging/test_message_hub.py | 26 +++++++++++++ 2 files changed, 80 insertions(+) diff --git a/mmengine/logging/message_hub.py b/mmengine/logging/message_hub.py index 2b5087e8..a93d6038 100644 --- a/mmengine/logging/message_hub.py +++ b/mmengine/logging/message_hub.py @@ -3,6 +3,10 @@ import copy from collections import OrderedDict from typing import Any, Union +import numpy as np +import torch + +from mmengine.visualization.utils import check_type from .base_global_accsessible import BaseGlobalAccessible from .log_buffer import LogBuffer @@ -41,6 +45,34 @@ class MessageHub(BaseGlobalAccessible): else: self._log_buffers[key] = LogBuffer([value], [count]) + def update_log_vars(self, log_dict: dict) -> None: + """Update :attr:`_log_buffers` with a dict. + + Args: + log_dict (str): Used for batch updating :attr:`_log_buffers`. + + Examples: + >>> message_hub = MessageHub.create_instance() + >>> log_dict = dict(a=1, b=2, c=3) + >>> message_hub.update_log_vars(log_dict) + >>> # The default count of `a`, `b` and `c` is 1. + >>> log_dict = dict(a=1, b=2, c=dict(value=1, count=2)) + >>> message_hub.update_log_vars(log_dict) + >>> # The count of `c` is 2. + """ + assert isinstance(log_dict, dict), ('`log_dict` must be a dict!, ' + f'but got {type(log_dict)}') + for log_name, log_val in log_dict.items(): + if isinstance(log_val, dict): + assert 'value' in log_val, \ + f'value must be defined in {log_val}' + count = log_val.get('count', 1) + value = self._get_valid_value(log_name, log_val['value']) + else: + value = self._get_valid_value(log_name, log_val) + count = 1 + self.update_log(log_name, value, count) + def update_info(self, key: str, value: Any) -> None: """Update runtime information. @@ -105,3 +137,25 @@ class MessageHub(BaseGlobalAccessible): raise KeyError(f'{key} is not found in Messagehub.log_buffers: ' f'instance name is: {MessageHub.instance_name}') return copy.deepcopy(self._runtime_info[key]) + + def _get_valid_value(self, key: str, + value: Union[torch.Tensor, np.ndarray, int, float])\ + -> Union[int, float]: + """Convert value to python built-in type. + + Args: + key (str): name of log. + value (torch.Tensor or np.ndarray or int or float): value of log. + + Returns: + float or int: python built-in type value. + """ + if isinstance(value, np.ndarray): + assert value.size == 1 + value = value.item() + elif isinstance(value, torch.Tensor): + assert value.numel() == 1 + value = value.item() + else: + check_type(key, value, (int, float)) + return value diff --git a/tests/test_logging/test_message_hub.py b/tests/test_logging/test_message_hub.py index 3b4b6452..c9ec25a1 100644 --- a/tests/test_logging/test_message_hub.py +++ b/tests/test_logging/test_message_hub.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import numpy as np import pytest +import torch from mmengine import MessageHub @@ -55,3 +56,28 @@ class TestMessageHub: recorded_dict = dict(a=1, b=2) message_hub.update_info('test_value', recorded_dict) assert message_hub.get_info('test_value') == recorded_dict + + def test_get_log_vars(self): + message_hub = MessageHub.create_instance() + log_dict = dict( + loss=1, + loss_cls=torch.tensor(2), + loss_bbox=np.array(3), + loss_iou=dict(value=1, count=2)) + message_hub.update_log_vars(log_dict) + loss = message_hub.get_log('loss') + loss_cls = message_hub.get_log('loss_cls') + loss_bbox = message_hub.get_log('loss_bbox') + loss_iou = message_hub.get_log('loss_iou') + assert loss.current() == 1 + assert loss_cls.current() == 2 + assert loss_bbox.current() == 3 + assert loss_iou.mean() == 0.5 + + with pytest.raises(TypeError): + loss_dict = dict(error_type=[]) + message_hub.update_log_vars(loss_dict) + + with pytest.raises(AssertionError): + loss_dict = dict(error_type=dict(count=1)) + message_hub.update_log_vars(loss_dict) -- GitLab