Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# Copyright (c) OpenMMLab. All rights reserved.
from abc import abstractmethod
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from mmengine.data import BaseDataElement
from mmengine.optim import OptimWrapper
from mmengine.registry import MODELS
from mmengine.utils import is_list_of
from ..base_module import BaseModule
ForwardResults = Union[Dict[str, torch.Tensor], List[BaseDataElement],
Tuple[torch.Tensor], torch.Tensor]
class BaseModel(BaseModule):
"""Base class for all algorithmic models.
BaseModel implements the basic functions of the algorithmic model, such as
weights initialize, batch inputs preprocess(see more information in
:class:`BaseDataPreprocessor`), parse losses, and update model parameters.
Subclasses inherit from BaseModel only need to implement the forward
method, which implements the logic to calculate loss and predictions,
then can be trained in the runner.
Examples:
>>> @MODELS.register_module()
>>> class ToyModel(BaseModel):
>>>
>>> def __init__(self):
>>> super().__init__()
>>> self.backbone = nn.Sequential()
>>> self.backbone.add_module('conv1', nn.Conv2d(3, 6, 5))
>>> self.backbone.add_module('pool', nn.MaxPool2d(2, 2))
>>> self.backbone.add_module('conv2', nn.Conv2d(6, 16, 5))
>>> self.backbone.add_module('fc1', nn.Linear(16 * 5 * 5, 120))
>>> self.backbone.add_module('fc2', nn.Linear(120, 84))
>>> self.backbone.add_module('fc3', nn.Linear(84, 10))
>>>
>>> self.criterion = nn.CrossEntropyLoss()
>>>
>>> def forward(self, batch_inputs, data_samples, mode='tensor'):
>>> data_samples = torch.stack(data_samples)
>>> if mode == 'tensor':
>>> return self.backbone(batch_inputs)
>>> elif mode == 'predict':
>>> feats = self.backbone(batch_inputs)
>>> predictions = torch.argmax(feats, 1)
>>> return predictions
>>> elif mode == 'loss':
>>> feats = self.backbone(batch_inputs)
>>> loss = self.criterion(feats, data_samples)
>>> return dict(loss=loss)
Args:
data_preprocessor (dict, optional): The pre-process config of
:class:`BaseDataPreprocessor`.
init_cfg (dict, optional): The weight initialized config for
:class:`BaseModule`.
Attributes:
data_preprocessor (:obj:`BaseDataPreprocessor`): Used for
pre-processing data sampled by dataloader to the format accepted by
:meth:`forward`.
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self,
data_preprocessor: Optional[Union[dict, nn.Module]] = None,
init_cfg: Optional[dict] = None):
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
super().__init__(init_cfg)
if data_preprocessor is None:
data_preprocessor = dict(type='BaseDataPreprocessor')
if isinstance(data_preprocessor, nn.Module):
self.data_preprocessor = data_preprocessor
elif isinstance(data_preprocessor, dict):
self.data_preprocessor = MODELS.build(data_preprocessor)
else:
raise TypeError('data_preprocessor should be a `dict` or '
f'`nn.Module` instance, but got '
f'{type(data_preprocessor)}')
def train_step(self, data: List[dict],
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
"""Implements the default model training process including
preprocessing, model forward propagation, loss calculation,
optimization, and back-propagation.
During non-distributed training. If subclasses do not override the
:meth:`train_step`, :class:`EpochBasedTrainLoop` or
:class:`IterBasedTrainLoop` will call this method to update model
parameters. The default parameter update process is as follows:
1. Calls ``self.data_processor(data, training=False) to collext
batch_inputs and corresponding data_samples(labels).
2. Calls ``self(batch_inputs, data_samples, mode='loss')`` to get raw
loss
3. Calls ``self.parse_losses`` to get ``parsed_losses`` tensor used to
backward and dict of loss tensor used to log messages.
4. Calls ``optim_wrapper.update_params(loss)`` to update model.
Args:
data (List[dict]): Data sampled from dataloader.
optim_wrapper (OptimWrapper): OptimWrapper instance
used to update model parameters.
Returns:
Dict[str, torch.Tensor]: A ``dict`` of tensor for logging.
"""
# enable automatic mixed precision training context.
with optim_wrapper.precision_context():
batch_inputs, data_samples = self.data_preprocessor(data, True)
losses = self(batch_inputs, data_samples, mode='loss')
parsed_losses, log_vars = self.parse_losses(losses)
optim_wrapper.update_params(parsed_losses)
return log_vars
def val_step(self, data: List[dict]) -> List[BaseDataElement]:
"""Gets the predictions of given data.
Calls ``self.data_preprocessor(data, False)`` and
``self(inputs, data_sample, mode='predict')`` in order. Return the
predictions which will be passed to evaluator.
Args:
data (List[dict]): Data sampled from dataloader.
Returns:
List[BaseDataElement]: The predictions of given data.
"""
inputs, data_sample = self.data_preprocessor(data, False)
return self(inputs, data_sample, mode='predict')
def test_step(self, data: List[dict]) -> List[BaseDataElement]:
"""``BaseModel`` implements ``test_step`` the same as ``val_step``.
Args:
data (List[dict]): Data sampled from dataloader.
Returns:
List[BaseDataElement]: The predictions of given data.
"""
inputs, data_sample = self.data_preprocessor(data, False)
return self(inputs, data_sample, mode='predict')
def parse_losses(
self, losses: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""Parses the raw outputs (losses) of the network.
Args:
losses (dict): Raw output of the network, which usually contain
losses and other necessary information.
Returns:
tuple[Tensor, dict]: There are two elements. The first is the
loss tensor passed to optim_wrapper which may be a weighted sum of
all losses, and the second is log_vars which will be sent to the
logger.
"""
log_vars = OrderedDict()
for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor):
log_vars[loss_name] = loss_value.mean()
elif is_list_of(loss_value, torch.Tensor):
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
else:
raise TypeError(
f'{loss_name} is not a tensor or list of tensors')
loss = sum(value for key, value in log_vars.items() if 'loss' in key)
log_vars['loss'] = loss
return loss, log_vars
def to(self, device: Optional[Union[int, torch.device]], *args,
**kwargs) -> nn.Module:
"""Overrides this method to call :meth:`BaseDataPreprocessor.to`
additionally.
Args:
device (int or torch.device, optional): the desired device of the
parameters and buffers in this module.
Returns:
nn.Module: The model itself.
"""
self.data_preprocessor.to(device)
return super().to(device)
def cuda(self, *args, **kwargs) -> nn.Module:
"""Overrides this method to call :meth:`BaseDataPreprocessor.cuda`
additionally.
Returns:
nn.Module: The model itself.
"""
self.data_preprocessor.cuda()
return super().cuda()
def cpu(self, *args, **kwargs) -> nn.Module:
"""Overrides this method to call :meth:`BaseDataPreprocessor.cpu`
additionally.
Returns:
nn.Module: The model itself.
"""
self.data_preprocessor.cpu()
return super().cpu()
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
@abstractmethod
def forward(self,
batch_inputs: torch.Tensor,
data_samples: Optional[List[BaseDataElement]] = None,
mode: str = 'tensor') -> ForwardResults:
"""Returns losses or predictions of training, validation, testing, and
simple inference process.
``forward`` method of BaseModel is an abstract method, its subclasses
must implement this method.
Accepts ``batch_inputs`` and ``data_samples`` processed by
:attr:`data_preprocessor`, and returns results according to mode
arguments.
During non-distributed training, validation, and testing process,
``forward`` will be called by ``BaseModel.train_step``,
``BaseModel.val_step`` and ``BaseModel.val_step`` directly.
During distributed data parallel training process,
``MMSeparateDistributedDataParallel.train_step`` will first call
``DistributedDataParallel.forward`` to enable automatic
gradient synchronization, and then call ``forward`` to get training
loss.
Args:
batch_inputs (torch.Tensor): batch input tensor collated by
:attr:`data_preprocessor`.
data_samples (List[BaseDataElement], optional):
data samples collated by :attr:`data_preprocessor`.
mode (str): mode should be one of ``loss``, ``predict`` and
``tensor``
- ``loss``: Called by ``train_step`` and return loss ``dict``
used for logging
- ``predict``: Called by ``val_step`` and ``test_step``
and return list of ``BaseDataElement`` results used for
computing metric.
- ``tensor``: Called by custom use to get ``Tensor`` type
results.
Returns:
ForwardResults:
- If ``mode == loss``, return a ``dict`` of loss tensor used
for backward and logging.
- If ``mode == predict``, return a ``list`` of
:obj:`BaseDataElement` for computing metric
and getting inference result.
- If ``mode == tensor``, return a tensor or ``tuple`` of tensor
or ``dict of tensor for custom use.
"""