From 149248ce52589741f5eab019c9898435d89d5f3e Mon Sep 17 00:00:00 2001 From: Yifei Yang <2744335995@qq.com> Date: Thu, 3 Mar 2022 17:13:43 +0800 Subject: [PATCH] [Feature] Add Sync Buffer Hook (#57) * [Feature]: Add Part3 of Hooks * [Feature]: Add Hook * [Fix]: Add docstring and type hint for base hook * add sync buffer hook * update typing hint and docs * fix lint * fix mypy * fix lint * use mock from unittest Co-authored-by: seuyou <3463423099@qq.com> --- mmengine/hooks/__init__.py | 5 +- mmengine/hooks/sync_buffer_hook.py | 97 +++++++++++++++++++++++ tests/test_hook/test_sync_buffers_hook.py | 13 +++ 3 files changed, 113 insertions(+), 2 deletions(-) create mode 100644 mmengine/hooks/sync_buffer_hook.py create mode 100644 tests/test_hook/test_sync_buffers_hook.py diff --git a/mmengine/hooks/__init__.py b/mmengine/hooks/__init__.py index a91f093b..1cb5b535 100644 --- a/mmengine/hooks/__init__.py +++ b/mmengine/hooks/__init__.py @@ -1,13 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .empty_cache_hook import EmptyCacheHook from .checkpoint_hook import CheckpointHook +from .empty_cache_hook import EmptyCacheHook from .hook import Hook from .iter_timer_hook import IterTimerHook from .optimizer_hook import OptimizerHook from .param_scheduler_hook import ParamSchedulerHook from .sampler_seed_hook import DistSamplerSeedHook +from .sync_buffer_hook import SyncBuffersHook __all__ = [ 'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook', - 'OptimizerHook', 'EmptyCacheHook', 'CheckpointHook' + 'OptimizerHook', 'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook' ] diff --git a/mmengine/hooks/sync_buffer_hook.py b/mmengine/hooks/sync_buffer_hook.py new file mode 100644 index 00000000..89edb55d --- /dev/null +++ b/mmengine/hooks/sync_buffer_hook.py @@ -0,0 +1,97 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# from mmengine.dist import get_dist_info, all_reduce +from collections import OrderedDict +from typing import Generator, List +from unittest.mock import MagicMock, Mock + +import torch +from torch._utils import (_flatten_dense_tensors, _take_tensors, + _unflatten_dense_tensors) + +from mmengine.registry import HOOKS +from .hook import Hook + +# TODO, replace with import mmengine.dist as dist +dist = Mock() +dist.IS_DIST = MagicMock(return_value=True) + +# TODO, replace with mmengine.dist.get_dist_info +get_dist_info = MagicMock(return_value=(0, 1)) +# TODO, replace with mmengine.dist.all_reduce +all_reduce = MagicMock() + + +# TODO, may need to move to dist.utils after implementing dist module +def _allreduce_coalesced(tensors: List[torch.Tensor], + world_size: int, + bucket_size_mb: int = -1) -> None: + """All-reduce a sequence of tensors as a whole. + + Args: + tensors (List[torch.Tensor]): A sequence of tensors to be + all-reduced. + world_size (int): The world size of the process group. + bucket_size_mb (int): The limit of each chunk in megabytes + for grouping tensors into chunks. Defaults to -1. + """ + if bucket_size_mb > 0: + bucket_size_bytes = bucket_size_mb * 1024 * 1024 + buckets = _take_tensors(tensors, bucket_size_bytes) + else: + buckets = OrderedDict() + for tensor in tensors: + tp = tensor.type() + if tp not in buckets: + buckets[tp] = [] + buckets[tp].append(tensor) + buckets = buckets.values() + + for bucket in buckets: + flat_tensors = _flatten_dense_tensors(bucket) + all_reduce(flat_tensors) + flat_tensors.div_(world_size) + for tensor, synced in zip( + bucket, _unflatten_dense_tensors(flat_tensors, bucket)): + tensor.copy_(synced) + + +def allreduce_params(params: Generator[torch.Tensor, None, None], + coalesce: bool = True, + bucket_size_mb: int = -1) -> None: + """All-reduce parameters. + + Args: + params (Generator[torch.Tensor, None, None]): List of parameters or + buffers of a model. + coalesce (bool, optional): Whether to reduce parameters as a whole. + Defaults to True. + bucket_size_mb (int, optional): Size of bucket, the unit is MB. + Defaults to -1. + """ + _, world_size = get_dist_info() + if world_size == 1: + return + params_data = [param.data for param in params] + if coalesce: + _allreduce_coalesced(params_data, world_size, bucket_size_mb) + else: + for tensor in params_data: + all_reduce(tensor.div_(world_size)) + + +@HOOKS.register_module() +class SyncBuffersHook(Hook): + """Synchronize model buffers such as running_mean and running_var in BN at + the end of each epoch.""" + + def __init__(self) -> None: + self.distributed = dist.IS_DIST + + def after_epoch(self, runner: object) -> None: + """All-reduce model buffers at the end of each epoch. + + Args: + runner (object): The runner of the training process. + """ + if self.distributed: + allreduce_params(runner.model.buffers()) # type: ignore diff --git a/tests/test_hook/test_sync_buffers_hook.py b/tests/test_hook/test_sync_buffers_hook.py new file mode 100644 index 00000000..6bba7de5 --- /dev/null +++ b/tests/test_hook/test_sync_buffers_hook.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest.mock import Mock + +from mmengine.hooks import SyncBuffersHook + + +class TestSyncBuffersHook: + + def test_sync_buffers_hook(self): + Runner = Mock() + Runner.model = Mock() + Hook = SyncBuffersHook() + Hook.after_epoch(Runner) -- GitLab