From 6015fd35e5cb8f3a06486802682b5a3fd5542ecd Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Tue, 28 Jun 2022 11:04:55 +0800 Subject: [PATCH] Fix docstring format (#337) --- mmengine/logging/log_processor.py | 7 +++--- mmengine/logging/logger.py | 1 + mmengine/model/__init__.py | 16 ++++++------- mmengine/model/averaged_model.py | 37 ++++++++++++++++++------------- 4 files changed, 34 insertions(+), 27 deletions(-) diff --git a/mmengine/logging/log_processor.py b/mmengine/logging/log_processor.py index 4b743668..d1fd8a4b 100644 --- a/mmengine/logging/log_processor.py +++ b/mmengine/logging/log_processor.py @@ -27,7 +27,8 @@ class LogProcessor: custom_cfg (list[dict], optional): Contains multiple log config dict, in which key means the data source name of log and value means the statistic method and corresponding arguments used to count the - data source. Defaults to None + data source. Defaults to None. + - If custom_cfg is None, all logs will be formatted via default methods, such as smoothing loss by default window_size. If custom_cfg is defined as a list of config dict, for example: @@ -35,12 +36,12 @@ class LogProcessor: window_size='global')]. It means the log item ``loss`` will be counted as global mean and additionally logged as ``global_loss`` (defined by ``log_name``). If ``log_name`` is not defined in - config dict, the original logged key will be overwritten. + config dict, the original logged key will be overwritten. - The original log item cannot be overwritten twice. Here is an error example: [dict(data_src=loss, method='mean', window_size='global'), - dict(data_src=loss, method='mean', window_size='epoch')]. + dict(data_src=loss, method='mean', window_size='epoch')]. Both log config dict in custom_cfg do not have ``log_name`` key, which means the loss item will be overwritten twice. diff --git a/mmengine/logging/logger.py b/mmengine/logging/logger.py index 85025dba..059110c4 100644 --- a/mmengine/logging/logger.py +++ b/mmengine/logging/logger.py @@ -246,6 +246,7 @@ def print_log(msg, logger (Logger or str, optional): If the type of logger is ``logging.Logger``, we directly use logger to log messages. Some special loggers are: + - "silent": No message will be printed. - "current": Use latest created logger to log message. - other str: Instance name of logger. The corresponding logger diff --git a/mmengine/model/__init__.py b/mmengine/model/__init__.py index 8b8203fb..3da94604 100644 --- a/mmengine/model/__init__.py +++ b/mmengine/model/__init__.py @@ -1,8 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmengine.utils.parrots_wrapper import TORCH_VERSION from mmengine.utils.version_utils import digit_version -from .averaged_model import (ExponentialMovingAverage, MomentumAnnealingEMA, - StochasticWeightAverage) +from .averaged_model import (BaseAveragedModel, ExponentialMovingAverage, + MomentumAnnealingEMA, StochasticWeightAverage) from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor from .base_module import BaseModule, ModuleDict, ModuleList, Sequential from .utils import detect_anomalous_params, merge_dict, stack_batch @@ -10,12 +10,12 @@ from .wrappers import (MMDistributedDataParallel, MMSeparateDistributedDataParallel, is_model_wrapper) __all__ = [ - 'MMDistributedDataParallel', 'is_model_wrapper', 'StochasticWeightAverage', - 'ExponentialMovingAverage', 'MomentumAnnealingEMA', 'BaseModel', - 'BaseDataPreprocessor', 'ImgDataPreprocessor', - 'MMSeparateDistributedDataParallel', 'BaseModule', 'stack_batch', - 'merge_dict', 'detect_anomalous_params', 'ModuleList', 'ModuleDict', - 'Sequential' + 'MMDistributedDataParallel', 'is_model_wrapper', 'BaseAveragedModel', + 'StochasticWeightAverage', 'ExponentialMovingAverage', + 'MomentumAnnealingEMA', 'BaseModel', 'BaseDataPreprocessor', + 'ImgDataPreprocessor', 'MMSeparateDistributedDataParallel', 'BaseModule', + 'stack_batch', 'merge_dict', 'detect_anomalous_params', 'ModuleList', + 'ModuleDict', 'Sequential' ] if digit_version(TORCH_VERSION) >= digit_version('1.11.0'): diff --git a/mmengine/model/averaged_model.py b/mmengine/model/averaged_model.py index e2f75e99..8a47e8e0 100644 --- a/mmengine/model/averaged_model.py +++ b/mmengine/model/averaged_model.py @@ -17,25 +17,28 @@ class BaseAveragedModel(nn.Module): training neural networks. This class implements the averaging process for a model. All subclasses must implement the `avg_func` method. This class creates a copy of the provided module :attr:`model` - on the device :attr:`device` and allows computing running averages of the + on the :attr:`device` and allows computing running averages of the parameters of the :attr:`model`. + The code is referenced from: https://github.com/pytorch/pytorch/blob/master/torch/optim/swa_utils.py. + Different from the `AveragedModel` in PyTorch, we use in-place operation to improve the parameter updating speed, which is about 5 times faster than the non-in-place version. In mmengine, we provide two ways to use the model averaging: + 1. Use the model averaging module in hook: - We provide an EMAHook to apply the model averaging during training. - Add ``custom_hooks=[dict(type='EMAHook')]`` to the config or the runner. - The hook is implemented in mmengine/hooks/ema_hook.py + We provide an EMAHook to apply the model averaging during training. + Add ``custom_hooks=[dict(type='EMAHook')]`` to the config or the runner. + The hook is implemented in mmengine/hooks/ema_hook.py 2. Use the model averaging module directly in the algorithm. Take the ema teacher in semi-supervise as an example: - >>> from mmengine.model import ExponentialMovingAverage - >>> student = ResNet(depth=50) - >>> # use ema model as teacher - >>> ema_teacher = ExponentialMovingAverage(student) + >>> from mmengine.model import ExponentialMovingAverage + >>> student = ResNet(depth=50) + >>> # use ema model as teacher + >>> ema_teacher = ExponentialMovingAverage(student) Args: model (nn.Module): The model to be averaged. @@ -134,7 +137,7 @@ class StochasticWeightAverage(BaseAveragedModel): @MODELS.register_module() class ExponentialMovingAverage(BaseAveragedModel): - """Implements the exponential moving average (EMA) of the model. + r"""Implements the exponential moving average (EMA) of the model. All parameters are updated by the formula as below: @@ -145,9 +148,10 @@ class ExponentialMovingAverage(BaseAveragedModel): Args: model (nn.Module): The model to be averaged. momentum (float): The momentum used for updating ema parameter. - Ema's parameter are updated with the formula: - `averaged_param = (1-momentum) * averaged_param + momentum * - source_param`. Defaults to 0.0002. + Defaults to 0.0002. + Ema's parameter are updated with the formula + :math:`averaged\_param = (1-momentum) * averaged\_param + + momentum * source\_param`. interval (int): Interval between two updates. Defaults to 1. device (torch.device, optional): If provided, the averaged model will be stored on the :attr:`device`. Defaults to None. @@ -184,14 +188,15 @@ class ExponentialMovingAverage(BaseAveragedModel): @MODELS.register_module() class MomentumAnnealingEMA(ExponentialMovingAverage): - """Exponential moving average (EMA) with momentum annealing strategy. + r"""Exponential moving average (EMA) with momentum annealing strategy. Args: model (nn.Module): The model to be averaged. momentum (float): The momentum used for updating ema parameter. - Ema's parameter are updated with the formula: - `averaged_param = (1-momentum) * averaged_param + momentum * - source_param`. Defaults to 0.0002. + Defaults to 0.0002. + Ema's parameter are updated with the formula + :math:`averaged\_param = (1-momentum) * averaged\_param + + momentum * source\_param`. gamma (int): Use a larger momentum early in training and gradually annealing to a smaller value to update the ema model smoothly. The momentum is calculated as max(momentum, gamma / (gamma + steps)) -- GitLab