Newer
Older
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from collections import OrderedDict
from typing import Any, Union
import numpy as np
import torch
from mmengine.visualization.utils import check_type
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from .base_global_accsessible import BaseGlobalAccessible
from .log_buffer import LogBuffer
class MessageHub(BaseGlobalAccessible):
"""Message hub for component interaction. MessageHub is created and
accessed in the same way as BaseGlobalAccessible.
``MessageHub`` will record log information and runtime information. The
log information refers to the learning rate, loss, etc. of the model
when training a model, which will be stored as ``LogBuffer``. The runtime
information refers to the iter times, meta information of runner etc.,
which will be overwritten by next update.
Args:
name (str): Name of message hub, for global access. Defaults to ''.
"""
def __init__(self, name: str = ''):
self._log_buffers: OrderedDict = OrderedDict()
self._runtime_info: OrderedDict = OrderedDict()
super().__init__(name)
def update_log(self, key: str, value: Union[int, float], count: int = 1) \
-> None:
"""Update log buffer.
Args:
key (str): Key of ``LogBuffer``.
value (int or float): Value of log.
count (int): Accumulation times of log, defaults to 1. `count`
will be used in smooth statistics.
"""
if key in self._log_buffers:
self._log_buffers[key].update(value, count)
else:
self._log_buffers[key] = LogBuffer([value], [count])
def update_log_vars(self, log_dict: dict) -> None:
"""Update :attr:`_log_buffers` with a dict.
Args:
log_dict (str): Used for batch updating :attr:`_log_buffers`.
Examples:
>>> message_hub = MessageHub.create_instance()
>>> log_dict = dict(a=1, b=2, c=3)
>>> message_hub.update_log_vars(log_dict)
>>> # The default count of `a`, `b` and `c` is 1.
>>> log_dict = dict(a=1, b=2, c=dict(value=1, count=2))
>>> message_hub.update_log_vars(log_dict)
>>> # The count of `c` is 2.
"""
assert isinstance(log_dict, dict), ('`log_dict` must be a dict!, '
f'but got {type(log_dict)}')
for log_name, log_val in log_dict.items():
if isinstance(log_val, dict):
assert 'value' in log_val, \
f'value must be defined in {log_val}'
count = log_val.get('count', 1)
value = self._get_valid_value(log_name, log_val['value'])
else:
value = self._get_valid_value(log_name, log_val)
count = 1
self.update_log(log_name, value, count)
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
def update_info(self, key: str, value: Any) -> None:
"""Update runtime information.
Args:
key (str): Key of runtime information.
value (Any): Value of runtime information.
"""
self._runtime_info[key] = value
@property
def log_buffers(self) -> OrderedDict:
"""Get all ``LogBuffer`` instances.
Note:
Considering the large memory footprint of ``log_buffers`` in the
post-training, ``MessageHub.log_buffers`` will not return the
result of ``copy.deepcopy``.
Returns:
OrderedDict: All ``LogBuffer`` instances.
"""
return self._log_buffers
@property
def runtime_info(self) -> OrderedDict:
"""Get all runtime information.
Returns:
OrderedDict: A copy of all runtime information.
"""
return copy.deepcopy(self._runtime_info)
def get_log(self, key: str) -> LogBuffer:
"""Get ``LogBuffer`` instance by key.
Note:
Considering the large memory footprint of ``log_buffers`` in the
post-training, ``MessageHub.get_log`` will not return the
result of ``copy.deepcopy``.
Args:
key (str): Key of ``LogBuffer``.
Returns:
LogBuffer: Corresponding ``LogBuffer`` instance if the key exists.
"""
if key not in self.log_buffers:
raise KeyError(f'{key} is not found in Messagehub.log_buffers: '
f'instance name is: {MessageHub.instance_name}')
return self._log_buffers[key]
def get_info(self, key: str) -> Any:
"""Get runtime information by key.
Args:
key (str): Key of runtime information.
Returns:
Any: A copy of corresponding runtime information if the key exists.
"""
if key not in self.runtime_info:
raise KeyError(f'{key} is not found in Messagehub.log_buffers: '
f'instance name is: {MessageHub.instance_name}')
return copy.deepcopy(self._runtime_info[key])
def _get_valid_value(self, key: str,
value: Union[torch.Tensor, np.ndarray, int, float])\
-> Union[int, float]:
"""Convert value to python built-in type.
Args:
key (str): name of log.
value (torch.Tensor or np.ndarray or int or float): value of log.
Returns:
float or int: python built-in type value.
"""
if isinstance(value, np.ndarray):
assert value.size == 1
value = value.item()
elif isinstance(value, torch.Tensor):
assert value.numel() == 1
value = value.item()
else:
check_type(key, value, (int, float))
return value