Skip to content
Snippets Groups Projects
message_hub.py 3.61 KiB
Newer Older
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from collections import OrderedDict
from typing import Any, Union

from .base_global_accsessible import BaseGlobalAccessible
from .log_buffer import LogBuffer


class MessageHub(BaseGlobalAccessible):
    """Message hub for component interaction. MessageHub is created and
    accessed in the same way as BaseGlobalAccessible.

    ``MessageHub`` will record log information and runtime information. The
    log information refers to the learning rate, loss, etc. of the model
    when training a model, which will be stored as ``LogBuffer``. The runtime
    information refers to the iter times, meta information of runner etc.,
    which will be overwritten by next update.

    Args:
        name (str): Name of message hub, for global access. Defaults to ''.
    """

    def __init__(self, name: str = ''):
        self._log_buffers: OrderedDict = OrderedDict()
        self._runtime_info: OrderedDict = OrderedDict()
        super().__init__(name)

    def update_log(self, key: str, value: Union[int, float], count: int = 1) \
            -> None:
        """Update log buffer.

        Args:
            key (str): Key of ``LogBuffer``.
            value (int or float): Value of log.
            count (int): Accumulation times of log, defaults to 1. `count`
                will be used in smooth statistics.
        """
        if key in self._log_buffers:
            self._log_buffers[key].update(value, count)
        else:
            self._log_buffers[key] = LogBuffer([value], [count])

    def update_info(self, key: str, value: Any) -> None:
        """Update runtime information.

        Args:
            key (str): Key of runtime information.
            value (Any): Value of runtime information.
        """
        self._runtime_info[key] = value

    @property
    def log_buffers(self) -> OrderedDict:
        """Get all ``LogBuffer`` instances.

        Note:
            Considering the large memory footprint of ``log_buffers`` in the
            post-training, ``MessageHub.log_buffers`` will not return the
            result of ``copy.deepcopy``.

        Returns:
            OrderedDict: All ``LogBuffer`` instances.
        """
        return self._log_buffers

    @property
    def runtime_info(self) -> OrderedDict:
        """Get all runtime information.

        Returns:
            OrderedDict: A copy of all runtime information.
        """
        return copy.deepcopy(self._runtime_info)

    def get_log(self, key: str) -> LogBuffer:
        """Get ``LogBuffer`` instance by key.

        Note:
            Considering the large memory footprint of ``log_buffers`` in the
            post-training, ``MessageHub.get_log`` will not return the
            result of ``copy.deepcopy``.

        Args:
            key (str): Key of ``LogBuffer``.

        Returns:
            LogBuffer: Corresponding ``LogBuffer`` instance if the key exists.
        """
        if key not in self.log_buffers:
            raise KeyError(f'{key} is not found in Messagehub.log_buffers: '
                           f'instance name is: {MessageHub.instance_name}')
        return self._log_buffers[key]

    def get_info(self, key: str) -> Any:
        """Get runtime information by key.

        Args:
            key (str): Key of runtime information.

        Returns:
            Any: A copy of corresponding runtime information if the key exists.
        """
        if key not in self.runtime_info:
            raise KeyError(f'{key} is not found in Messagehub.log_buffers: '
                           f'instance name is: {MessageHub.instance_name}')
        return copy.deepcopy(self._runtime_info[key])