diff --git a/mmengine/model/wrappers/distributed.py b/mmengine/model/wrappers/distributed.py index 813f42c6aa62d0e1853929af719c43f962f53b1d..3ac9dd3ed2d78c1860056f05f4b919cf0abf0fc0 100644 --- a/mmengine/model/wrappers/distributed.py +++ b/mmengine/model/wrappers/distributed.py @@ -47,9 +47,29 @@ class MMDistributedDataParallel(DistributedDataParallel): loss. Default: False. - *args: list arguments passed to ``DistributedDataParallel`` **kwargs: keyword arguments passed to ``DistributedDataParallel``. + - device_ids (List[int] or torch.device, optional): CUDA devices + for module. + - output_device (int or torch.device, optional): Device location of + output for single-device CUDA modules. + - dim (int): Defaults to 0. + - broadcast_buffers (bool): Flag that enables syncing ( + broadcasting) buffers of the module at beginning of the + ``forward`` function. Defaults to True + - find_unused_parameters (bool): Whether to find parameters of + module, which are not in the forward graph. Defaults to False. + - process_group (ProcessGroup, optional): The process group to be + used for distributed data all-reduction. + - bucket_cap_mb (int): bucket size in MegaBytes (MB). Defaults + to 25. + - check_reduction (bool): This argument is deprecated. Defaults + to False. + - gradient_as_bucket_view (bool): Defaults to False. + - static_graph (bool): Defaults to False. + + See more information about arguments in `https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel`_ # noqa E501 + Note: If model has multiple submodules and each module has separate optimization strategies, @@ -63,8 +83,11 @@ class MMDistributedDataParallel(DistributedDataParallel): override the ``train_step`` method. """ - def __init__(self, detect_anomalous_params: bool = False, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, + module, + detect_anomalous_params: bool = False, + **kwargs): + super().__init__(module=module, **kwargs) self.detect_anomalous_params = detect_anomalous_params def train_step(self, data: List[dict], diff --git a/mmengine/model/wrappers/seperate_distributed.py b/mmengine/model/wrappers/seperate_distributed.py index d1e2caa2b328d45809fc2cbd6144a605438324ee..dbe133f93074d60d25506594c211f54a7b200a38 100644 --- a/mmengine/model/wrappers/seperate_distributed.py +++ b/mmengine/model/wrappers/seperate_distributed.py @@ -36,11 +36,37 @@ class MMSeparateDistributedDataParallel(DistributedDataParallel): Args: module (nn.Module): model contain multiple submodules which have separately updating strategy. - *args: list arguments passed to ``MMDistributedDataParallel`` - **kwargs: keyword arguments passed to ``MMDistributedDataParallel``. + broadcast_buffers (bool): Same as that in + ``torch.nn.parallel.distributed.DistributedDataParallel``. + Defaults to False. + find_unused_parameters (bool): Same as that in + ``torch.nn.parallel.distributed.DistributedDataParallel``. + Traverse the autograd graph of all tensors contained in returned + value of the wrapped module’s forward function. Defaults to False. + **kwargs: Keyword arguments passed to ``MMDistributedDataParallel``. + + - device_ids (List[int] or torch.device, optional): CUDA devices + for module. + - output_device (int or torch.device, optional): Device location of + output for single-device CUDA modules. + - dim (int): Defaults to 0. + - process_group (ProcessGroup, optional): The process group to be + used for distributed data all-reduction. + - bucket_cap_mb (int): bucket size in MegaBytes (MB). Defaults + to 25. + - check_reduction (bool): This argument is deprecated. Defaults + to False. + - gradient_as_bucket_view (bool): Defaults to False. + - static_graph (bool): Defaults to False. + + See more information about arguments in `https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel`_ # noqa E501 """ - def __init__(self, module: nn.Module, *args, **kwargs): + def __init__(self, + module: nn.Module, + broadcast_buffers: bool = False, + find_unused_parameters: bool = False, + **kwargs): super(DistributedDataParallel, self).__init__() self.module = module device = get_device() @@ -54,7 +80,10 @@ class MMSeparateDistributedDataParallel(DistributedDataParallel): sub_module = sub_module.to(device) else: sub_module = MMDistributedDataParallel( - module=sub_module.to(device), *args, **kwargs) + module=sub_module.to(device), + broadcast_buffers=broadcast_buffers, + find_unused_parameters=find_unused_parameters, + **kwargs) module._modules[name] = sub_module def train_step(self, data: List[dict], diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 5d7eb2d1eaf6e9453929a482c0d23ff150de7a52..d41b1c6b4b3f7442e0b37515275c2187636ca24c 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -858,8 +858,16 @@ class Runner: broadcast_buffers=False, find_unused_parameters=find_unused_parameters) else: + model_wrapper_type = MODEL_WRAPPERS.get( + model_wrapper_cfg.get('type')) # type: ignore + default_args: dict = dict() + if issubclass( + model_wrapper_type, # type: ignore + DistributedDataParallel): + default_args['device_ids'] = [int(os.environ['LOCAL_RANK'])] + default_args['module'] = model model = MODEL_WRAPPERS.build( - model_wrapper_cfg, default_args=dict(module=model)) + model_wrapper_cfg, default_args=default_args) return model def scale_lr(self,