diff --git a/examples/distributed_training.py b/examples/distributed_training.py
index 6910c6dd6945dcddca81a2263ca04bd287234287..7030a9f4d828c0b665a6cf7a8b05cd3705d54394 100644
--- a/examples/distributed_training.py
+++ b/examples/distributed_training.py
@@ -5,7 +5,6 @@ import torch.nn.functional as F
 import torchvision
 import torchvision.transforms as transforms
 from torch.optim import SGD
-from torch.utils.data import DataLoader
 
 from mmengine.evaluator import BaseMetric
 from mmengine.model import BaseModel
@@ -57,29 +56,33 @@ def parse_args():
 def main():
     args = parse_args()
     norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
-    train_dataloader = DataLoader(
+    train_set = torchvision.datasets.CIFAR10(
+        'data/cifar10',
+        train=True,
+        download=True,
+        transform=transforms.Compose([
+            transforms.RandomCrop(32, padding=4),
+            transforms.RandomHorizontalFlip(),
+            transforms.ToTensor(),
+            transforms.Normalize(**norm_cfg)
+        ]))
+    valid_set = torchvision.datasets.CIFAR10(
+        'data/cifar10',
+        train=False,
+        download=True,
+        transform=transforms.Compose(
+            [transforms.ToTensor(),
+             transforms.Normalize(**norm_cfg)]))
+    train_dataloader = dict(
         batch_size=32,
-        shuffle=True,
-        dataset=torchvision.datasets.CIFAR10(
-            'data/cifar10',
-            train=True,
-            download=True,
-            transform=transforms.Compose([
-                transforms.RandomCrop(32, padding=4),
-                transforms.RandomHorizontalFlip(),
-                transforms.ToTensor(),
-                transforms.Normalize(**norm_cfg)
-            ])))
-    val_dataloader = DataLoader(
+        dataset=train_set,
+        sampler=dict(type='DefaultSampler', shuffle=True),
+        collate_fn=dict(type='default_collate'))
+    val_dataloader = dict(
         batch_size=32,
-        shuffle=False,
-        dataset=torchvision.datasets.CIFAR10(
-            'data/cifar10',
-            train=False,
-            download=True,
-            transform=transforms.Compose(
-                [transforms.ToTensor(),
-                 transforms.Normalize(**norm_cfg)])))
+        dataset=valid_set,
+        sampler=dict(type='DefaultSampler', shuffle=False),
+        collate_fn=dict(type='default_collate'))
     runner = Runner(
         model=MMResNet50(),
         work_dir='./work_dir',