Skip to content
Snippets Groups Projects
Unverified Commit 936c4ebc authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Fix] Fix missing device ids in wrap_model (#351)

* fix missing device ids in wrap_model

* clean the code

* use default broadcast_buffers

* refine MMSeparateDistributedDataParallel

* rename tmp variable

* refine docstring

* add type hints

* refactor docstring of ddp model

* add  arg in docstring

* minor refine

* better ddp link
parent 792f481e
No related branches found
No related tags found
No related merge requests found
...@@ -47,9 +47,29 @@ class MMDistributedDataParallel(DistributedDataParallel): ...@@ -47,9 +47,29 @@ class MMDistributedDataParallel(DistributedDataParallel):
loss. loss.
Default: False. Default: False.
*args: list arguments passed to ``DistributedDataParallel``
**kwargs: keyword arguments passed to ``DistributedDataParallel``. **kwargs: keyword arguments passed to ``DistributedDataParallel``.
- device_ids (List[int] or torch.device, optional): CUDA devices
for module.
- output_device (int or torch.device, optional): Device location of
output for single-device CUDA modules.
- dim (int): Defaults to 0.
- broadcast_buffers (bool): Flag that enables syncing (
broadcasting) buffers of the module at beginning of the
``forward`` function. Defaults to True
- find_unused_parameters (bool): Whether to find parameters of
module, which are not in the forward graph. Defaults to False.
- process_group (ProcessGroup, optional): The process group to be
used for distributed data all-reduction.
- bucket_cap_mb (int): bucket size in MegaBytes (MB). Defaults
to 25.
- check_reduction (bool): This argument is deprecated. Defaults
to False.
- gradient_as_bucket_view (bool): Defaults to False.
- static_graph (bool): Defaults to False.
See more information about arguments in `https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel`_ # noqa E501
Note: Note:
If model has multiple submodules and each module has If model has multiple submodules and each module has
separate optimization strategies, separate optimization strategies,
...@@ -63,8 +83,11 @@ class MMDistributedDataParallel(DistributedDataParallel): ...@@ -63,8 +83,11 @@ class MMDistributedDataParallel(DistributedDataParallel):
override the ``train_step`` method. override the ``train_step`` method.
""" """
def __init__(self, detect_anomalous_params: bool = False, *args, **kwargs): def __init__(self,
super().__init__(*args, **kwargs) module,
detect_anomalous_params: bool = False,
**kwargs):
super().__init__(module=module, **kwargs)
self.detect_anomalous_params = detect_anomalous_params self.detect_anomalous_params = detect_anomalous_params
def train_step(self, data: List[dict], def train_step(self, data: List[dict],
......
...@@ -36,11 +36,37 @@ class MMSeparateDistributedDataParallel(DistributedDataParallel): ...@@ -36,11 +36,37 @@ class MMSeparateDistributedDataParallel(DistributedDataParallel):
Args: Args:
module (nn.Module): model contain multiple submodules which have module (nn.Module): model contain multiple submodules which have
separately updating strategy. separately updating strategy.
*args: list arguments passed to ``MMDistributedDataParallel`` broadcast_buffers (bool): Same as that in
**kwargs: keyword arguments passed to ``MMDistributedDataParallel``. ``torch.nn.parallel.distributed.DistributedDataParallel``.
Defaults to False.
find_unused_parameters (bool): Same as that in
``torch.nn.parallel.distributed.DistributedDataParallel``.
Traverse the autograd graph of all tensors contained in returned
value of the wrapped module’s forward function. Defaults to False.
**kwargs: Keyword arguments passed to ``MMDistributedDataParallel``.
- device_ids (List[int] or torch.device, optional): CUDA devices
for module.
- output_device (int or torch.device, optional): Device location of
output for single-device CUDA modules.
- dim (int): Defaults to 0.
- process_group (ProcessGroup, optional): The process group to be
used for distributed data all-reduction.
- bucket_cap_mb (int): bucket size in MegaBytes (MB). Defaults
to 25.
- check_reduction (bool): This argument is deprecated. Defaults
to False.
- gradient_as_bucket_view (bool): Defaults to False.
- static_graph (bool): Defaults to False.
See more information about arguments in `https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel`_ # noqa E501
""" """
def __init__(self, module: nn.Module, *args, **kwargs): def __init__(self,
module: nn.Module,
broadcast_buffers: bool = False,
find_unused_parameters: bool = False,
**kwargs):
super(DistributedDataParallel, self).__init__() super(DistributedDataParallel, self).__init__()
self.module = module self.module = module
device = get_device() device = get_device()
...@@ -54,7 +80,10 @@ class MMSeparateDistributedDataParallel(DistributedDataParallel): ...@@ -54,7 +80,10 @@ class MMSeparateDistributedDataParallel(DistributedDataParallel):
sub_module = sub_module.to(device) sub_module = sub_module.to(device)
else: else:
sub_module = MMDistributedDataParallel( sub_module = MMDistributedDataParallel(
module=sub_module.to(device), *args, **kwargs) module=sub_module.to(device),
broadcast_buffers=broadcast_buffers,
find_unused_parameters=find_unused_parameters,
**kwargs)
module._modules[name] = sub_module module._modules[name] = sub_module
def train_step(self, data: List[dict], def train_step(self, data: List[dict],
......
...@@ -858,8 +858,16 @@ class Runner: ...@@ -858,8 +858,16 @@ class Runner:
broadcast_buffers=False, broadcast_buffers=False,
find_unused_parameters=find_unused_parameters) find_unused_parameters=find_unused_parameters)
else: else:
model_wrapper_type = MODEL_WRAPPERS.get(
model_wrapper_cfg.get('type')) # type: ignore
default_args: dict = dict()
if issubclass(
model_wrapper_type, # type: ignore
DistributedDataParallel):
default_args['device_ids'] = [int(os.environ['LOCAL_RANK'])]
default_args['module'] = model
model = MODEL_WRAPPERS.build( model = MODEL_WRAPPERS.build(
model_wrapper_cfg, default_args=dict(module=model)) model_wrapper_cfg, default_args=default_args)
return model return model
def scale_lr(self, def scale_lr(self,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment