Skip to content
Snippets Groups Projects
Unverified Commit 149248ce authored by Yifei Yang's avatar Yifei Yang Committed by GitHub
Browse files

[Feature] Add Sync Buffer Hook (#57)


* [Feature]: Add Part3 of Hooks

* [Feature]: Add Hook

* [Fix]: Add docstring and type hint for base hook

* add sync buffer hook

* update typing hint and docs

* fix lint

* fix mypy

* fix lint

* use mock from unittest

Co-authored-by: default avatarseuyou <3463423099@qq.com>
parent 12b916cf
No related branches found
No related tags found
No related merge requests found
# Copyright (c) OpenMMLab. All rights reserved.
from .empty_cache_hook import EmptyCacheHook
from .checkpoint_hook import CheckpointHook
from .empty_cache_hook import EmptyCacheHook
from .hook import Hook
from .iter_timer_hook import IterTimerHook
from .optimizer_hook import OptimizerHook
from .param_scheduler_hook import ParamSchedulerHook
from .sampler_seed_hook import DistSamplerSeedHook
from .sync_buffer_hook import SyncBuffersHook
__all__ = [
'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
'OptimizerHook', 'EmptyCacheHook', 'CheckpointHook'
'OptimizerHook', 'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook'
]
# Copyright (c) OpenMMLab. All rights reserved.
# from mmengine.dist import get_dist_info, all_reduce
from collections import OrderedDict
from typing import Generator, List
from unittest.mock import MagicMock, Mock
import torch
from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors)
from mmengine.registry import HOOKS
from .hook import Hook
# TODO, replace with import mmengine.dist as dist
dist = Mock()
dist.IS_DIST = MagicMock(return_value=True)
# TODO, replace with mmengine.dist.get_dist_info
get_dist_info = MagicMock(return_value=(0, 1))
# TODO, replace with mmengine.dist.all_reduce
all_reduce = MagicMock()
# TODO, may need to move to dist.utils after implementing dist module
def _allreduce_coalesced(tensors: List[torch.Tensor],
world_size: int,
bucket_size_mb: int = -1) -> None:
"""All-reduce a sequence of tensors as a whole.
Args:
tensors (List[torch.Tensor]): A sequence of tensors to be
all-reduced.
world_size (int): The world size of the process group.
bucket_size_mb (int): The limit of each chunk in megabytes
for grouping tensors into chunks. Defaults to -1.
"""
if bucket_size_mb > 0:
bucket_size_bytes = bucket_size_mb * 1024 * 1024
buckets = _take_tensors(tensors, bucket_size_bytes)
else:
buckets = OrderedDict()
for tensor in tensors:
tp = tensor.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(tensor)
buckets = buckets.values()
for bucket in buckets:
flat_tensors = _flatten_dense_tensors(bucket)
all_reduce(flat_tensors)
flat_tensors.div_(world_size)
for tensor, synced in zip(
bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
tensor.copy_(synced)
def allreduce_params(params: Generator[torch.Tensor, None, None],
coalesce: bool = True,
bucket_size_mb: int = -1) -> None:
"""All-reduce parameters.
Args:
params (Generator[torch.Tensor, None, None]): List of parameters or
buffers of a model.
coalesce (bool, optional): Whether to reduce parameters as a whole.
Defaults to True.
bucket_size_mb (int, optional): Size of bucket, the unit is MB.
Defaults to -1.
"""
_, world_size = get_dist_info()
if world_size == 1:
return
params_data = [param.data for param in params]
if coalesce:
_allreduce_coalesced(params_data, world_size, bucket_size_mb)
else:
for tensor in params_data:
all_reduce(tensor.div_(world_size))
@HOOKS.register_module()
class SyncBuffersHook(Hook):
"""Synchronize model buffers such as running_mean and running_var in BN at
the end of each epoch."""
def __init__(self) -> None:
self.distributed = dist.IS_DIST
def after_epoch(self, runner: object) -> None:
"""All-reduce model buffers at the end of each epoch.
Args:
runner (object): The runner of the training process.
"""
if self.distributed:
allreduce_params(runner.model.buffers()) # type: ignore
# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import Mock
from mmengine.hooks import SyncBuffersHook
class TestSyncBuffersHook:
def test_sync_buffers_hook(self):
Runner = Mock()
Runner.model = Mock()
Hook = SyncBuffersHook()
Hook.after_epoch(Runner)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment