Skip to content
Snippets Groups Projects
Unverified Commit 6015fd35 authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

Fix docstring format (#337)

parent 2fd6beb9
No related branches found
No related tags found
No related merge requests found
......@@ -27,7 +27,8 @@ class LogProcessor:
custom_cfg (list[dict], optional): Contains multiple log config dict,
in which key means the data source name of log and value means the
statistic method and corresponding arguments used to count the
data source. Defaults to None
data source. Defaults to None.
- If custom_cfg is None, all logs will be formatted via default
methods, such as smoothing loss by default window_size. If
custom_cfg is defined as a list of config dict, for example:
......@@ -35,12 +36,12 @@ class LogProcessor:
window_size='global')]. It means the log item ``loss`` will be
counted as global mean and additionally logged as ``global_loss``
(defined by ``log_name``). If ``log_name`` is not defined in
config dict, the original logged key will be overwritten.
config dict, the original logged key will be overwritten.
- The original log item cannot be overwritten twice. Here is
an error example:
[dict(data_src=loss, method='mean', window_size='global'),
dict(data_src=loss, method='mean', window_size='epoch')].
dict(data_src=loss, method='mean', window_size='epoch')].
Both log config dict in custom_cfg do not have ``log_name`` key,
which means the loss item will be overwritten twice.
......
......@@ -246,6 +246,7 @@ def print_log(msg,
logger (Logger or str, optional): If the type of logger is
``logging.Logger``, we directly use logger to log messages.
Some special loggers are:
- "silent": No message will be printed.
- "current": Use latest created logger to log message.
- other str: Instance name of logger. The corresponding logger
......
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.utils.parrots_wrapper import TORCH_VERSION
from mmengine.utils.version_utils import digit_version
from .averaged_model import (ExponentialMovingAverage, MomentumAnnealingEMA,
StochasticWeightAverage)
from .averaged_model import (BaseAveragedModel, ExponentialMovingAverage,
MomentumAnnealingEMA, StochasticWeightAverage)
from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
from .base_module import BaseModule, ModuleDict, ModuleList, Sequential
from .utils import detect_anomalous_params, merge_dict, stack_batch
......@@ -10,12 +10,12 @@ from .wrappers import (MMDistributedDataParallel,
MMSeparateDistributedDataParallel, is_model_wrapper)
__all__ = [
'MMDistributedDataParallel', 'is_model_wrapper', 'StochasticWeightAverage',
'ExponentialMovingAverage', 'MomentumAnnealingEMA', 'BaseModel',
'BaseDataPreprocessor', 'ImgDataPreprocessor',
'MMSeparateDistributedDataParallel', 'BaseModule', 'stack_batch',
'merge_dict', 'detect_anomalous_params', 'ModuleList', 'ModuleDict',
'Sequential'
'MMDistributedDataParallel', 'is_model_wrapper', 'BaseAveragedModel',
'StochasticWeightAverage', 'ExponentialMovingAverage',
'MomentumAnnealingEMA', 'BaseModel', 'BaseDataPreprocessor',
'ImgDataPreprocessor', 'MMSeparateDistributedDataParallel', 'BaseModule',
'stack_batch', 'merge_dict', 'detect_anomalous_params', 'ModuleList',
'ModuleDict', 'Sequential'
]
if digit_version(TORCH_VERSION) >= digit_version('1.11.0'):
......
......@@ -17,25 +17,28 @@ class BaseAveragedModel(nn.Module):
training neural networks. This class implements the averaging process
for a model. All subclasses must implement the `avg_func` method.
This class creates a copy of the provided module :attr:`model`
on the device :attr:`device` and allows computing running averages of the
on the :attr:`device` and allows computing running averages of the
parameters of the :attr:`model`.
The code is referenced from: https://github.com/pytorch/pytorch/blob/master/torch/optim/swa_utils.py.
Different from the `AveragedModel` in PyTorch, we use in-place operation
to improve the parameter updating speed, which is about 5 times faster
than the non-in-place version.
In mmengine, we provide two ways to use the model averaging:
1. Use the model averaging module in hook:
We provide an EMAHook to apply the model averaging during training.
Add ``custom_hooks=[dict(type='EMAHook')]`` to the config or the runner.
The hook is implemented in mmengine/hooks/ema_hook.py
We provide an EMAHook to apply the model averaging during training.
Add ``custom_hooks=[dict(type='EMAHook')]`` to the config or the runner.
The hook is implemented in mmengine/hooks/ema_hook.py
2. Use the model averaging module directly in the algorithm. Take the ema
teacher in semi-supervise as an example:
>>> from mmengine.model import ExponentialMovingAverage
>>> student = ResNet(depth=50)
>>> # use ema model as teacher
>>> ema_teacher = ExponentialMovingAverage(student)
>>> from mmengine.model import ExponentialMovingAverage
>>> student = ResNet(depth=50)
>>> # use ema model as teacher
>>> ema_teacher = ExponentialMovingAverage(student)
Args:
model (nn.Module): The model to be averaged.
......@@ -134,7 +137,7 @@ class StochasticWeightAverage(BaseAveragedModel):
@MODELS.register_module()
class ExponentialMovingAverage(BaseAveragedModel):
"""Implements the exponential moving average (EMA) of the model.
r"""Implements the exponential moving average (EMA) of the model.
All parameters are updated by the formula as below:
......@@ -145,9 +148,10 @@ class ExponentialMovingAverage(BaseAveragedModel):
Args:
model (nn.Module): The model to be averaged.
momentum (float): The momentum used for updating ema parameter.
Ema's parameter are updated with the formula:
`averaged_param = (1-momentum) * averaged_param + momentum *
source_param`. Defaults to 0.0002.
Defaults to 0.0002.
Ema's parameter are updated with the formula
:math:`averaged\_param = (1-momentum) * averaged\_param +
momentum * source\_param`.
interval (int): Interval between two updates. Defaults to 1.
device (torch.device, optional): If provided, the averaged model will
be stored on the :attr:`device`. Defaults to None.
......@@ -184,14 +188,15 @@ class ExponentialMovingAverage(BaseAveragedModel):
@MODELS.register_module()
class MomentumAnnealingEMA(ExponentialMovingAverage):
"""Exponential moving average (EMA) with momentum annealing strategy.
r"""Exponential moving average (EMA) with momentum annealing strategy.
Args:
model (nn.Module): The model to be averaged.
momentum (float): The momentum used for updating ema parameter.
Ema's parameter are updated with the formula:
`averaged_param = (1-momentum) * averaged_param + momentum *
source_param`. Defaults to 0.0002.
Defaults to 0.0002.
Ema's parameter are updated with the formula
:math:`averaged\_param = (1-momentum) * averaged\_param +
momentum * source\_param`.
gamma (int): Use a larger momentum early in training and gradually
annealing to a smaller value to update the ema model smoothly. The
momentum is calculated as max(momentum, gamma / (gamma + steps))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment