From 46209b8cbf9a9e263020474c0ae0c6a0e99e58b3 Mon Sep 17 00:00:00 2001
From: GPH <gphsmail@163.com>
Date: Wed, 9 Nov 2022 22:04:08 +0800
Subject: [PATCH] [Fix] Fix examples/distributed_training.py does not work in
 DDP (#700)

* Update distributed_training.py

Better example for DDP training

* Update distributed_training.py

* Update distributed_training.py

update according to reviwer's suggesstions.

* Update distributed_training.py

* Update distributed_training.py

The previous update copy data from main branch, its a mistake.
This update fix this mistake and the code is tested.
---
 examples/distributed_training.py | 47 +++++++++++++++++---------------
 1 file changed, 25 insertions(+), 22 deletions(-)

diff --git a/examples/distributed_training.py b/examples/distributed_training.py
index 6910c6dd..7030a9f4 100644
--- a/examples/distributed_training.py
+++ b/examples/distributed_training.py
@@ -5,7 +5,6 @@ import torch.nn.functional as F
 import torchvision
 import torchvision.transforms as transforms
 from torch.optim import SGD
-from torch.utils.data import DataLoader
 
 from mmengine.evaluator import BaseMetric
 from mmengine.model import BaseModel
@@ -57,29 +56,33 @@ def parse_args():
 def main():
     args = parse_args()
     norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
-    train_dataloader = DataLoader(
+    train_set = torchvision.datasets.CIFAR10(
+        'data/cifar10',
+        train=True,
+        download=True,
+        transform=transforms.Compose([
+            transforms.RandomCrop(32, padding=4),
+            transforms.RandomHorizontalFlip(),
+            transforms.ToTensor(),
+            transforms.Normalize(**norm_cfg)
+        ]))
+    valid_set = torchvision.datasets.CIFAR10(
+        'data/cifar10',
+        train=False,
+        download=True,
+        transform=transforms.Compose(
+            [transforms.ToTensor(),
+             transforms.Normalize(**norm_cfg)]))
+    train_dataloader = dict(
         batch_size=32,
-        shuffle=True,
-        dataset=torchvision.datasets.CIFAR10(
-            'data/cifar10',
-            train=True,
-            download=True,
-            transform=transforms.Compose([
-                transforms.RandomCrop(32, padding=4),
-                transforms.RandomHorizontalFlip(),
-                transforms.ToTensor(),
-                transforms.Normalize(**norm_cfg)
-            ])))
-    val_dataloader = DataLoader(
+        dataset=train_set,
+        sampler=dict(type='DefaultSampler', shuffle=True),
+        collate_fn=dict(type='default_collate'))
+    val_dataloader = dict(
         batch_size=32,
-        shuffle=False,
-        dataset=torchvision.datasets.CIFAR10(
-            'data/cifar10',
-            train=False,
-            download=True,
-            transform=transforms.Compose(
-                [transforms.ToTensor(),
-                 transforms.Normalize(**norm_cfg)])))
+        dataset=valid_set,
+        sampler=dict(type='DefaultSampler', shuffle=False),
+        collate_fn=dict(type='default_collate'))
     runner = Runner(
         model=MMResNet50(),
         work_dir='./work_dir',
-- 
GitLab