diff --git a/examples/distributed_training.py b/examples/distributed_training.py index 6910c6dd6945dcddca81a2263ca04bd287234287..7030a9f4d828c0b665a6cf7a8b05cd3705d54394 100644 --- a/examples/distributed_training.py +++ b/examples/distributed_training.py @@ -5,7 +5,6 @@ import torch.nn.functional as F import torchvision import torchvision.transforms as transforms from torch.optim import SGD -from torch.utils.data import DataLoader from mmengine.evaluator import BaseMetric from mmengine.model import BaseModel @@ -57,29 +56,33 @@ def parse_args(): def main(): args = parse_args() norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201]) - train_dataloader = DataLoader( + train_set = torchvision.datasets.CIFAR10( + 'data/cifar10', + train=True, + download=True, + transform=transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(**norm_cfg) + ])) + valid_set = torchvision.datasets.CIFAR10( + 'data/cifar10', + train=False, + download=True, + transform=transforms.Compose( + [transforms.ToTensor(), + transforms.Normalize(**norm_cfg)])) + train_dataloader = dict( batch_size=32, - shuffle=True, - dataset=torchvision.datasets.CIFAR10( - 'data/cifar10', - train=True, - download=True, - transform=transforms.Compose([ - transforms.RandomCrop(32, padding=4), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize(**norm_cfg) - ]))) - val_dataloader = DataLoader( + dataset=train_set, + sampler=dict(type='DefaultSampler', shuffle=True), + collate_fn=dict(type='default_collate')) + val_dataloader = dict( batch_size=32, - shuffle=False, - dataset=torchvision.datasets.CIFAR10( - 'data/cifar10', - train=False, - download=True, - transform=transforms.Compose( - [transforms.ToTensor(), - transforms.Normalize(**norm_cfg)]))) + dataset=valid_set, + sampler=dict(type='DefaultSampler', shuffle=False), + collate_fn=dict(type='default_collate')) runner = Runner( model=MMResNet50(), work_dir='./work_dir',