# Copyright (c) OpenMMLab. All rights reserved. import copy from collections import OrderedDict from typing import Any, Union import numpy as np import torch from mmengine.visualization.utils import check_type from .base_global_accsessible import BaseGlobalAccessible from .log_buffer import LogBuffer class MessageHub(BaseGlobalAccessible): """Message hub for component interaction. MessageHub is created and accessed in the same way as BaseGlobalAccessible. ``MessageHub`` will record log information and runtime information. The log information refers to the learning rate, loss, etc. of the model when training a model, which will be stored as ``LogBuffer``. The runtime information refers to the iter times, meta information of runner etc., which will be overwritten by next update. Args: name (str): Name of message hub, for global access. Defaults to ''. """ def __init__(self, name: str = ''): self._log_buffers: OrderedDict = OrderedDict() self._runtime_info: OrderedDict = OrderedDict() super().__init__(name) def update_log(self, key: str, value: Union[int, float], count: int = 1) \ -> None: """Update log buffer. Args: key (str): Key of ``LogBuffer``. value (int or float): Value of log. count (int): Accumulation times of log, defaults to 1. `count` will be used in smooth statistics. """ if key in self._log_buffers: self._log_buffers[key].update(value, count) else: self._log_buffers[key] = LogBuffer([value], [count]) def update_log_vars(self, log_dict: dict) -> None: """Update :attr:`_log_buffers` with a dict. Args: log_dict (str): Used for batch updating :attr:`_log_buffers`. Examples: >>> message_hub = MessageHub.create_instance() >>> log_dict = dict(a=1, b=2, c=3) >>> message_hub.update_log_vars(log_dict) >>> # The default count of `a`, `b` and `c` is 1. >>> log_dict = dict(a=1, b=2, c=dict(value=1, count=2)) >>> message_hub.update_log_vars(log_dict) >>> # The count of `c` is 2. """ assert isinstance(log_dict, dict), ('`log_dict` must be a dict!, ' f'but got {type(log_dict)}') for log_name, log_val in log_dict.items(): if isinstance(log_val, dict): assert 'value' in log_val, \ f'value must be defined in {log_val}' count = log_val.get('count', 1) value = self._get_valid_value(log_name, log_val['value']) else: value = self._get_valid_value(log_name, log_val) count = 1 self.update_log(log_name, value, count) def update_info(self, key: str, value: Any) -> None: """Update runtime information. Args: key (str): Key of runtime information. value (Any): Value of runtime information. """ self._runtime_info[key] = value @property def log_buffers(self) -> OrderedDict: """Get all ``LogBuffer`` instances. Note: Considering the large memory footprint of ``log_buffers`` in the post-training, ``MessageHub.log_buffers`` will not return the result of ``copy.deepcopy``. Returns: OrderedDict: All ``LogBuffer`` instances. """ return self._log_buffers @property def runtime_info(self) -> OrderedDict: """Get all runtime information. Returns: OrderedDict: A copy of all runtime information. """ return copy.deepcopy(self._runtime_info) def get_log(self, key: str) -> LogBuffer: """Get ``LogBuffer`` instance by key. Note: Considering the large memory footprint of ``log_buffers`` in the post-training, ``MessageHub.get_log`` will not return the result of ``copy.deepcopy``. Args: key (str): Key of ``LogBuffer``. Returns: LogBuffer: Corresponding ``LogBuffer`` instance if the key exists. """ if key not in self.log_buffers: raise KeyError(f'{key} is not found in Messagehub.log_buffers: ' f'instance name is: {MessageHub.instance_name}') return self._log_buffers[key] def get_info(self, key: str) -> Any: """Get runtime information by key. Args: key (str): Key of runtime information. Returns: Any: A copy of corresponding runtime information if the key exists. """ if key not in self.runtime_info: raise KeyError(f'{key} is not found in Messagehub.log_buffers: ' f'instance name is: {MessageHub.instance_name}') return copy.deepcopy(self._runtime_info[key]) def _get_valid_value(self, key: str, value: Union[torch.Tensor, np.ndarray, int, float])\ -> Union[int, float]: """Convert value to python built-in type. Args: key (str): name of log. value (torch.Tensor or np.ndarray or int or float): value of log. Returns: float or int: python built-in type value. """ if isinstance(value, np.ndarray): assert value.size == 1 value = value.item() elif isinstance(value, torch.Tensor): assert value.numel() == 1 value = value.item() else: check_type(key, value, (int, float)) return value