Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from collections import OrderedDict
from typing import Any, Union
from .base_global_accsessible import BaseGlobalAccessible
from .log_buffer import LogBuffer
class MessageHub(BaseGlobalAccessible):
"""Message hub for component interaction. MessageHub is created and
accessed in the same way as BaseGlobalAccessible.
``MessageHub`` will record log information and runtime information. The
log information refers to the learning rate, loss, etc. of the model
when training a model, which will be stored as ``LogBuffer``. The runtime
information refers to the iter times, meta information of runner etc.,
which will be overwritten by next update.
Args:
name (str): Name of message hub, for global access. Defaults to ''.
"""
def __init__(self, name: str = ''):
self._log_buffers: OrderedDict = OrderedDict()
self._runtime_info: OrderedDict = OrderedDict()
super().__init__(name)
def update_log(self, key: str, value: Union[int, float], count: int = 1) \
-> None:
"""Update log buffer.
Args:
key (str): Key of ``LogBuffer``.
value (int or float): Value of log.
count (int): Accumulation times of log, defaults to 1. `count`
will be used in smooth statistics.
"""
if key in self._log_buffers:
self._log_buffers[key].update(value, count)
else:
self._log_buffers[key] = LogBuffer([value], [count])
def update_info(self, key: str, value: Any) -> None:
"""Update runtime information.
Args:
key (str): Key of runtime information.
value (Any): Value of runtime information.
"""
self._runtime_info[key] = value
@property
def log_buffers(self) -> OrderedDict:
"""Get all ``LogBuffer`` instances.
Note:
Considering the large memory footprint of ``log_buffers`` in the
post-training, ``MessageHub.log_buffers`` will not return the
result of ``copy.deepcopy``.
Returns:
OrderedDict: All ``LogBuffer`` instances.
"""
return self._log_buffers
@property
def runtime_info(self) -> OrderedDict:
"""Get all runtime information.
Returns:
OrderedDict: A copy of all runtime information.
"""
return copy.deepcopy(self._runtime_info)
def get_log(self, key: str) -> LogBuffer:
"""Get ``LogBuffer`` instance by key.
Note:
Considering the large memory footprint of ``log_buffers`` in the
post-training, ``MessageHub.get_log`` will not return the
result of ``copy.deepcopy``.
Args:
key (str): Key of ``LogBuffer``.
Returns:
LogBuffer: Corresponding ``LogBuffer`` instance if the key exists.
"""
if key not in self.log_buffers:
raise KeyError(f'{key} is not found in Messagehub.log_buffers: '
f'instance name is: {MessageHub.instance_name}')
return self._log_buffers[key]
def get_info(self, key: str) -> Any:
"""Get runtime information by key.
Args:
key (str): Key of runtime information.
Returns:
Any: A copy of corresponding runtime information if the key exists.
"""
if key not in self.runtime_info:
raise KeyError(f'{key} is not found in Messagehub.log_buffers: '
f'instance name is: {MessageHub.instance_name}')
return copy.deepcopy(self._runtime_info[key])