Skip to content
Snippets Groups Projects
Unverified Commit fff4742e authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Enhancement] Messagehub supports update dict logs. (#120)

* Messagehub support update dict values

* Messagehub support update dict values

* fix assertion error message
parent f548c818
No related branches found
No related tags found
No related merge requests found
......@@ -3,6 +3,10 @@ import copy
from collections import OrderedDict
from typing import Any, Union
import numpy as np
import torch
from mmengine.visualization.utils import check_type
from .base_global_accsessible import BaseGlobalAccessible
from .log_buffer import LogBuffer
......@@ -41,6 +45,34 @@ class MessageHub(BaseGlobalAccessible):
else:
self._log_buffers[key] = LogBuffer([value], [count])
def update_log_vars(self, log_dict: dict) -> None:
"""Update :attr:`_log_buffers` with a dict.
Args:
log_dict (str): Used for batch updating :attr:`_log_buffers`.
Examples:
>>> message_hub = MessageHub.create_instance()
>>> log_dict = dict(a=1, b=2, c=3)
>>> message_hub.update_log_vars(log_dict)
>>> # The default count of `a`, `b` and `c` is 1.
>>> log_dict = dict(a=1, b=2, c=dict(value=1, count=2))
>>> message_hub.update_log_vars(log_dict)
>>> # The count of `c` is 2.
"""
assert isinstance(log_dict, dict), ('`log_dict` must be a dict!, '
f'but got {type(log_dict)}')
for log_name, log_val in log_dict.items():
if isinstance(log_val, dict):
assert 'value' in log_val, \
f'value must be defined in {log_val}'
count = log_val.get('count', 1)
value = self._get_valid_value(log_name, log_val['value'])
else:
value = self._get_valid_value(log_name, log_val)
count = 1
self.update_log(log_name, value, count)
def update_info(self, key: str, value: Any) -> None:
"""Update runtime information.
......@@ -105,3 +137,25 @@ class MessageHub(BaseGlobalAccessible):
raise KeyError(f'{key} is not found in Messagehub.log_buffers: '
f'instance name is: {MessageHub.instance_name}')
return copy.deepcopy(self._runtime_info[key])
def _get_valid_value(self, key: str,
value: Union[torch.Tensor, np.ndarray, int, float])\
-> Union[int, float]:
"""Convert value to python built-in type.
Args:
key (str): name of log.
value (torch.Tensor or np.ndarray or int or float): value of log.
Returns:
float or int: python built-in type value.
"""
if isinstance(value, np.ndarray):
assert value.size == 1
value = value.item()
elif isinstance(value, torch.Tensor):
assert value.numel() == 1
value = value.item()
else:
check_type(key, value, (int, float))
return value
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch
from mmengine import MessageHub
......@@ -55,3 +56,28 @@ class TestMessageHub:
recorded_dict = dict(a=1, b=2)
message_hub.update_info('test_value', recorded_dict)
assert message_hub.get_info('test_value') == recorded_dict
def test_get_log_vars(self):
message_hub = MessageHub.create_instance()
log_dict = dict(
loss=1,
loss_cls=torch.tensor(2),
loss_bbox=np.array(3),
loss_iou=dict(value=1, count=2))
message_hub.update_log_vars(log_dict)
loss = message_hub.get_log('loss')
loss_cls = message_hub.get_log('loss_cls')
loss_bbox = message_hub.get_log('loss_bbox')
loss_iou = message_hub.get_log('loss_iou')
assert loss.current() == 1
assert loss_cls.current() == 2
assert loss_bbox.current() == 3
assert loss_iou.mean() == 0.5
with pytest.raises(TypeError):
loss_dict = dict(error_type=[])
message_hub.update_log_vars(loss_dict)
with pytest.raises(AssertionError):
loss_dict = dict(error_type=dict(count=1))
message_hub.update_log_vars(loss_dict)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment