Skip to content
Snippets Groups Projects
Unverified Commit 9bbbd7dc authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Fix] Fix error argument sequence in fsdp (#520)

parent a6f52977
No related branches found
No related tags found
No related merge requests found
...@@ -97,6 +97,9 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel): ...@@ -97,6 +97,9 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
computation overlapping. computation overlapping.
Pros and cons of each algorithm is explained in class Pros and cons of each algorithm is explained in class
``BackwardPrefetch``. ``BackwardPrefetch``.
**kwargs: Keyword arguments passed to
:class:`FullyShardedDataParallel`.
""" """
def __init__( def __init__(
...@@ -106,6 +109,7 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel): ...@@ -106,6 +109,7 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
cpu_offload: Optional[Union[bool, CPUOffload]] = None, cpu_offload: Optional[Union[bool, CPUOffload]] = None,
fsdp_auto_wrap_policy: Optional[Union[str, Callable]] = None, fsdp_auto_wrap_policy: Optional[Union[str, Callable]] = None,
backward_prefetch: Optional[Union[str, BackwardPrefetch]] = None, backward_prefetch: Optional[Union[str, BackwardPrefetch]] = None,
**kwargs,
): ):
if cpu_offload is not None: if cpu_offload is not None:
...@@ -150,8 +154,13 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel): ...@@ -150,8 +154,13 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
'or `BackwardPrefetch`, but has type ' 'or `BackwardPrefetch`, but has type '
f'{type(backward_prefetch)}') f'{type(backward_prefetch)}')
super().__init__(module, process_group, cpu_offload, super().__init__(
fsdp_auto_wrap_policy, backward_prefetch) module=module,
process_group=process_group,
auto_wrap_policy=fsdp_auto_wrap_policy,
cpu_offload=cpu_offload,
backward_prefetch=backward_prefetch,
**kwargs)
def train_step(self, data: dict, def train_step(self, data: dict,
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]: optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment