Newer
Older
# Copyright (c) OpenMMLab. All rights reserved.
import random
import warnings
from typing import Any, Mapping, Sequence
from torch.utils.data._utils.collate import \
default_collate as torch_default_collate
from mmengine.registry import Registry
from mmengine.structures import BaseDataElement
COLLATE_FUNCTIONS = Registry('Collate Functions')
def worker_init_fn(worker_id: int,
num_workers: int,
rank: int,
seed: int,
disable_subprocess_warning: bool = False) -> None:
"""This function will be called on each worker subprocess after seeding and
before data loading.
Args:
worker_id (int): Worker id in [0, num_workers - 1].
num_workers (int): How many subprocesses to use for data loading.
rank (int): Rank of process in distributed environment. If in
non-distributed environment, it is a constant number `0`.
seed (int): Random seed.
"""
# The seed of each worker equals to
# num_worker * rank + worker_id + user_seed
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)
torch.manual_seed(worker_seed)
if disable_subprocess_warning and worker_id != 0:
warnings.simplefilter('ignore')
@COLLATE_FUNCTIONS.register_module()
def pseudo_collate(data_batch: Sequence) -> Any:
"""Convert list of data sampled from dataset into a batch of data, of which
type consistent with the type of each data_itement in ``data_batch``.
The default behavior of dataloader is to merge a list of samples to form
a mini-batch of Tensor(s). However, in MMEngine, ``pseudo_collate``
will not stack tensors to batch tensors, and convert int, float, ndarray to
tensors.
This code is referenced from:
`Pytorch default_collate <https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py>`_.
Args:
data_batch (Sequence): Batch of data from dataloader.
Returns:
Any: Transversed Data in the same format as the data_itement of
``data_batch``.
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
data_item = data_batch[0]
data_item_type = type(data_item)
if isinstance(data_item, (str, bytes)):
return data_batch
elif isinstance(data_item, tuple) and hasattr(data_item, '_fields'):
# named tuple
return data_item_type(*(pseudo_collate(samples)
for samples in zip(*data_batch)))
elif isinstance(data_item, Sequence):
# check to make sure that the data_itements in batch have
# consistent size
it = iter(data_batch)
data_item_size = len(next(it))
if not all(len(data_item) == data_item_size for data_item in it):
raise RuntimeError(
'each data_itement in list of batch should be of equal size')
transposed = list(zip(*data_batch))
if isinstance(data_item, tuple):
return [pseudo_collate(samples)
for samples in transposed] # Compat with Pytorch.
else:
try:
return data_item_type(
[pseudo_collate(samples) for samples in transposed])
except TypeError:
# The sequence type may not support `__init__(iterable)`
# (e.g., `range`).
return [pseudo_collate(samples) for samples in transposed]
elif isinstance(data_item, Mapping):
return data_item_type({
key: pseudo_collate([d[key] for d in data_batch])
for key in data_item
})
else:
return data_batch
@COLLATE_FUNCTIONS.register_module()
def default_collate(data_batch: Sequence) -> Any:
"""Convert list of data sampled from dataset into a batch of data, of which
type consistent with the type of each data_itement in ``data_batch``.
Different from :func:`pseudo_collate`, ``default_collate`` will stack
tensor contained in ``data_batch`` into a batched tensor with the
first dimension batch size, and then move input tensor to the target
device.
Different from ``default_collate`` in pytorch, ``default_collate`` will
not process ``BaseDataElement``.
This code is referenced from:
`Pytorch default_collate <https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py>`_.
Note:
``default_collate`` only accept input tensor with the same shape.
data_batch (Sequence): Data sampled from dataset.
Any: Data in the same format as the data_itement of ``data_batch``, of which
tensors have been stacked, and ndarray, int, float have been
converted to tensors.
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
data_item = data_batch[0]
data_item_type = type(data_item)
if isinstance(data_item, (BaseDataElement, str, bytes)):
return data_batch
elif isinstance(data_item, tuple) and hasattr(data_item, '_fields'):
# named_tuple
return data_item_type(*(default_collate(samples)
for samples in zip(*data_batch)))
elif isinstance(data_item, Sequence):
# check to make sure that the data_itements in batch have
# consistent size
it = iter(data_batch)
data_item_size = len(next(it))
if not all(len(data_item) == data_item_size for data_item in it):
raise RuntimeError(
'each data_itement in list of batch should be of equal size')
transposed = list(zip(*data_batch))
if isinstance(data_item, tuple):
return [default_collate(samples)
for samples in transposed] # Compat with Pytorch.
else:
try:
return data_item_type(
[default_collate(samples) for samples in transposed])
except TypeError:
# The sequence type may not support `__init__(iterable)`
# (e.g., `range`).
return [default_collate(samples) for samples in transposed]
elif isinstance(data_item, Mapping):
return data_item_type({
key: default_collate([d[key] for d in data_batch])
for key in data_item
})
else:
return torch_default_collate(data_batch)