From 936c4ebc581f7383f4e4d156554ab8f257778e57 Mon Sep 17 00:00:00 2001
From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>
Date: Fri, 8 Jul 2022 15:01:47 +0800
Subject: [PATCH] [Fix] Fix missing device ids in wrap_model (#351)

* fix missing device ids in wrap_model

* clean the code

* use default broadcast_buffers

* refine MMSeparateDistributedDataParallel

* rename tmp variable

* refine docstring

* add type hints

* refactor docstring of ddp model

* add  arg in docstring

* minor refine

* better ddp link
---
 mmengine/model/wrappers/distributed.py        | 29 +++++++++++++--
 .../model/wrappers/seperate_distributed.py    | 37 +++++++++++++++++--
 mmengine/runner/runner.py                     | 10 ++++-
 3 files changed, 68 insertions(+), 8 deletions(-)

diff --git a/mmengine/model/wrappers/distributed.py b/mmengine/model/wrappers/distributed.py
index 813f42c6..3ac9dd3e 100644
--- a/mmengine/model/wrappers/distributed.py
+++ b/mmengine/model/wrappers/distributed.py
@@ -47,9 +47,29 @@ class MMDistributedDataParallel(DistributedDataParallel):
                   loss.
             Default: False.
 
-        *args: list arguments passed to ``DistributedDataParallel``
         **kwargs: keyword arguments passed to ``DistributedDataParallel``.
 
+            - device_ids (List[int] or torch.device, optional): CUDA devices
+              for module.
+            - output_device (int or torch.device, optional): Device location of
+              output for single-device CUDA modules.
+            - dim (int): Defaults to 0.
+            - broadcast_buffers (bool): Flag that enables syncing (
+                broadcasting) buffers of the module at beginning of the
+                ``forward`` function. Defaults to True
+            - find_unused_parameters (bool): Whether to find parameters of
+              module, which are not in the forward graph. Defaults to False.
+            - process_group (ProcessGroup, optional): The process group to be
+              used for distributed data all-reduction.
+            - bucket_cap_mb (int): bucket size in MegaBytes (MB). Defaults
+              to 25.
+            - check_reduction (bool): This argument is deprecated. Defaults
+              to False.
+            - gradient_as_bucket_view (bool): Defaults to False.
+            - static_graph (bool): Defaults to False.
+
+    See more information about arguments in `https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel`_  # noqa E501
+
     Note:
         If model has multiple submodules and each module has
         separate optimization strategies,
@@ -63,8 +83,11 @@ class MMDistributedDataParallel(DistributedDataParallel):
         override the ``train_step`` method.
     """
 
-    def __init__(self, detect_anomalous_params: bool = False, *args, **kwargs):
-        super().__init__(*args, **kwargs)
+    def __init__(self,
+                 module,
+                 detect_anomalous_params: bool = False,
+                 **kwargs):
+        super().__init__(module=module, **kwargs)
         self.detect_anomalous_params = detect_anomalous_params
 
     def train_step(self, data: List[dict],
diff --git a/mmengine/model/wrappers/seperate_distributed.py b/mmengine/model/wrappers/seperate_distributed.py
index d1e2caa2..dbe133f9 100644
--- a/mmengine/model/wrappers/seperate_distributed.py
+++ b/mmengine/model/wrappers/seperate_distributed.py
@@ -36,11 +36,37 @@ class MMSeparateDistributedDataParallel(DistributedDataParallel):
     Args:
         module (nn.Module): model contain multiple submodules which have
             separately updating strategy.
-        *args: list arguments passed to ``MMDistributedDataParallel``
-        **kwargs: keyword arguments passed to ``MMDistributedDataParallel``.
+        broadcast_buffers (bool): Same as that in
+            ``torch.nn.parallel.distributed.DistributedDataParallel``.
+            Defaults to False.
+        find_unused_parameters (bool): Same as that in
+            ``torch.nn.parallel.distributed.DistributedDataParallel``.
+            Traverse the autograd graph of all tensors contained in returned
+            value of the wrapped module’s forward function. Defaults to False.
+        **kwargs: Keyword arguments passed to ``MMDistributedDataParallel``.
+
+            - device_ids (List[int] or torch.device, optional): CUDA devices
+              for module.
+            - output_device (int or torch.device, optional): Device location of
+              output for single-device CUDA modules.
+            - dim (int): Defaults to 0.
+            - process_group (ProcessGroup, optional): The process group to be
+              used for distributed data all-reduction.
+            - bucket_cap_mb (int): bucket size in MegaBytes (MB). Defaults
+              to 25.
+            - check_reduction (bool): This argument is deprecated. Defaults
+              to False.
+            - gradient_as_bucket_view (bool): Defaults to False.
+            - static_graph (bool): Defaults to False.
+
+    See more information about arguments in `https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel`_  # noqa E501
     """
 
-    def __init__(self, module: nn.Module, *args, **kwargs):
+    def __init__(self,
+                 module: nn.Module,
+                 broadcast_buffers: bool = False,
+                 find_unused_parameters: bool = False,
+                 **kwargs):
         super(DistributedDataParallel, self).__init__()
         self.module = module
         device = get_device()
@@ -54,7 +80,10 @@ class MMSeparateDistributedDataParallel(DistributedDataParallel):
                 sub_module = sub_module.to(device)
             else:
                 sub_module = MMDistributedDataParallel(
-                    module=sub_module.to(device), *args, **kwargs)
+                    module=sub_module.to(device),
+                    broadcast_buffers=broadcast_buffers,
+                    find_unused_parameters=find_unused_parameters,
+                    **kwargs)
             module._modules[name] = sub_module
 
     def train_step(self, data: List[dict],
diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py
index 5d7eb2d1..d41b1c6b 100644
--- a/mmengine/runner/runner.py
+++ b/mmengine/runner/runner.py
@@ -858,8 +858,16 @@ class Runner:
                 broadcast_buffers=False,
                 find_unused_parameters=find_unused_parameters)
         else:
+            model_wrapper_type = MODEL_WRAPPERS.get(
+                model_wrapper_cfg.get('type'))  # type: ignore
+            default_args: dict = dict()
+            if issubclass(
+                    model_wrapper_type,  # type: ignore
+                    DistributedDataParallel):
+                default_args['device_ids'] = [int(os.environ['LOCAL_RANK'])]
+            default_args['module'] = model
             model = MODEL_WRAPPERS.build(
-                model_wrapper_cfg, default_args=dict(module=model))
+                model_wrapper_cfg, default_args=default_args)
         return model
 
     def scale_lr(self,
-- 
GitLab