From 149248ce52589741f5eab019c9898435d89d5f3e Mon Sep 17 00:00:00 2001
From: Yifei Yang <2744335995@qq.com>
Date: Thu, 3 Mar 2022 17:13:43 +0800
Subject: [PATCH] [Feature] Add Sync Buffer Hook  (#57)

* [Feature]: Add Part3 of Hooks

* [Feature]: Add Hook

* [Fix]: Add docstring and type hint for base hook

* add sync buffer hook

* update typing hint and docs

* fix lint

* fix mypy

* fix lint

* use mock from unittest

Co-authored-by: seuyou <3463423099@qq.com>
---
 mmengine/hooks/__init__.py                |  5 +-
 mmengine/hooks/sync_buffer_hook.py        | 97 +++++++++++++++++++++++
 tests/test_hook/test_sync_buffers_hook.py | 13 +++
 3 files changed, 113 insertions(+), 2 deletions(-)
 create mode 100644 mmengine/hooks/sync_buffer_hook.py
 create mode 100644 tests/test_hook/test_sync_buffers_hook.py

diff --git a/mmengine/hooks/__init__.py b/mmengine/hooks/__init__.py
index a91f093b..1cb5b535 100644
--- a/mmengine/hooks/__init__.py
+++ b/mmengine/hooks/__init__.py
@@ -1,13 +1,14 @@
 # Copyright (c) OpenMMLab. All rights reserved.
-from .empty_cache_hook import EmptyCacheHook
 from .checkpoint_hook import CheckpointHook
+from .empty_cache_hook import EmptyCacheHook
 from .hook import Hook
 from .iter_timer_hook import IterTimerHook
 from .optimizer_hook import OptimizerHook
 from .param_scheduler_hook import ParamSchedulerHook
 from .sampler_seed_hook import DistSamplerSeedHook
+from .sync_buffer_hook import SyncBuffersHook
 
 __all__ = [
     'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
-    'OptimizerHook', 'EmptyCacheHook', 'CheckpointHook'
+    'OptimizerHook', 'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook'
 ]
diff --git a/mmengine/hooks/sync_buffer_hook.py b/mmengine/hooks/sync_buffer_hook.py
new file mode 100644
index 00000000..89edb55d
--- /dev/null
+++ b/mmengine/hooks/sync_buffer_hook.py
@@ -0,0 +1,97 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# from mmengine.dist import get_dist_info, all_reduce
+from collections import OrderedDict
+from typing import Generator, List
+from unittest.mock import MagicMock, Mock
+
+import torch
+from torch._utils import (_flatten_dense_tensors, _take_tensors,
+                          _unflatten_dense_tensors)
+
+from mmengine.registry import HOOKS
+from .hook import Hook
+
+# TODO, replace with import mmengine.dist as dist
+dist = Mock()
+dist.IS_DIST = MagicMock(return_value=True)
+
+# TODO, replace with mmengine.dist.get_dist_info
+get_dist_info = MagicMock(return_value=(0, 1))
+# TODO, replace with mmengine.dist.all_reduce
+all_reduce = MagicMock()
+
+
+# TODO, may need to move to dist.utils after implementing dist module
+def _allreduce_coalesced(tensors: List[torch.Tensor],
+                         world_size: int,
+                         bucket_size_mb: int = -1) -> None:
+    """All-reduce a sequence of tensors as a whole.
+
+    Args:
+        tensors (List[torch.Tensor]): A sequence of tensors to be
+            all-reduced.
+        world_size (int): The world size of the process group.
+        bucket_size_mb (int): The limit of each chunk in megabytes
+            for grouping tensors into chunks. Defaults to -1.
+    """
+    if bucket_size_mb > 0:
+        bucket_size_bytes = bucket_size_mb * 1024 * 1024
+        buckets = _take_tensors(tensors, bucket_size_bytes)
+    else:
+        buckets = OrderedDict()
+        for tensor in tensors:
+            tp = tensor.type()
+            if tp not in buckets:
+                buckets[tp] = []
+            buckets[tp].append(tensor)
+        buckets = buckets.values()
+
+    for bucket in buckets:
+        flat_tensors = _flatten_dense_tensors(bucket)
+        all_reduce(flat_tensors)
+        flat_tensors.div_(world_size)
+        for tensor, synced in zip(
+                bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
+            tensor.copy_(synced)
+
+
+def allreduce_params(params: Generator[torch.Tensor, None, None],
+                     coalesce: bool = True,
+                     bucket_size_mb: int = -1) -> None:
+    """All-reduce parameters.
+
+    Args:
+        params (Generator[torch.Tensor, None, None]): List of parameters or
+            buffers of a model.
+        coalesce (bool, optional): Whether to reduce parameters as a whole.
+            Defaults to True.
+        bucket_size_mb (int, optional): Size of bucket, the unit is MB.
+            Defaults to -1.
+    """
+    _, world_size = get_dist_info()
+    if world_size == 1:
+        return
+    params_data = [param.data for param in params]
+    if coalesce:
+        _allreduce_coalesced(params_data, world_size, bucket_size_mb)
+    else:
+        for tensor in params_data:
+            all_reduce(tensor.div_(world_size))
+
+
+@HOOKS.register_module()
+class SyncBuffersHook(Hook):
+    """Synchronize model buffers such as running_mean and running_var in BN at
+    the end of each epoch."""
+
+    def __init__(self) -> None:
+        self.distributed = dist.IS_DIST
+
+    def after_epoch(self, runner: object) -> None:
+        """All-reduce model buffers at the end of each epoch.
+
+        Args:
+            runner (object): The runner of the training process.
+        """
+        if self.distributed:
+            allreduce_params(runner.model.buffers())  # type: ignore
diff --git a/tests/test_hook/test_sync_buffers_hook.py b/tests/test_hook/test_sync_buffers_hook.py
new file mode 100644
index 00000000..6bba7de5
--- /dev/null
+++ b/tests/test_hook/test_sync_buffers_hook.py
@@ -0,0 +1,13 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from unittest.mock import Mock
+
+from mmengine.hooks import SyncBuffersHook
+
+
+class TestSyncBuffersHook:
+
+    def test_sync_buffers_hook(self):
+        Runner = Mock()
+        Runner.model = Mock()
+        Hook = SyncBuffersHook()
+        Hook.after_epoch(Runner)
-- 
GitLab