diff --git a/mmengine/model/__init__.py b/mmengine/model/__init__.py
index f1aedde257853fec2c47fe60db78a34329561f45..8b8203fbf921026cd163a33d50ead3a7f2a78e9a 100644
--- a/mmengine/model/__init__.py
+++ b/mmengine/model/__init__.py
@@ -1,4 +1,6 @@
 # Copyright (c) OpenMMLab. All rights reserved.
+from mmengine.utils.parrots_wrapper import TORCH_VERSION
+from mmengine.utils.version_utils import digit_version
 from .averaged_model import (ExponentialMovingAverage, MomentumAnnealingEMA,
                              StochasticWeightAverage)
 from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
@@ -15,3 +17,7 @@ __all__ = [
     'merge_dict', 'detect_anomalous_params', 'ModuleList', 'ModuleDict',
     'Sequential'
 ]
+
+if digit_version(TORCH_VERSION) >= digit_version('1.11.0'):
+    from .wrappers import MMFullyShardedDataParallel  # noqa:F401
+    __all__.append('MMFullyShardedDataParallel')
diff --git a/mmengine/model/wrappers/__init__.py b/mmengine/model/wrappers/__init__.py
index d6ece71384bda27e06d13c43d7b69d889dc1328a..78a3ced16b08de25731abb8c8c95a5681410b553 100644
--- a/mmengine/model/wrappers/__init__.py
+++ b/mmengine/model/wrappers/__init__.py
@@ -1,4 +1,6 @@
 # Copyright (c) OpenMMLab. All rights reserved.
+from mmengine.utils.parrots_wrapper import TORCH_VERSION
+from mmengine.utils.version_utils import digit_version
 from .distributed import MMDistributedDataParallel
 from .seperate_distributed import MMSeparateDistributedDataParallel
 from .utils import is_model_wrapper
@@ -7,3 +9,8 @@ __all__ = [
     'MMDistributedDataParallel', 'is_model_wrapper',
     'MMSeparateDistributedDataParallel'
 ]
+
+if digit_version(TORCH_VERSION) >= digit_version('1.11.0'):
+    from .fully_sharded_distributed import \
+        MMFullyShardedDataParallel  # noqa:F401
+    __all__.append('MMFullyShardedDataParallel')
diff --git a/mmengine/model/wrappers/fully_sharded_distributed.py b/mmengine/model/wrappers/fully_sharded_distributed.py
new file mode 100644
index 0000000000000000000000000000000000000000..74e4330396afdb2fbafd165db3acc162b219ec99
--- /dev/null
+++ b/mmengine/model/wrappers/fully_sharded_distributed.py
@@ -0,0 +1,209 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Callable, Dict, List, Optional, Union
+
+import torch
+import torch.nn as nn
+from torch.distributed import ProcessGroup
+from torch.distributed.fsdp.fully_sharded_data_parallel import (
+    BackwardPrefetch, CPUOffload, FullyShardedDataParallel)
+
+from mmengine.data import BaseDataElement
+from mmengine.optim import OptimWrapper
+from mmengine.registry import MODEL_WRAPPERS, Registry
+
+# support customize fsdp policy
+FSDP_WRAP_POLICYS = Registry('fsdp wrap policy')
+
+
+@MODEL_WRAPPERS.register_module()
+class MMFullyShardedDataParallel(FullyShardedDataParallel):
+    """A wrapper for sharding Module parameters across data parallel workers.
+
+    Different from FullyShardedDataParallel, MMFullyShardedDataParallel
+    implements three methods :meth:`train_step`, :meth:`val_step` and
+    :meth:`test_step`, which will be called by ``train_loop``, ``val_loop``
+    and ``test_loop``.
+
+    - ``train_step``: Called by ``runner.train_loop``, and implement
+      default model forward, gradient back propagation, parameter updating
+      logic.
+
+    - ``val_step``: Called by ``runner.val_loop`` and get the inference
+      results. Specially, since MMFullyShardedDataParallel will wrap model
+      recursively, it may cause some problem if one just use
+      ``BaseModel.val_step`` to implement ``val_step`` here. To avoid that,
+      ``val_step`` will call methods of :obj:`BaseModel` to pre-process
+      data first, and use ``FullyShardedDataParallel.forward`` to get result.
+
+    - ``test_step``: Called by ``runner.test_loop`` and get the inference
+      results. Its logic is equivalent to ``val_loop``.
+
+    Args:
+        module (nn.Module): module to be wrapped with FSDP.
+        process_group (Optional[ProcessGroup]): process group for sharding.
+        cpu_offload (Optional[Union[bool,CPUOffload]]):
+            CPU offloading config.
+            Different from FullyShardedDataParallel,Since it can be set by
+            users' pre-defined config in MMEngine,its type is expected to be
+            `None`, `bool` or `CPUOffload`.
+
+            Currently, only parameter and gradient CPU offload is supported.
+            It can be enabled via passing in
+            ``cpu_offload=CPUOffload(offload_params=True)``. Note that this
+            currently implicitly enables gradient offloading to CPU in order
+            for params and grads to be on same device to work with optimizer.
+            This API is subject to change. Default is ``None`` in which case
+            there will be no offloading.
+        fsdp_auto_wrap_policy: (Optional[Union[str,Callable]]):
+            Specifying a policy to recursively wrap layers with FSDP.
+            Different from FullyShardedDataParallel, Since it can be set by
+            users' pre-defined config in MMEngine, its type is expected to be
+            `None`, `str` or `Callable`. If it's `str`, then
+            MMFullyShardedDataParallel will try to get specified method in
+            ``FSDP_WRAP_POLICYS`` registry,and this method will be passed to
+            FullyShardedDataParallel to finally initialize model.
+
+            Note that this policy currently will only apply to child modules of
+            the passed in module. The remainder modules are always wrapped in
+            the returned FSDP root instance.
+            ``default_auto_wrap_policy`` written in
+            ``torch.distributed.fsdp.wrap`` is an example of
+            ``fsdp_auto_wrap_policy`` callable, this policy wraps layers with
+            parameter sizes larger than 100M. Users can supply the customized
+            ``fsdp_auto_wrap_policy`` callable that should accept following
+            arguments: ``module: nn.Module``, ``recurse: bool``,
+            ``unwrapped_params: int``, extra customized arguments could be
+            added to the customized ``fsdp_auto_wrap_policy`` callable as well.
+
+            Example::
+
+                >>> def custom_auto_wrap_policy(
+                >>>     module: nn.Module,
+                >>>     recurse: bool,
+                >>>     unwrapped_params: int,
+                >>>     # These are customizable for this policy function.
+                >>>     min_num_params: int = int(1e8),
+                >>> ) -> bool:
+                >>>     return unwrapped_params >= min_num_params
+
+        backward_prefetch: (Optional[Union[str,BackwardPrefetch]]):
+            Different from FullyShardedDataParallel, Since it will be set by
+            users' pre-defined config in MMEngine,its type is expected to be
+            `None`, `str` or `BackwardPrefetch`.
+
+            This is an experimental feature that is subject to change in the
+            the near future. It allows users to enable two different
+            backward_prefetch algorithms to help backward communication and
+            computation overlapping.
+            Pros and cons of each algorithm is explained in class
+            ``BackwardPrefetch``.
+    """
+
+    def __init__(
+        self,
+        module: nn.Module,
+        process_group: Optional[ProcessGroup] = None,
+        cpu_offload: Optional[Union[bool, CPUOffload]] = None,
+        fsdp_auto_wrap_policy: Optional[Union[str, Callable]] = None,
+        backward_prefetch: Optional[Union[str, BackwardPrefetch]] = None,
+    ):
+
+        if cpu_offload is not None:
+            if isinstance(cpu_offload, bool):
+                cpu_offload = CPUOffload(offload_params=cpu_offload)
+            elif not isinstance(cpu_offload, CPUOffload):
+                raise TypeError(
+                    '`cpu_offload` should be `None`, `bool`'
+                    f'or `CPUOffload`, but has type {type(cpu_offload)}')
+
+        if fsdp_auto_wrap_policy is not None:
+            if isinstance(fsdp_auto_wrap_policy, str):
+                assert fsdp_auto_wrap_policy in FSDP_WRAP_POLICYS, \
+                    '`FSDP_WRAP_POLICYS` has no ' \
+                    f'function {fsdp_auto_wrap_policy}'
+                fsdp_auto_wrap_policy = FSDP_WRAP_POLICYS.get(  # type: ignore
+                    fsdp_auto_wrap_policy)
+                if not isinstance(fsdp_auto_wrap_policy,
+                                  Callable):  # type: ignore
+                    raise TypeError(
+                        'Registered `fsdp_auto_wrap_policy` needs to be '
+                        '`Callable`, but has type '
+                        f'{type(fsdp_auto_wrap_policy)}')
+            elif not isinstance(fsdp_auto_wrap_policy,
+                                Callable):  # type: ignore
+                raise TypeError(
+                    '`fsdp_auto_wrap_policy` should be `None`, `str` '
+                    'or `Callable`, but has type '
+                    f'{type(fsdp_auto_wrap_policy)}')
+
+        if backward_prefetch is not None:
+            if isinstance(backward_prefetch, str):
+                assert backward_prefetch in ['pre', 'post'], \
+                    '`backward_prefetch` should be either `pre` or `post`,' \
+                    f' but get {backward_prefetch}'
+                if backward_prefetch == 'pre':
+                    backward_prefetch = BackwardPrefetch.BACKWARD_PRE
+                else:
+                    backward_prefetch = BackwardPrefetch.BACKWARD_POST
+            elif not isinstance(backward_prefetch, BackwardPrefetch):
+                raise TypeError('`backward_prefetch` should be `None`, `str` '
+                                'or `BackwardPrefetch`, but has type '
+                                f'{type(backward_prefetch)}')
+
+        super().__init__(module, process_group, cpu_offload,
+                         fsdp_auto_wrap_policy, backward_prefetch)
+
+    def train_step(self, data: List[dict],
+                   optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
+        """Interface for model forward, backward and parameters updating during
+        training process.
+
+        :meth:`train_step` will perform the following steps in order:
+
+        - If :attr:`module` defines the preprocess method,
+            call ``module.preprocess`` to pre-processing data.
+        - Call ``module.forward(**data)`` and get losses.
+        - Parse losses.
+        - Call ``optim_wrapper.optimizer_step`` to update parameters.
+        - Return log messages of losses.
+
+        Args:
+            data (List[dict]): Data sampled by dataloader.
+            optim_wrapper (OptimWrapper): A wrapper of optimizer to
+                update parameters.
+
+        Returns:
+            Dict[str, torch.Tensor]: A ``dict`` of tensor for logging.
+        """
+        # enable automatic mixed precision training context.
+        with optim_wrapper.optim_context(self):
+            batch_inputs, data_samples = self.module.data_preprocessor(
+                data, training=True)
+            losses = self(batch_inputs, data_samples, mode='loss')
+        parsed_loss, log_vars = self.module.parse_losses(losses)
+        optim_wrapper.update_params(parsed_loss)
+        return log_vars
+
+    def val_step(self, data: List[dict]) -> List[BaseDataElement]:
+        """Gets the prediction of module during validation process.
+
+        Args:
+            data (List[dict]): Data sampled by dataloader.
+
+        Returns:
+            List[BaseDataElement] or dict: The predictions of given data.
+        """
+        inputs, data_sample = self.module.data_preprocessor(data, False)
+        return self(inputs, data_sample, mode='predict')
+
+    def test_step(self, data: List[dict]) -> List[BaseDataElement]:
+        """Gets the predictions of module during testing process.
+
+        Args:
+            data: Data sampled by dataloader.
+
+        Returns:
+            List[BaseDataElement]: The predictions of given data.
+        """
+        inputs, data_sample = self.module.data_preprocessor(data, False)
+        return self(inputs, data_sample, mode='predict')
diff --git a/tests/test_model/test_wrappers/test_model_wrapper.py b/tests/test_model/test_wrappers/test_model_wrapper.py
index 0efe338be40e1fcd0b4486ef60d28f61aecc649c..90ae3643778ad20b09a3ca3f19041527f1d4dd0c 100644
--- a/tests/test_model/test_wrappers/test_model_wrapper.py
+++ b/tests/test_model/test_wrappers/test_model_wrapper.py
@@ -14,6 +14,11 @@ from mmengine.model import (BaseModel, MMDistributedDataParallel,
 from mmengine.optim import AmpOptimWrapper, OptimWrapper, OptimWrapperDict
 from mmengine.testing import assert_allclose
 from mmengine.testing._internal import MultiProcessTestCase
+from mmengine.utils.parrots_wrapper import TORCH_VERSION
+from mmengine.utils.version_utils import digit_version
+
+if digit_version(TORCH_VERSION) >= digit_version('1.11.0'):
+    from mmengine.model import MMFullyShardedDataParallel
 
 
 class ToyModel(BaseModel):
@@ -177,3 +182,58 @@ class TestMMSeparateDistributedDataParallel(TestDistributedDataParallel):
         os.environ['RANK'] = str(rank)
         torch_dist.init_process_group(
             backend='gloo', rank=rank, world_size=world_size)
+
+
+@unittest.skipIf(
+    torch.cuda.device_count() < 2, reason='need 2 gpu to test fsdp')
+@unittest.skipIf(
+    digit_version(TORCH_VERSION) < digit_version('1.11.0'),
+    reason='fsdp needs Pytorch 1.11 or higher')
+class TestMMFullyShardedDataParallel(MultiProcessTestCase):
+
+    def _init_dist_env(self, rank, world_size):
+        """Initialize the distributed environment."""
+        os.environ['MASTER_ADDR'] = '127.0.0.1'
+        os.environ['MASTER_PORT'] = '29520'
+        os.environ['RANK'] = str(rank)
+
+        num_gpus = torch.cuda.device_count()
+        torch.cuda.set_device(rank % num_gpus)
+        torch_dist.init_process_group(
+            backend='nccl', rank=rank, world_size=world_size)
+
+    def setUp(self) -> None:
+        super().setUp()
+        self._spawn_processes()
+
+    def test_train_step(self):
+        self._init_dist_env(self.rank, self.world_size)
+        # Test `optim_wrapper` is a instance of `OptimWrapper`
+        model = ToyModel()
+        fsdp_model = MMFullyShardedDataParallel(module=model.cuda())
+        optimizer = SGD(fsdp_model.parameters(), lr=0)
+        optim_wrapper = OptimWrapper(optimizer, accumulative_iters=1)
+        inputs = torch.randn(3, 1, 1) * self.rank * 255
+        data = dict(inputs=inputs, data_sample=MagicMock())
+        fsdp_model.train()
+        self.assertTrue(fsdp_model.training)
+        fsdp_model.train_step([data], optim_wrapper=optim_wrapper)
+
+    def test_val_step(self):
+        self._init_dist_env(self.rank, self.world_size)
+        model = ToyModel()
+        fsdp_model = MMFullyShardedDataParallel(module=model.cuda())
+        inputs = torch.randn(3, 1, 1) * self.rank * 255
+        data = dict(inputs=inputs, data_sample=MagicMock())
+        # Test get predictions.
+        predictions = fsdp_model.val_step([data])
+        self.assertIsInstance(predictions, torch.Tensor)
+
+    def test_test_step(self):
+        self._init_dist_env(self.rank, self.world_size)
+        model = ToyModel()
+        fsdp_model = MMFullyShardedDataParallel(module=model.cuda())
+        inputs = torch.randn(3, 1, 1) * self.rank * 255
+        data = dict(inputs=inputs, data_sample=MagicMock())
+        predictions = fsdp_model.test_step([data])
+        self.assertIsInstance(predictions, torch.Tensor)