Newer
Older
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import List, Sequence, Union
import numpy as np
import torch
from .base_data_element import BaseDataElement
class PixelData(BaseDataElement):
"""Data structure for pixel-level annotations or predictions.
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
All data items in ``data_fields`` of ``PixelData`` meet the following
requirements:
- They all have 3 dimensions in orders of channel, height, and width.
- They should have the same height and width.
Examples:
>>> metainfo = dict(
... img_id=random.randint(0, 100),
... img_shape=(random.randint(400, 600), random.randint(400, 600)))
>>> image = np.random.randint(0, 255, (4, 20, 40))
>>> featmap = torch.randint(0, 255, (10, 20, 40))
>>> pixel_data = PixelData(metainfo=metainfo,
... image=image,
... featmap=featmap)
>>> print(pixel_data)
>>> (20, 40)
>>> # slice
>>> slice_data = pixel_data[10:20, 20:40]
>>> assert slice_data.shape == (10, 10)
>>> slice_data = pixel_data[10, 20]
>>> assert slice_data.shape == (1, 1)
>>> # set
>>> pixel_data.map3 = torch.randint(0, 255, (20, 40))
>>> assert tuple(pixel_data.map3.shape) == (1, 20, 40)
>>> with self.assertRaises(AssertionError):
... # The dimension must be 3 or 2
... pixel_data.map2 = torch.randint(0, 255, (1, 3, 20, 40))
"""
def __setattr__(self, name: str, value: Union[torch.Tensor, np.ndarray]):
"""Set attributes of ``PixelData``.
If the dimension of value is 2 and its shape meet the demand, it
will automatically expend its channel-dimension.
Args:
name (str): The key to access the value, stored in `PixelData`.
value (Union[torch.Tensor, np.ndarray]): The value to store in.
The type of value must be `torch.Tensor` or `np.ndarray`,
and its shape must meet the requirements of `PixelData`.
"""
if name in ('_metainfo_fields', '_data_fields'):
if not hasattr(self, name):
super().__setattr__(name, value)
else:
raise AttributeError(
f'{name} has been used as a '
f'private attribute, which is immutable. ')
else:
assert isinstance(value, (torch.Tensor, np.ndarray)), \
f'Can set {type(value)}, only support' \
f' {(torch.Tensor, np.ndarray)}'
if self.shape:
assert tuple(value.shape[-2:]) == self.shape, (
f'the height and width of '
f'values {tuple(value.shape[-2:])} is '
f'not consistent with'
f' the length of this '
f':obj:`PixelData` '
f'{self.shape} ')
assert value.ndim in [
2, 3
], f'The dim of value must be 2 or 3, but got {value.ndim}'
if value.ndim == 2:
value = value[None]
warnings.warn(f'The shape of value will convert from '
f'{value.shape[-2:]} to {value.shape}')
super().__setattr__(name, value)
# TODO torch.Long/bool
def __getitem__(self, item: Sequence[Union[int, slice]]) -> 'PixelData':
"""
Args:
item (Sequence[Union[int, slice]]): get the corresponding values
according to item.
Returns:
obj:`PixelData`: Corresponding values.
"""
new_data = self.__class__(metainfo=self.metainfo)
if isinstance(item, tuple):
assert len(item) == 2, 'Only support slice height and width'
tmp_item: List[slice] = list()
for index, single_item in enumerate(item[::-1]):
if isinstance(single_item, int):
tmp_item.insert(
0, slice(single_item, None, self.shape[-index - 1]))
elif isinstance(single_item, slice):
tmp_item.insert(0, single_item)
else:
raise TypeError(
'The type of element in input must be int or slice, '
f'but got {type(single_item)}')
tmp_item.insert(0, slice(None, None, None))
item = tuple(tmp_item)
for k, v in self.items():
setattr(new_data, k, v[item])
else:
raise TypeError(
f'Unsupported type {type(item)} for slicing PixelData')
return new_data
@property
def shape(self):
"""The shape of pixel data."""
if len(self._data_fields) > 0:
return tuple(self.values()[0].shape[-2:])
else:
return None
# TODO padding, resize