diff --git a/mmengine/model/__init__.py b/mmengine/model/__init__.py index f1aedde257853fec2c47fe60db78a34329561f45..8b8203fbf921026cd163a33d50ead3a7f2a78e9a 100644 --- a/mmengine/model/__init__.py +++ b/mmengine/model/__init__.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from mmengine.utils.parrots_wrapper import TORCH_VERSION +from mmengine.utils.version_utils import digit_version from .averaged_model import (ExponentialMovingAverage, MomentumAnnealingEMA, StochasticWeightAverage) from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor @@ -15,3 +17,7 @@ __all__ = [ 'merge_dict', 'detect_anomalous_params', 'ModuleList', 'ModuleDict', 'Sequential' ] + +if digit_version(TORCH_VERSION) >= digit_version('1.11.0'): + from .wrappers import MMFullyShardedDataParallel # noqa:F401 + __all__.append('MMFullyShardedDataParallel') diff --git a/mmengine/model/wrappers/__init__.py b/mmengine/model/wrappers/__init__.py index d6ece71384bda27e06d13c43d7b69d889dc1328a..78a3ced16b08de25731abb8c8c95a5681410b553 100644 --- a/mmengine/model/wrappers/__init__.py +++ b/mmengine/model/wrappers/__init__.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from mmengine.utils.parrots_wrapper import TORCH_VERSION +from mmengine.utils.version_utils import digit_version from .distributed import MMDistributedDataParallel from .seperate_distributed import MMSeparateDistributedDataParallel from .utils import is_model_wrapper @@ -7,3 +9,8 @@ __all__ = [ 'MMDistributedDataParallel', 'is_model_wrapper', 'MMSeparateDistributedDataParallel' ] + +if digit_version(TORCH_VERSION) >= digit_version('1.11.0'): + from .fully_sharded_distributed import \ + MMFullyShardedDataParallel # noqa:F401 + __all__.append('MMFullyShardedDataParallel') diff --git a/mmengine/model/wrappers/fully_sharded_distributed.py b/mmengine/model/wrappers/fully_sharded_distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..74e4330396afdb2fbafd165db3acc162b219ec99 --- /dev/null +++ b/mmengine/model/wrappers/fully_sharded_distributed.py @@ -0,0 +1,209 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Callable, Dict, List, Optional, Union + +import torch +import torch.nn as nn +from torch.distributed import ProcessGroup +from torch.distributed.fsdp.fully_sharded_data_parallel import ( + BackwardPrefetch, CPUOffload, FullyShardedDataParallel) + +from mmengine.data import BaseDataElement +from mmengine.optim import OptimWrapper +from mmengine.registry import MODEL_WRAPPERS, Registry + +# support customize fsdp policy +FSDP_WRAP_POLICYS = Registry('fsdp wrap policy') + + +@MODEL_WRAPPERS.register_module() +class MMFullyShardedDataParallel(FullyShardedDataParallel): + """A wrapper for sharding Module parameters across data parallel workers. + + Different from FullyShardedDataParallel, MMFullyShardedDataParallel + implements three methods :meth:`train_step`, :meth:`val_step` and + :meth:`test_step`, which will be called by ``train_loop``, ``val_loop`` + and ``test_loop``. + + - ``train_step``: Called by ``runner.train_loop``, and implement + default model forward, gradient back propagation, parameter updating + logic. + + - ``val_step``: Called by ``runner.val_loop`` and get the inference + results. Specially, since MMFullyShardedDataParallel will wrap model + recursively, it may cause some problem if one just use + ``BaseModel.val_step`` to implement ``val_step`` here. To avoid that, + ``val_step`` will call methods of :obj:`BaseModel` to pre-process + data first, and use ``FullyShardedDataParallel.forward`` to get result. + + - ``test_step``: Called by ``runner.test_loop`` and get the inference + results. Its logic is equivalent to ``val_loop``. + + Args: + module (nn.Module): module to be wrapped with FSDP. + process_group (Optional[ProcessGroup]): process group for sharding. + cpu_offload (Optional[Union[bool,CPUOffload]]): + CPU offloading config. + Different from FullyShardedDataParallel,Since it can be set by + users' pre-defined config in MMEngine,its type is expected to be + `None`, `bool` or `CPUOffload`. + + Currently, only parameter and gradient CPU offload is supported. + It can be enabled via passing in + ``cpu_offload=CPUOffload(offload_params=True)``. Note that this + currently implicitly enables gradient offloading to CPU in order + for params and grads to be on same device to work with optimizer. + This API is subject to change. Default is ``None`` in which case + there will be no offloading. + fsdp_auto_wrap_policy: (Optional[Union[str,Callable]]): + Specifying a policy to recursively wrap layers with FSDP. + Different from FullyShardedDataParallel, Since it can be set by + users' pre-defined config in MMEngine, its type is expected to be + `None`, `str` or `Callable`. If it's `str`, then + MMFullyShardedDataParallel will try to get specified method in + ``FSDP_WRAP_POLICYS`` registry,and this method will be passed to + FullyShardedDataParallel to finally initialize model. + + Note that this policy currently will only apply to child modules of + the passed in module. The remainder modules are always wrapped in + the returned FSDP root instance. + ``default_auto_wrap_policy`` written in + ``torch.distributed.fsdp.wrap`` is an example of + ``fsdp_auto_wrap_policy`` callable, this policy wraps layers with + parameter sizes larger than 100M. Users can supply the customized + ``fsdp_auto_wrap_policy`` callable that should accept following + arguments: ``module: nn.Module``, ``recurse: bool``, + ``unwrapped_params: int``, extra customized arguments could be + added to the customized ``fsdp_auto_wrap_policy`` callable as well. + + Example:: + + >>> def custom_auto_wrap_policy( + >>> module: nn.Module, + >>> recurse: bool, + >>> unwrapped_params: int, + >>> # These are customizable for this policy function. + >>> min_num_params: int = int(1e8), + >>> ) -> bool: + >>> return unwrapped_params >= min_num_params + + backward_prefetch: (Optional[Union[str,BackwardPrefetch]]): + Different from FullyShardedDataParallel, Since it will be set by + users' pre-defined config in MMEngine,its type is expected to be + `None`, `str` or `BackwardPrefetch`. + + This is an experimental feature that is subject to change in the + the near future. It allows users to enable two different + backward_prefetch algorithms to help backward communication and + computation overlapping. + Pros and cons of each algorithm is explained in class + ``BackwardPrefetch``. + """ + + def __init__( + self, + module: nn.Module, + process_group: Optional[ProcessGroup] = None, + cpu_offload: Optional[Union[bool, CPUOffload]] = None, + fsdp_auto_wrap_policy: Optional[Union[str, Callable]] = None, + backward_prefetch: Optional[Union[str, BackwardPrefetch]] = None, + ): + + if cpu_offload is not None: + if isinstance(cpu_offload, bool): + cpu_offload = CPUOffload(offload_params=cpu_offload) + elif not isinstance(cpu_offload, CPUOffload): + raise TypeError( + '`cpu_offload` should be `None`, `bool`' + f'or `CPUOffload`, but has type {type(cpu_offload)}') + + if fsdp_auto_wrap_policy is not None: + if isinstance(fsdp_auto_wrap_policy, str): + assert fsdp_auto_wrap_policy in FSDP_WRAP_POLICYS, \ + '`FSDP_WRAP_POLICYS` has no ' \ + f'function {fsdp_auto_wrap_policy}' + fsdp_auto_wrap_policy = FSDP_WRAP_POLICYS.get( # type: ignore + fsdp_auto_wrap_policy) + if not isinstance(fsdp_auto_wrap_policy, + Callable): # type: ignore + raise TypeError( + 'Registered `fsdp_auto_wrap_policy` needs to be ' + '`Callable`, but has type ' + f'{type(fsdp_auto_wrap_policy)}') + elif not isinstance(fsdp_auto_wrap_policy, + Callable): # type: ignore + raise TypeError( + '`fsdp_auto_wrap_policy` should be `None`, `str` ' + 'or `Callable`, but has type ' + f'{type(fsdp_auto_wrap_policy)}') + + if backward_prefetch is not None: + if isinstance(backward_prefetch, str): + assert backward_prefetch in ['pre', 'post'], \ + '`backward_prefetch` should be either `pre` or `post`,' \ + f' but get {backward_prefetch}' + if backward_prefetch == 'pre': + backward_prefetch = BackwardPrefetch.BACKWARD_PRE + else: + backward_prefetch = BackwardPrefetch.BACKWARD_POST + elif not isinstance(backward_prefetch, BackwardPrefetch): + raise TypeError('`backward_prefetch` should be `None`, `str` ' + 'or `BackwardPrefetch`, but has type ' + f'{type(backward_prefetch)}') + + super().__init__(module, process_group, cpu_offload, + fsdp_auto_wrap_policy, backward_prefetch) + + def train_step(self, data: List[dict], + optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]: + """Interface for model forward, backward and parameters updating during + training process. + + :meth:`train_step` will perform the following steps in order: + + - If :attr:`module` defines the preprocess method, + call ``module.preprocess`` to pre-processing data. + - Call ``module.forward(**data)`` and get losses. + - Parse losses. + - Call ``optim_wrapper.optimizer_step`` to update parameters. + - Return log messages of losses. + + Args: + data (List[dict]): Data sampled by dataloader. + optim_wrapper (OptimWrapper): A wrapper of optimizer to + update parameters. + + Returns: + Dict[str, torch.Tensor]: A ``dict`` of tensor for logging. + """ + # enable automatic mixed precision training context. + with optim_wrapper.optim_context(self): + batch_inputs, data_samples = self.module.data_preprocessor( + data, training=True) + losses = self(batch_inputs, data_samples, mode='loss') + parsed_loss, log_vars = self.module.parse_losses(losses) + optim_wrapper.update_params(parsed_loss) + return log_vars + + def val_step(self, data: List[dict]) -> List[BaseDataElement]: + """Gets the prediction of module during validation process. + + Args: + data (List[dict]): Data sampled by dataloader. + + Returns: + List[BaseDataElement] or dict: The predictions of given data. + """ + inputs, data_sample = self.module.data_preprocessor(data, False) + return self(inputs, data_sample, mode='predict') + + def test_step(self, data: List[dict]) -> List[BaseDataElement]: + """Gets the predictions of module during testing process. + + Args: + data: Data sampled by dataloader. + + Returns: + List[BaseDataElement]: The predictions of given data. + """ + inputs, data_sample = self.module.data_preprocessor(data, False) + return self(inputs, data_sample, mode='predict') diff --git a/tests/test_model/test_wrappers/test_model_wrapper.py b/tests/test_model/test_wrappers/test_model_wrapper.py index 0efe338be40e1fcd0b4486ef60d28f61aecc649c..90ae3643778ad20b09a3ca3f19041527f1d4dd0c 100644 --- a/tests/test_model/test_wrappers/test_model_wrapper.py +++ b/tests/test_model/test_wrappers/test_model_wrapper.py @@ -14,6 +14,11 @@ from mmengine.model import (BaseModel, MMDistributedDataParallel, from mmengine.optim import AmpOptimWrapper, OptimWrapper, OptimWrapperDict from mmengine.testing import assert_allclose from mmengine.testing._internal import MultiProcessTestCase +from mmengine.utils.parrots_wrapper import TORCH_VERSION +from mmengine.utils.version_utils import digit_version + +if digit_version(TORCH_VERSION) >= digit_version('1.11.0'): + from mmengine.model import MMFullyShardedDataParallel class ToyModel(BaseModel): @@ -177,3 +182,58 @@ class TestMMSeparateDistributedDataParallel(TestDistributedDataParallel): os.environ['RANK'] = str(rank) torch_dist.init_process_group( backend='gloo', rank=rank, world_size=world_size) + + +@unittest.skipIf( + torch.cuda.device_count() < 2, reason='need 2 gpu to test fsdp') +@unittest.skipIf( + digit_version(TORCH_VERSION) < digit_version('1.11.0'), + reason='fsdp needs Pytorch 1.11 or higher') +class TestMMFullyShardedDataParallel(MultiProcessTestCase): + + def _init_dist_env(self, rank, world_size): + """Initialize the distributed environment.""" + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '29520' + os.environ['RANK'] = str(rank) + + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + torch_dist.init_process_group( + backend='nccl', rank=rank, world_size=world_size) + + def setUp(self) -> None: + super().setUp() + self._spawn_processes() + + def test_train_step(self): + self._init_dist_env(self.rank, self.world_size) + # Test `optim_wrapper` is a instance of `OptimWrapper` + model = ToyModel() + fsdp_model = MMFullyShardedDataParallel(module=model.cuda()) + optimizer = SGD(fsdp_model.parameters(), lr=0) + optim_wrapper = OptimWrapper(optimizer, accumulative_iters=1) + inputs = torch.randn(3, 1, 1) * self.rank * 255 + data = dict(inputs=inputs, data_sample=MagicMock()) + fsdp_model.train() + self.assertTrue(fsdp_model.training) + fsdp_model.train_step([data], optim_wrapper=optim_wrapper) + + def test_val_step(self): + self._init_dist_env(self.rank, self.world_size) + model = ToyModel() + fsdp_model = MMFullyShardedDataParallel(module=model.cuda()) + inputs = torch.randn(3, 1, 1) * self.rank * 255 + data = dict(inputs=inputs, data_sample=MagicMock()) + # Test get predictions. + predictions = fsdp_model.val_step([data]) + self.assertIsInstance(predictions, torch.Tensor) + + def test_test_step(self): + self._init_dist_env(self.rank, self.world_size) + model = ToyModel() + fsdp_model = MMFullyShardedDataParallel(module=model.cuda()) + inputs = torch.randn(3, 1, 1) * self.rank * 255 + data = dict(inputs=inputs, data_sample=MagicMock()) + predictions = fsdp_model.test_step([data]) + self.assertIsInstance(predictions, torch.Tensor)