Skip to content
Snippets Groups Projects
Unverified Commit 46209b8c authored by GPH's avatar GPH Committed by GitHub
Browse files

[Fix] Fix examples/distributed_training.py does not work in DDP (#700)

* Update distributed_training.py

Better example for DDP training

* Update distributed_training.py

* Update distributed_training.py

update according to reviwer's suggesstions.

* Update distributed_training.py

* Update distributed_training.py

The previous update copy data from main branch, its a mistake.
This update fix this mistake and the code is tested.
parent b35196ac
No related branches found
No related tags found
No related merge requests found
...@@ -5,7 +5,6 @@ import torch.nn.functional as F ...@@ -5,7 +5,6 @@ import torch.nn.functional as F
import torchvision import torchvision
import torchvision.transforms as transforms import torchvision.transforms as transforms
from torch.optim import SGD from torch.optim import SGD
from torch.utils.data import DataLoader
from mmengine.evaluator import BaseMetric from mmengine.evaluator import BaseMetric
from mmengine.model import BaseModel from mmengine.model import BaseModel
...@@ -57,29 +56,33 @@ def parse_args(): ...@@ -57,29 +56,33 @@ def parse_args():
def main(): def main():
args = parse_args() args = parse_args()
norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201]) norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_dataloader = DataLoader( train_set = torchvision.datasets.CIFAR10(
'data/cifar10',
train=True,
download=True,
transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)
]))
valid_set = torchvision.datasets.CIFAR10(
'data/cifar10',
train=False,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize(**norm_cfg)]))
train_dataloader = dict(
batch_size=32, batch_size=32,
shuffle=True, dataset=train_set,
dataset=torchvision.datasets.CIFAR10( sampler=dict(type='DefaultSampler', shuffle=True),
'data/cifar10', collate_fn=dict(type='default_collate'))
train=True, val_dataloader = dict(
download=True,
transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)
])))
val_dataloader = DataLoader(
batch_size=32, batch_size=32,
shuffle=False, dataset=valid_set,
dataset=torchvision.datasets.CIFAR10( sampler=dict(type='DefaultSampler', shuffle=False),
'data/cifar10', collate_fn=dict(type='default_collate'))
train=False,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize(**norm_cfg)])))
runner = Runner( runner = Runner(
model=MMResNet50(), model=MMResNet50(),
work_dir='./work_dir', work_dir='./work_dir',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment