# Copyright (c) OpenMMLab. All rights reserved. # from mmengine.dist import get_dist_info, all_reduce from collections import OrderedDict from typing import Generator, List from unittest.mock import MagicMock, Mock import torch from torch._utils import (_flatten_dense_tensors, _take_tensors, _unflatten_dense_tensors) from mmengine.registry import HOOKS from .hook import Hook # TODO, replace with import mmengine.dist as dist dist = Mock() dist.IS_DIST = MagicMock(return_value=True) # TODO, replace with mmengine.dist.get_dist_info get_dist_info = MagicMock(return_value=(0, 1)) # TODO, replace with mmengine.dist.all_reduce all_reduce = MagicMock() # TODO, may need to move to dist.utils after implementing dist module def _allreduce_coalesced(tensors: List[torch.Tensor], world_size: int, bucket_size_mb: int = -1) -> None: """All-reduce a sequence of tensors as a whole. Args: tensors (List[torch.Tensor]): A sequence of tensors to be all-reduced. world_size (int): The world size of the process group. bucket_size_mb (int): The limit of each chunk in megabytes for grouping tensors into chunks. Defaults to -1. """ if bucket_size_mb > 0: bucket_size_bytes = bucket_size_mb * 1024 * 1024 buckets = _take_tensors(tensors, bucket_size_bytes) else: buckets = OrderedDict() for tensor in tensors: tp = tensor.type() if tp not in buckets: buckets[tp] = [] buckets[tp].append(tensor) buckets = buckets.values() for bucket in buckets: flat_tensors = _flatten_dense_tensors(bucket) all_reduce(flat_tensors) flat_tensors.div_(world_size) for tensor, synced in zip( bucket, _unflatten_dense_tensors(flat_tensors, bucket)): tensor.copy_(synced) def allreduce_params(params: Generator[torch.Tensor, None, None], coalesce: bool = True, bucket_size_mb: int = -1) -> None: """All-reduce parameters. Args: params (Generator[torch.Tensor, None, None]): List of parameters or buffers of a model. coalesce (bool, optional): Whether to reduce parameters as a whole. Defaults to True. bucket_size_mb (int, optional): Size of bucket, the unit is MB. Defaults to -1. """ _, world_size = get_dist_info() if world_size == 1: return params_data = [param.data for param in params] if coalesce: _allreduce_coalesced(params_data, world_size, bucket_size_mb) else: for tensor in params_data: all_reduce(tensor.div_(world_size)) @HOOKS.register_module() class SyncBuffersHook(Hook): """Synchronize model buffers such as running_mean and running_var in BN at the end of each epoch.""" priority = 'NORMAL' def __init__(self) -> None: self.distributed = dist.IS_DIST def after_epoch(self, runner: object) -> None: """All-reduce model buffers at the end of each epoch. Args: runner (object): The runner of the training process. """ if self.distributed: allreduce_params(runner.model.buffers()) # type: ignore