From 936c4ebc581f7383f4e4d156554ab8f257778e57 Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Fri, 8 Jul 2022 15:01:47 +0800 Subject: [PATCH] [Fix] Fix missing device ids in wrap_model (#351) * fix missing device ids in wrap_model * clean the code * use default broadcast_buffers * refine MMSeparateDistributedDataParallel * rename tmp variable * refine docstring * add type hints * refactor docstring of ddp model * add arg in docstring * minor refine * better ddp link --- mmengine/model/wrappers/distributed.py | 29 +++++++++++++-- .../model/wrappers/seperate_distributed.py | 37 +++++++++++++++++-- mmengine/runner/runner.py | 10 ++++- 3 files changed, 68 insertions(+), 8 deletions(-) diff --git a/mmengine/model/wrappers/distributed.py b/mmengine/model/wrappers/distributed.py index 813f42c6..3ac9dd3e 100644 --- a/mmengine/model/wrappers/distributed.py +++ b/mmengine/model/wrappers/distributed.py @@ -47,9 +47,29 @@ class MMDistributedDataParallel(DistributedDataParallel): loss. Default: False. - *args: list arguments passed to ``DistributedDataParallel`` **kwargs: keyword arguments passed to ``DistributedDataParallel``. + - device_ids (List[int] or torch.device, optional): CUDA devices + for module. + - output_device (int or torch.device, optional): Device location of + output for single-device CUDA modules. + - dim (int): Defaults to 0. + - broadcast_buffers (bool): Flag that enables syncing ( + broadcasting) buffers of the module at beginning of the + ``forward`` function. Defaults to True + - find_unused_parameters (bool): Whether to find parameters of + module, which are not in the forward graph. Defaults to False. + - process_group (ProcessGroup, optional): The process group to be + used for distributed data all-reduction. + - bucket_cap_mb (int): bucket size in MegaBytes (MB). Defaults + to 25. + - check_reduction (bool): This argument is deprecated. Defaults + to False. + - gradient_as_bucket_view (bool): Defaults to False. + - static_graph (bool): Defaults to False. + + See more information about arguments in `https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel`_ # noqa E501 + Note: If model has multiple submodules and each module has separate optimization strategies, @@ -63,8 +83,11 @@ class MMDistributedDataParallel(DistributedDataParallel): override the ``train_step`` method. """ - def __init__(self, detect_anomalous_params: bool = False, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, + module, + detect_anomalous_params: bool = False, + **kwargs): + super().__init__(module=module, **kwargs) self.detect_anomalous_params = detect_anomalous_params def train_step(self, data: List[dict], diff --git a/mmengine/model/wrappers/seperate_distributed.py b/mmengine/model/wrappers/seperate_distributed.py index d1e2caa2..dbe133f9 100644 --- a/mmengine/model/wrappers/seperate_distributed.py +++ b/mmengine/model/wrappers/seperate_distributed.py @@ -36,11 +36,37 @@ class MMSeparateDistributedDataParallel(DistributedDataParallel): Args: module (nn.Module): model contain multiple submodules which have separately updating strategy. - *args: list arguments passed to ``MMDistributedDataParallel`` - **kwargs: keyword arguments passed to ``MMDistributedDataParallel``. + broadcast_buffers (bool): Same as that in + ``torch.nn.parallel.distributed.DistributedDataParallel``. + Defaults to False. + find_unused_parameters (bool): Same as that in + ``torch.nn.parallel.distributed.DistributedDataParallel``. + Traverse the autograd graph of all tensors contained in returned + value of the wrapped module’s forward function. Defaults to False. + **kwargs: Keyword arguments passed to ``MMDistributedDataParallel``. + + - device_ids (List[int] or torch.device, optional): CUDA devices + for module. + - output_device (int or torch.device, optional): Device location of + output for single-device CUDA modules. + - dim (int): Defaults to 0. + - process_group (ProcessGroup, optional): The process group to be + used for distributed data all-reduction. + - bucket_cap_mb (int): bucket size in MegaBytes (MB). Defaults + to 25. + - check_reduction (bool): This argument is deprecated. Defaults + to False. + - gradient_as_bucket_view (bool): Defaults to False. + - static_graph (bool): Defaults to False. + + See more information about arguments in `https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel`_ # noqa E501 """ - def __init__(self, module: nn.Module, *args, **kwargs): + def __init__(self, + module: nn.Module, + broadcast_buffers: bool = False, + find_unused_parameters: bool = False, + **kwargs): super(DistributedDataParallel, self).__init__() self.module = module device = get_device() @@ -54,7 +80,10 @@ class MMSeparateDistributedDataParallel(DistributedDataParallel): sub_module = sub_module.to(device) else: sub_module = MMDistributedDataParallel( - module=sub_module.to(device), *args, **kwargs) + module=sub_module.to(device), + broadcast_buffers=broadcast_buffers, + find_unused_parameters=find_unused_parameters, + **kwargs) module._modules[name] = sub_module def train_step(self, data: List[dict], diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 5d7eb2d1..d41b1c6b 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -858,8 +858,16 @@ class Runner: broadcast_buffers=False, find_unused_parameters=find_unused_parameters) else: + model_wrapper_type = MODEL_WRAPPERS.get( + model_wrapper_cfg.get('type')) # type: ignore + default_args: dict = dict() + if issubclass( + model_wrapper_type, # type: ignore + DistributedDataParallel): + default_args['device_ids'] = [int(os.environ['LOCAL_RANK'])] + default_args['module'] = model model = MODEL_WRAPPERS.build( - model_wrapper_cfg, default_args=dict(module=model)) + model_wrapper_cfg, default_args=default_args) return model def scale_lr(self, -- GitLab