diff --git a/mmengine/hooks/__init__.py b/mmengine/hooks/__init__.py index ecc72e6a8e5531e26534f516b4f0647bf41e2547..fe326332e2045ee6b9e9329b5eddb3073cf684aa 100644 --- a/mmengine/hooks/__init__.py +++ b/mmengine/hooks/__init__.py @@ -6,7 +6,6 @@ from .hook import Hook from .iter_timer_hook import IterTimerHook from .logger_hook import LoggerHook from .naive_visualization_hook import NaiveVisualizationHook -from .optimizer_hook import OptimizerHook from .param_scheduler_hook import ParamSchedulerHook from .runtime_info_hook import RuntimeInfoHook from .sampler_seed_hook import DistSamplerSeedHook @@ -14,6 +13,6 @@ from .sync_buffer_hook import SyncBuffersHook __all__ = [ 'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook', - 'OptimizerHook', 'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook', - 'LoggerHook', 'NaiveVisualizationHook', 'EMAHook', 'RuntimeInfoHook' + 'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook', 'LoggerHook', + 'NaiveVisualizationHook', 'EMAHook', 'RuntimeInfoHook' ] diff --git a/mmengine/hooks/optimizer_hook.py b/mmengine/hooks/optimizer_hook.py deleted file mode 100644 index c00d9deafc0671b2071b65f9ac724869cc605d89..0000000000000000000000000000000000000000 --- a/mmengine/hooks/optimizer_hook.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import logging -from typing import List, Optional, Sequence - -import torch -from torch.nn.parameter import Parameter -from torch.nn.utils import clip_grad - -from mmengine.registry import HOOKS -from .hook import Hook - -DATA_BATCH = Optional[Sequence[dict]] - - -@HOOKS.register_module() -class OptimizerHook(Hook): - """A hook contains custom operations for the optimizer. - - Args: - grad_clip (dict, optional): A config dict to control the clip_grad. - Defaults to None. - detect_anomalous_params (bool): This option is only used for - debugging which will slow down the training speed. - Detect anomalous parameters that are not included in - the computational graph with ``loss`` as the root. - There are two cases - - Parameters were not used during - forward pass. - - Parameters were not used to produce - loss. - Defaults to False. - """ - - priority = 'HIGH' - - def __init__(self, - grad_clip: Optional[dict] = None, - detect_anomalous_params: bool = False) -> None: - self.grad_clip = grad_clip - self.detect_anomalous_params = detect_anomalous_params - - def clip_grads(self, params: List[Parameter]) -> Optional[torch.Tensor]: - """Clip the gradients of parameters. - - Args: - params (list[Parameter]): Model's parameters. - - Returns: - Optional[torch.Tensor]: Total norm of the parameters if there is - at least one param requiring gradient, else None. - """ - params = list( - filter(lambda p: p.requires_grad and p.grad is not None, params)) - if len(params) > 0: - return clip_grad.clip_grad_norm_(params, **self.grad_clip) - return None - - def after_train_iter(self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None, - outputs: Optional[dict] = None) -> None: - """All operations need to be finished after each training iteration. - - This function will finish following 3 operations: - - - Detect any anomalous parameters which are not included in the - training graph. (optional) - - - Compute the gradient of model parameters. - - - Clip the gradients of each parameter. (optional) - - - Update model parameters with gradients. - - Args: - runner (Runner): The runner of the training process. - batch_idx (int): The index of the current batch in the train loop. - data_batch (Sequence[dict], optional): Data from dataloader. - In order to keep this interface consistent with other hooks, - we keep ``data_batch`` here. Defaults to None. - outputs (dict, optional): Outputs from model. - In order to keep this interface consistent with other hooks, - we keep ``outputs`` here. Defaults to None. - """ - runner.optim_wrapper.zero_grad() - if self.detect_anomalous_params: - self.detect_anomalous_parameters(runner.outputs['loss'], runner) - runner.outputs['loss'].backward() - - if self.grad_clip is not None: - grad_norm = self.clip_grads(runner.model.parameters()) - if grad_norm is not None: - # Add grad norm to the logger - runner.message_hub.update_scalar('train/grad_norm', - float(grad_norm)) - runner.optim_wrapper.step() - - def detect_anomalous_parameters(self, loss: torch.Tensor, runner) -> None: - """Detect anomalous parameters that are not included in the graph. - - Args: - loss (torch.Tensor): The loss of current iteration. - runner (Runner): The runner of the training process. - """ - logger = runner.logger - parameters_in_graph = set() - visited = set() - - def traverse(grad_fn): - if grad_fn is None: - return - if grad_fn not in visited: - visited.add(grad_fn) - if hasattr(grad_fn, 'variable'): - parameters_in_graph.add(grad_fn.variable) - parents = grad_fn.next_functions - if parents is not None: - for parent in parents: - grad_fn = parent[0] - traverse(grad_fn) - - traverse(loss.grad_fn) - for n, p in runner.model.named_parameters(): - if p not in parameters_in_graph and p.requires_grad: - logger.log( - level=logging.ERROR, - msg=f'{n} with shape {p.size()} is not ' - f'in the computational graph \n') diff --git a/mmengine/hooks/runtime_info_hook.py b/mmengine/hooks/runtime_info_hook.py index 56186ff5adf2cc85bd47d0f2ba1ea0a6ba048224..091ced5075494798a258850952abbfe8d6946c52 100644 --- a/mmengine/hooks/runtime_info_hook.py +++ b/mmengine/hooks/runtime_info_hook.py @@ -59,7 +59,7 @@ class RuntimeInfoHook(Hook): outputs: Optional[dict] = None) -> None: """Update ``log_vars`` in model outputs every iteration.""" if outputs is not None: - for key, value in outputs['log_vars'].items(): + for key, value in outputs.items(): runner.message_hub.update_scalar(f'train/{key}', value) def after_val_epoch(self, diff --git a/mmengine/model/__init__.py b/mmengine/model/__init__.py index 082f91317064c29ac649eb29a8a5faaeac104fd1..0b7f08e7d75be808db9ec13137febd011e1a639a 100644 --- a/mmengine/model/__init__.py +++ b/mmengine/model/__init__.py @@ -1,11 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. from .averaged_model import (ExponentialMovingAverage, MomentumAnnealingEMA, StochasticWeightAverage) -from .wrappers import (MMDataParallel, MMDistributedDataParallel, - is_model_wrapper) +from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor +from .base_module import BaseModule +from .utils import detect_anomalous_params, merge_dict, stach_batch_imgs +from .wrappers import (MMDistributedDataParallel, + MMSeparateDistributedDataParallel, is_model_wrapper) __all__ = [ - 'MMDistributedDataParallel', 'MMDataParallel', 'is_model_wrapper', - 'StochasticWeightAverage', 'ExponentialMovingAverage', - 'MomentumAnnealingEMA' + 'MMDistributedDataParallel', 'is_model_wrapper', 'StochasticWeightAverage', + 'ExponentialMovingAverage', 'MomentumAnnealingEMA', 'BaseModel', + 'BaseDataPreprocessor', 'ImgDataPreprocessor', + 'MMSeparateDistributedDataParallel', 'BaseModule', 'stach_batch_imgs', + 'merge_dict', 'detect_anomalous_params' ] diff --git a/mmengine/model/base_model/__init__.py b/mmengine/model/base_model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..696c83adb730a7e2fec404950229103a4c4581de --- /dev/null +++ b/mmengine/model/base_model/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base_model import BaseModel +from .data_preprocessor import (BaseDataElement, BaseDataPreprocessor, + ImgDataPreprocessor) + +__all__ = [ + 'BaseModel', 'BaseDataElement', 'ImgDataPreprocessor', + 'BaseDataPreprocessor' +] diff --git a/mmengine/model/base_model/base_model.py b/mmengine/model/base_model/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ede27b7a745a09ed816263b6c0863cfa37368f9a --- /dev/null +++ b/mmengine/model/base_model/base_model.py @@ -0,0 +1,256 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import abstractmethod +from collections import OrderedDict +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from mmengine.data import BaseDataElement +from mmengine.optim import OptimWrapper +from mmengine.registry import MODELS +from mmengine.utils import is_list_of +from ..base_module import BaseModule + +ForwardResults = Union[Dict[str, torch.Tensor], List[BaseDataElement], + Tuple[torch.Tensor], torch.Tensor] + + +class BaseModel(BaseModule): + """Base class for all algorithmic models. + + BaseModel implements the basic functions of the algorithmic model, such as + weights initialize, batch inputs preprocess(see more information in + :class:`BaseDataPreprocessor`), parse losses, and update model parameters. + + Subclasses inherit from BaseModel only need to implement the forward + method, which implements the logic to calculate loss and predictions, + then can be trained in the runner. + + Examples: + >>> @MODELS.register_module() + >>> class ToyModel(BaseModel): + >>> + >>> def __init__(self): + >>> super().__init__() + >>> self.backbone = nn.Sequential() + >>> self.backbone.add_module('conv1', nn.Conv2d(3, 6, 5)) + >>> self.backbone.add_module('pool', nn.MaxPool2d(2, 2)) + >>> self.backbone.add_module('conv2', nn.Conv2d(6, 16, 5)) + >>> self.backbone.add_module('fc1', nn.Linear(16 * 5 * 5, 120)) + >>> self.backbone.add_module('fc2', nn.Linear(120, 84)) + >>> self.backbone.add_module('fc3', nn.Linear(84, 10)) + >>> + >>> self.criterion = nn.CrossEntropyLoss() + >>> + >>> def forward(self, batch_inputs, data_samples, mode='tensor'): + >>> data_samples = torch.stack(data_samples) + >>> if mode == 'tensor': + >>> return self.backbone(batch_inputs) + >>> elif mode == 'predict': + >>> feats = self.backbone(batch_inputs) + >>> predictions = torch.argmax(feats, 1) + >>> return predictions + >>> elif mode == 'loss': + >>> feats = self.backbone(batch_inputs) + >>> loss = self.criterion(feats, data_samples) + >>> return dict(loss=loss) + + Args: + init_cfg (dict, optional): The weight initialized config for + :class:`BaseModule`. + data_preprocessor (dict, optional): The pre-process config of + :class:`BaseDataPreprocessor`. + + Attributes: + init_cfg (dict, optional): Initialization config dict. + data_preprocessor (:obj:`BaseDataPreprocessor`): Used for + pre-processing data sampled by dataloader to the format accepted by + :meth:`forward`. + """ + + def __init__(self, + init_cfg: Optional[dict] = None, + data_preprocessor: Optional[Union[dict, nn.Module]] = None): + super().__init__(init_cfg) + if data_preprocessor is None: + data_preprocessor = dict(type='BaseDataPreprocessor') + if isinstance(data_preprocessor, nn.Module): + self.data_preprocessor = data_preprocessor + elif isinstance(data_preprocessor, dict): + self.data_preprocessor = MODELS.build(data_preprocessor) + else: + raise TypeError('data_preprocessor should be a `dict` or ' + f'`nn.Module` instance, but got ' + f'{type(data_preprocessor)}') + + def train_step(self, data: List[dict], + optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]: + """Implements the default model training process including + preprocessing, model forward propagation, loss calculation, + optimization, and back-propagation. + + During non-distributed training. If subclasses do not override the + :meth:`train_step`, :class:`EpochBasedTrainLoop` or + :class:`IterBasedTrainLoop` will call this method to update model + parameters. The default parameter update process is as follows: + + 1. Calls ``self.data_processor(data, training=False) to collext + batch_inputs and corresponding data_samples(labels). + 2. Calls ``self(batch_inputs, data_samples, mode='loss')`` to get raw + loss + 3. Calls ``self.parse_losses`` to get ``parsed_losses`` tensor used to + backward and dict of loss tensor used to log messages. + 4. Calls ``optim_wrapper.update_params(loss)`` to update model. + + Args: + data (List[dict]): Data sampled from dataloader. + optim_wrapper (OptimWrapper): OptimWrapper instance + used to update model parameters. + + Returns: + Dict[str, torch.Tensor]: A ``dict`` of tensor for logging. + """ + # enable automatic mixed precision training context. + with optim_wrapper.precision_context(): + batch_inputs, data_samples = self.data_preprocessor(data, True) + losses = self(batch_inputs, data_samples, mode='loss') + parsed_losses, log_vars = self.parse_losses(losses) + optim_wrapper.update_params(parsed_losses) + return log_vars + + def val_step(self, data: List[dict]) -> List[BaseDataElement]: + """Gets the predictions of given data. + + Calls ``self.data_preprocessor(data, False)`` and + ``self(inputs, data_sample, mode='predict')`` in order. Return the + predictions which will be passed to evaluator. + + Args: + data (List[dict]): Data sampled from dataloader. + + Returns: + List[BaseDataElement]: The predictions of given data. + """ + inputs, data_sample = self.data_preprocessor(data, False) + return self(inputs, data_sample, mode='predict') + + def test_step(self, data: List[dict]) -> List[BaseDataElement]: + """``BaseModel`` implements ``test_step`` the same as ``val_step``. + + Args: + data (List[dict]): Data sampled from dataloader. + + Returns: + List[BaseDataElement]: The predictions of given data. + """ + inputs, data_sample = self.data_preprocessor(data, False) + return self(inputs, data_sample, mode='predict') + + def parse_losses( + self, losses: Dict[str, torch.Tensor] + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """Parses the raw outputs (losses) of the network. + + Args: + losses (dict): Raw output of the network, which usually contain + losses and other necessary information. + + Returns: + tuple[Tensor, dict]: There are two elements. The first is the + loss tensor passed to optim_wrapper which may be a weighted sum of + all losses, and the second is log_vars which will be sent to the + logger. + """ + log_vars = OrderedDict() + for loss_name, loss_value in losses.items(): + if isinstance(loss_value, torch.Tensor): + log_vars[loss_name] = loss_value.mean() + elif is_list_of(loss_value, torch.Tensor): + log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) + else: + raise TypeError( + f'{loss_name} is not a tensor or list of tensors') + + loss = sum(value for key, value in log_vars.items() if 'loss' in key) + log_vars['loss'] = loss + + return loss, log_vars + + def to(self, device: Optional[Union[int, torch.device]], *args, + **kwargs) -> nn.Module: + """Overrides this method to set the ``device`` attribute of + :obj:`BaseDataPreprocessor` additionally + + Args: + device (int or torch.device, optional): the desired device of the + parameters and buffers in this module. + + Returns: + nn.Module: The model itself. + """ + self.data_preprocessor.device = torch.device(device) + return super().to(device) + + def cuda(self, *args, **kwargs) -> nn.Module: + """Overrides this method to set the ``device`` attribute of + :obj:`BaseDataPreprocessor` additionally + + Returns: + nn.Module: The model itself. + """ + self.data_preprocessor.device = torch.cuda.current_device() + return super().cuda() + + @abstractmethod + def forward(self, + batch_inputs: torch.Tensor, + data_samples: Optional[List[BaseDataElement]] = None, + mode: str = 'tensor') -> ForwardResults: + """Returns losses or predictions of training, validation, testing, and + simple inference process. + + ``forward`` method of BaseModel is an abstract method, its subclasses + must implement this method. + + Accepts ``batch_inputs`` and ``data_samples`` processed by + :attr:`data_preprocessor`, and returns results according to mode + arguments. + + During non-distributed training, validation, and testing process, + ``forward`` will be called by ``BaseModel.train_step``, + ``BaseModel.val_step`` and ``BaseModel.val_step`` directly. + + During distributed data parallel training process, + ``MMSeparateDistributedDataParallel.train_step`` will first call + ``DistributedDataParallel.forward`` to enable automatic + gradient synchronization, and then call ``forward`` to get training + loss. + + Args: + batch_inputs (torch.Tensor): batch input tensor collated by + :attr:`data_preprocessor`. + data_samples (List[BaseDataElement], optional): + data samples collated by :attr:`data_preprocessor`. + mode (str): mode should be one of ``loss``, ``predict`` and + ``tensor`` + + - ``loss``: Called by ``train_step`` and return loss ``dict`` + used for logging + - ``predict``: Called by ``val_step`` and ``test_step`` + and return list of ``BaseDataElement`` results used for + computing metric. + - ``tensor``: Called by custom use to get ``Tensor`` type + results. + + Returns: + ForwardResults: + + - If ``mode == loss``, return a ``dict`` of loss tensor used + for backward and logging. + - If ``mode == predict``, return a ``list`` of + :obj:`BaseDataElement` for computing metric + and getting inference result. + - If ``mode == tensor``, return a tensor or ``tuple`` of tensor + or ``dict of tensor for custom use. + """ diff --git a/mmengine/model/base_model/data_preprocessor.py b/mmengine/model/base_model/data_preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..2b9d2cb3eefa2f51a782bd8d7c243208fb26917d --- /dev/null +++ b/mmengine/model/base_model/data_preprocessor.py @@ -0,0 +1,213 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn + +from mmengine.data import BaseDataElement +from mmengine.registry import MODELS +from ..utils import stach_batch_imgs + + +@MODELS.register_module() +class BaseDataPreprocessor(nn.Module): + """Base data pre-processor used for collating and copying data to the + target device. + + ``BaseDataPreprocessor`` performs data pre-processing according to the + following steps: + + - Collates the data sampled from dataloader. + - Copies data to the target device. + - Stacks the input tensor at the first dimension. + + Subclasses inherit from ``BaseDataPreprocessor`` could override the + forward method to implement custom data pre-processing, such as + batch-resize, MixUp, or CutMix. + + Args: + device (int or torch.device): Target device. + + Warnings: + Each item of data sampled from dataloader must be a dict and at least + contain the ``inputs`` key. Furthermore, the value of ``inputs`` + must be a ``Tensor`` with the same shape. + """ + + def __init__(self, device: Union[int, torch.device] = 'cpu'): + super().__init__() + self.device = device + + def collate_data( + self, + data: Sequence[dict]) -> Tuple[List[torch.Tensor], Optional[list]]: + """Collating and copying data to the target device. + + Collates the data sampled from dataloader into a list of tensor and + list of labels, and then copies tensor to the target device. + + Subclasses could override it to be compatible with the custom format + data sampled from custom dataloader. + + Args: + data (Sequence[dict]): Data sampled from dataloader. + + Returns: + Tuple[List[torch.Tensor], Optional[list]]: Unstacked list of input + tensor and list of labels at target device. + """ + inputs = [_data['inputs'].to(self.device) for _data in data] + batch_data_samples: List[BaseDataElement] = [] + # Model can get predictions without any data samples. + for _data in data: + if 'data_sample' in _data: + batch_data_samples.append(_data['data_sample']) + # Move data from CPU to corresponding device. + batch_data_samples = [ + data_sample.to(self.device) for data_sample in batch_data_samples + ] + + if not batch_data_samples: + batch_data_samples = None # type: ignore + + return inputs, batch_data_samples + + def forward(self, + data: Sequence[dict], + training: bool = False) -> Tuple[torch.Tensor, Optional[list]]: + """Preprocesses the data into the model input format. + + After the data pre-processing of :meth:`collate_data`, ``forward`` + will stack the input tensor list to a batch tensor at the first + dimension. + + Args: + data (Sequence[dict]): data sampled from dataloader. + training (bool): Whether to enable training time augmentation. + + Returns: + Tuple[torch.Tensor, Optional[list]]: Data in the same format as the + model input. + """ + inputs, batch_data_samples = self.collate_data(data) + batch_inputs = torch.stack(inputs, dim=0) + return batch_inputs, batch_data_samples + + def to(self, device: Optional[Union[int, torch.device]], *args, + **kwargs) -> nn.Module: + """Overrides this method to set the :attr:`device` + + Args: + device (int or torch.device, optional): The desired device of the + parameters and buffers in this module. + + Returns: + nn.Module: The model itself. + """ + self.device = torch.device(device) + return super().to(device) + + def cuda(self, *args, **kwargs) -> nn.Module: + """Overrides this method to set the :attr:`device` + + Returns: + nn.Module: The model itself. + """ + self.device = torch.cuda.current_device() + return super().cuda() + + +@MODELS.register_module() +class ImgDataPreprocessor(BaseDataPreprocessor): + """Image pre-processor for normalization and bgr to rgb conversion. + + Accepts the data sampled by the dataloader, and preprocesses it into the + format of the model input. ``ImgDataPreprocessor`` provides the + basic data pre-processing as follows + + - Collates and moves data to the target device. + - Converts inputs from bgr to rgb if the shape of input is (3, H, W). + - Normalizes image with defined std and mean. + - Pads inputs to the maximum size of current batch with defined + ``pad_value``. The padding size can be divisible by a defined + ``pad_size_divisor`` + - Stack inputs to batch_inputs. + + For ``ImgDataPreprocessor``, the dimension of the single inputs must be + (3, H, W). + + Note: + ``ImgDataPreprocessor`` and its subclass is built in the + constructor of :class:`BaseDataset`. + + Args: + mean (Sequence[float or int]): The pixel mean of image channels. If + ``bgr_to_rgb=True`` it means the mean value of R, G, B channels. + If ``mean`` and ``std`` are not specified, ``ImgDataPreprocessor`` + will normalize images to [-1, 1]. Defaults to (127.5, 127.5, + 127.5). + std (Sequence[float or int]): The pixel standard deviation of image + channels. If ``bgr_to_rgb=True`` it means the standard deviation of + R, G, B channels. If ``mean`` and ``std`` are not specified, + ImgDataPreprocessor will normalize images to [-1, 1]. Defaults + to (127.5, 127.5, 127.5). + pad_size_divisor (int): The size of padded image should be + divisible by ``pad_size_divisor``. Defaults to 1. + pad_value (float or int): The padded pixel value. Defaults to 0. + bgr_to_rgb (bool): whether to convert image from BGR to RGB. + Defaults to False. + rgb_to_bgr (bool): whether to convert image from RGB to RGB. + Defaults to False. + device (int or torch.device): Target device. + """ + + def __init__(self, + mean: Sequence[Union[float, int]] = (127.5, 127.5, 127.5), + std: Sequence[Union[float, int]] = (127.5, 127.5, 127.5), + pad_size_divisor: int = 1, + pad_value: Union[float, int] = 0, + bgr_to_rgb: bool = False, + rgb_to_bgr: bool = False, + device: Union[int, torch.device] = 'cpu'): + super().__init__(device) + assert len(mean) == 3 or len(mean) == 1, ( + 'The length of mean should be 1 or 3 to be compatible with RGB ' + f'or gray image, but got {len(mean)}') + assert len(std) == 3 or len(std) == 1, ( + 'The length of std should be 1 or 3 to be compatible with RGB ' + f'or gray image, but got {len(std)}') + assert not (bgr_to_rgb and rgb_to_bgr), ( + '`bgr2rgb` and `rgb2bgr` cannot be set to True at the same time') + self.channel_conversion = rgb_to_bgr or bgr_to_rgb + self.register_buffer('mean', torch.tensor(mean).view(-1, 1, 1), False) + self.register_buffer('std', torch.tensor(std).view(-1, 1, 1), False) + self.pad_size_divisor = pad_size_divisor + self.pad_value = pad_value + + def forward(self, + data: Sequence[dict], + training: bool = False) -> Tuple[torch.Tensor, Optional[list]]: + """Performs normalizationã€padding and bgr2rgb conversion based on + ``BaseDataPreprocessor``. + + Args: + data (Sequence[dict]): data sampled from dataloader. + training (bool): Whether to enable training time augmentation. If + subclasses override this method, they can perform different + preprocessing strategies for training and testing based on the + value of ``training``. + + Returns: + Tuple[torch.Tensor, Optional[list]]: Data in the same format as the + model input. + """ + inputs, batch_data_samples = self.collate_data(data) + # channel transform + if self.channel_conversion: + inputs = [_input[[2, 1, 0], ...] for _input in inputs] + # Normalization. + inputs = [(_input - self.mean) / self.std for _input in inputs] + # Pad and stack Tensor. + batch_inputs = stach_batch_imgs(inputs, self.pad_size_divisor, + self.pad_value) + return batch_inputs, batch_data_samples diff --git a/mmengine/model/base_module.py b/mmengine/model/base_module.py index 0a38bbf5292686598c5c3b98a901d8aa3aa96b0b..89f140dccfba742c57b4ae1427d50d5a5d468166 100644 --- a/mmengine/model/base_module.py +++ b/mmengine/model/base_module.py @@ -91,7 +91,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta): logger = MMLogger.get_current_instance() logger_name = logger.instance_name - from .utils.weight_init import initialize, update_init_info + from .utils import initialize, update_init_info module_name = self.__class__.__name__ if not self._is_init: if self.init_cfg: diff --git a/mmengine/model/utils/weight_init.py b/mmengine/model/utils.py similarity index 84% rename from mmengine/model/utils/weight_init.py rename to mmengine/model/utils.py index 1289d7f117f9dbdcb7d20e44c4b59b466fc989e0..29cc779b9f250e8d6293f2ff73492cddb17752d8 100644 --- a/mmengine/model/utils/weight_init.py +++ b/mmengine/model/utils.py @@ -1,11 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy +import logging import math import warnings +from typing import List, Union import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F from torch import Tensor from mmengine.logging.logger import MMLogger, print_log @@ -668,3 +671,131 @@ def trunc_normal_(tensor: Tensor, b (float): the maximum cutoff value. """ return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def stach_batch_imgs(tensor_list: List[torch.Tensor], + pad_size_divisor: int = 1, + pad_value: Union[int, float] = 0) -> torch.Tensor: + """Stack multiple tensors to form a batch and pad the images to the max + shape use the right bottom padding mode in these images. If + ``pad_size_divisor > 0``, add padding to ensure the shape of each dim is + divisible by ``pad_size_divisor``. + + Args: + tensor_list (List[Tensor]): A list of tensors with the same dim. + pad_size_divisor (int): If ``pad_size_divisor > 0``, add padding + to ensure the shape of each dim is divisible by + ``pad_size_divisor``. This depends on the model, and many + models need to be divisible by 32. Defaults to 1 + pad_value (int, float): The padding value. Defaults to 0. + + Returns: + Tensor: The 4D-tensor. + """ + assert isinstance( + tensor_list, + list), (f'Expected input type to be list, but got {type(tensor_list)}') + assert tensor_list, '`tensor_list` could not be an empty list' + assert len({ + tensor.ndim + for tensor in tensor_list + }) == 1, (f'Expected the dimensions of all tensors must be the same, ' + f'but got {[tensor.ndim for tensor in tensor_list]}') + + dim = tensor_list[0].dim() + num_img = len(tensor_list) + all_sizes: torch.Tensor = torch.Tensor( + [tensor.shape for tensor in tensor_list]) + max_sizes = torch.ceil( + torch.max(all_sizes, dim=0)[0] / pad_size_divisor) * pad_size_divisor + padded_sizes = max_sizes - all_sizes + # The first dim normally means channel, which should not be padded. + padded_sizes[:, 0] = 0 + if padded_sizes.sum() == 0: + return torch.stack(tensor_list) + # `pad` is the second arguments of `F.pad`. If pad is (1, 2, 3, 4), + # it means that padding the last dim with 1(left) 2(right), padding the + # penultimate dim to 3(top) 4(bottom). The order of `pad` is opposite of + # the `padded_sizes`. Therefore, the `padded_sizes` needs to be reversed, + # and only odd index of pad should be assigned to keep padding "right" and + # "bottom". + pad = torch.zeros(num_img, 2 * dim, dtype=torch.int) + pad[:, 1::2] = padded_sizes[:, range(dim - 1, -1, -1)] + batch_tensor = [] + for idx, tensor in enumerate(tensor_list): + batch_tensor.append( + F.pad(tensor, tuple(pad[idx].tolist()), value=pad_value)) + return torch.stack(batch_tensor) + + +def detect_anomalous_params(loss: torch.Tensor, model) -> None: + parameters_in_graph = set() + visited = set() + + def traverse(grad_fn): + if grad_fn is None: + return + if grad_fn not in visited: + visited.add(grad_fn) + if hasattr(grad_fn, 'variable'): + parameters_in_graph.add(grad_fn.variable) + parents = grad_fn.next_functions + if parents is not None: + for parent in parents: + grad_fn = parent[0] + traverse(grad_fn) + + traverse(loss.grad_fn) + from mmengine import MMLogger + logger = MMLogger.get_current_instance() + for n, p in model.named_parameters(): + if p not in parameters_in_graph and p.requires_grad: + logger.log( + level=logging.ERROR, + msg=f'{n} with shape {p.size()} is not ' + f'in the computational graph \n') + + +def merge_dict(*args): + """Merge all dictionaries into one dictionary. + + If pytorch version >= 1.8, ``merge_dict`` will be wrapped + by ``torch.fx.wrap``, which will make ``torch.fx.symbolic_trace`` skip + trace ``merge_dict``. + + Note: + If a function needs to be traced by ``torch.fx.symbolic_trace``, + but inevitably needs to use ``update`` method of ``dict``(``update`` + is not traceable). It should use ``merge_dict`` to replace + ``xxx.update``. + + Args: + *args: dictionary needs to be merged. + + Returns: + dict: Merged dict from args + """ + output = dict() + for item in args: + assert isinstance( + item, + dict), (f'all arguments of merge_dict should be a dict, but got ' + f'{type(item)}') + output.update(item) + return output + + +# torch.fx is only available when pytorch version >= 1.8. +# If the subclass of `BaseModel` has multiple submodules, and each module +# will return a loss dict during training process, i.e., `TwoStageDetector` +# in mmdet. It should use `merge_dict` to get the total loss, rather than +# `loss.update` to keep model traceable. +try: + import torch.fx + + # make torch.fx skip trace `merge_dict`. + merge_dict = torch.fx.wrap(merge_dict) + +except ImportError: + warnings.warn('Cannot import torch.fx, `merge_dict` is a simple function ' + 'to merge multiple dicts') diff --git a/mmengine/model/wrappers/__init__.py b/mmengine/model/wrappers/__init__.py index 1cab521decd4ac9cac834efc01c62113c17fafcd..d6ece71384bda27e06d13c43d7b69d889dc1328a 100644 --- a/mmengine/model/wrappers/__init__.py +++ b/mmengine/model/wrappers/__init__.py @@ -1,5 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .data_parallel import MMDataParallel, MMDistributedDataParallel +from .distributed import MMDistributedDataParallel +from .seperate_distributed import MMSeparateDistributedDataParallel from .utils import is_model_wrapper -__all__ = ['MMDistributedDataParallel', 'MMDataParallel', 'is_model_wrapper'] +__all__ = [ + 'MMDistributedDataParallel', 'is_model_wrapper', + 'MMSeparateDistributedDataParallel' +] diff --git a/mmengine/model/wrappers/data_parallel.py b/mmengine/model/wrappers/data_parallel.py deleted file mode 100644 index d31b009c3352baa15028d594b4114170563f8a83..0000000000000000000000000000000000000000 --- a/mmengine/model/wrappers/data_parallel.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from itertools import chain - -import torch -from torch.nn.parallel import DataParallel -from torch.nn.parallel.distributed import (DistributedDataParallel, - _find_tensors) - -from mmengine.registry import MODEL_WRAPPERS -from mmengine.utils import TORCH_VERSION, digit_version - -MODEL_WRAPPERS.register_module(module=DataParallel) -MODEL_WRAPPERS.register_module(module=DistributedDataParallel) - - -@MODEL_WRAPPERS.register_module() -class MMDataParallel(DataParallel): - """There is no difference between MMDataParallel and pytorch's - DataParallel, "train_step" and "val_step" are added just to avoid bc - breaking. - - Warning: - MMDataParallel only supports single GPU training, if you - need to train with multiple GPUs, please use MMDistributedDataParallel - instead. If you have multiple GPUs and you just want to use - MMDataParallel, you can set the environment variable - ``CUDA_VISIBLE_DEVICES=0`` or instantiate ``MMDataParallel`` with - ``device_ids=[0]``. - """ - - def train_step(self, *inputs, **kwargs): - assert len(self.device_ids) == 1, \ - ('MMDataParallel only supports single GPU training, if you need to' - ' train with multiple GPUs, please use MMDistributedDataParallel' - ' instead.') - assert hasattr(self.module, 'train_step') - for t in chain(self.module.parameters(), self.module.buffers()): - if t.device != self.src_device_obj: - raise RuntimeError( - 'module must have its parameters and buffers ' - f'on device {self.src_device_obj} (device_ids[0]) but ' - f'found one of them on device: {t.device}') - return self.module.train_step(*inputs, **kwargs) - - def val_step(self, *inputs, **kwargs): - assert len(self.device_ids) == 1, \ - ('MMDataParallel only supports single GPU training, if you need to' - ' train with multiple GPUs, please use MMDistributedDataParallel' - ' instead.') - assert hasattr(self.module, 'val_step') - for t in chain(self.module.parameters(), self.module.buffers()): - if t.device != self.src_device_obj: - raise RuntimeError( - 'module must have its parameters and buffers ' - f'on device {self.src_device_obj} (device_ids[0]) but ' - f'found one of them on device: {t.device}') - return self.module.val_step(*inputs, **kwargs) - - -@MODEL_WRAPPERS.register_module() -class MMDistributedDataParallel(DistributedDataParallel): - """There is no difference between MMDistributedDataParallel and pytorch's - DistributedDataParallel, "train_step" and "val_step" are added just to - avoid bc breaking.""" - - def train_step(self, *inputs, **kwargs): - """train_step() API for module wrapped by DistributedDataParallel. - - This method is basically the same as - ``DistributedDataParallel.forward()``, while replacing - ``self.module.forward()`` with ``self.module.train_step()``. - It is compatible with PyTorch 1.1 - 1.5. - """ - - # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the - # end of backward to the beginning of forward. - if ('parrots' not in TORCH_VERSION - and digit_version(TORCH_VERSION) >= digit_version('1.7') - and self.reducer._rebuild_buckets()): - # TODO: replace with logger - print('Reducer buckets have been rebuilt in this iteration.') - - if getattr(self, 'require_forward_param_sync', True): - self._sync_params() - - if self.device_ids: - inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) - if len(self.device_ids) == 1: - output = self.module.train_step(*inputs[0], **kwargs[0]) - else: - outputs = self.parallel_apply( - self._module_copies[:len(inputs)], inputs, kwargs) - output = self.gather(outputs, self.output_device) - else: - output = self.module.train_step(*inputs, **kwargs) - - if torch.is_grad_enabled() and getattr( - self, 'require_backward_grad_sync', True): - if self.find_unused_parameters: - self.reducer.prepare_for_backward(list(_find_tensors(output))) - else: - self.reducer.prepare_for_backward([]) - else: - if ('parrots' not in TORCH_VERSION - and digit_version(TORCH_VERSION) > digit_version('1.2')): - self.require_forward_param_sync = False - return output - - def val_step(self, *inputs, **kwargs): - """val_step() API for module wrapped by DistributedDataParallel. - - This method is basically the same as - ``DistributedDataParallel.forward()``, while replacing - ``self.module.forward()`` with ``self.module.val_step()``. - It is compatible with PyTorch 1.1 - 1.5. - """ - - # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the - # end of backward to the beginning of forward. - if ('parrots' not in TORCH_VERSION - and digit_version(TORCH_VERSION) >= digit_version('1.7') - and self.reducer._rebuild_buckets()): - # TODO: replace with logger - print('Reducer buckets have been rebuilt in this iteration.') - - if getattr(self, 'require_forward_param_sync', True): - self._sync_params() - if self.device_ids: - inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) - if len(self.device_ids) == 1: - output = self.module.val_step(*inputs[0], **kwargs[0]) - else: - outputs = self.parallel_apply( - self._module_copies[:len(inputs)], inputs, kwargs) - output = self.gather(outputs, self.output_device) - else: - output = self.module.val_step(*inputs, **kwargs) - - if torch.is_grad_enabled() and getattr( - self, 'require_backward_grad_sync', True): - if self.find_unused_parameters: - self.reducer.prepare_for_backward(list(_find_tensors(output))) - else: - self.reducer.prepare_for_backward([]) - else: - if ('parrots' not in TORCH_VERSION - and digit_version(TORCH_VERSION) > digit_version('1.2')): - self.require_forward_param_sync = False - return output diff --git a/mmengine/model/wrappers/distributed.py b/mmengine/model/wrappers/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..4084dde75e1e8428d296a1f435d7fbb2b1f9f0dc --- /dev/null +++ b/mmengine/model/wrappers/distributed.py @@ -0,0 +1,123 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List + +import torch +from torch.nn.parallel.distributed import DistributedDataParallel + +from mmengine.data import BaseDataElement +from mmengine.optim import OptimWrapper +from mmengine.registry import MODEL_WRAPPERS +from ..utils import detect_anomalous_params + + +@MODEL_WRAPPERS.register_module() +class MMDistributedDataParallel(DistributedDataParallel): + """A distributed model wrapper used for training,testing and validation in + loop. + + Different from DistributedDataParallel, MMDistributedDataParallel + implements three methods :meth:`train_step`, :meth:`val_step` and + :meth:`test_step`, which will be called by ``train_loop``, ``val_loop`` + and ``test_loop``. + + - ``train_step``: Called by ``runner.train_loop``, and implement + default model forward, gradient back propagation, parameter updating + logic. To take advantage of DistributedDataParallel's automatic gradient + synchronization, ``train_step`` calls ``DistributedDataParallel.forward`` + to calculate the losses, and call other methods of :obj:`BaseModel` to + pre-process data and parse losses. Finally, update model parameters by + :obj:``OptimWrapper`` and return the loss dictionary used for logging. + + - ``val_step``: Called by ``runner.val_loop`` and get the inference + results. Since there is no gradient synchronization requirement, + this procedure is equivalent to ``BaseModel.val_step`` + + - ``test_step``: Called by ``runner.test_loop``, equivalent ``val_step``. + + Args: + detect_anomalous_params (bool): This option is only used for + debugging which will slow down the training speed. + Detect anomalous parameters that are not included in + the computational graph with `loss` as the root. + There are two cases + + - Parameters were not used during + forward pass. + - Parameters were not used to produce + loss. + Default: False. + + *args: list arguments passed to ``DistributedDataParallel`` + **kwargs: keyword arguments passed to ``DistributedDataParallel``. + + Note: + If model has multiple submodules and each module has + separate optimization strategies, + :class:`MMSeparateDistributedDataParallel` should be used to wrap + the model. + + Note: + If model itself has custom optimization strategy, rather than + simply forward model and update model. A custom model wrapper + inherit from ``MMDistributedDataParallel`` should be defined and + override the ``train_step`` method. + """ + + def __init__(self, detect_anomalous_params: bool = False, *args, **kwargs): + super().__init__(*args, **kwargs) + self.detect_anomalous_params = detect_anomalous_params + + def train_step(self, data: List[dict], + optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]: + """Interface for model forward, backward and parameters updating during + training process. + + :meth:`train_step` will perform the following steps in order: + + - If :attr:`module` defines the preprocess method, + call ``module.preprocess`` to pre-processing data. + - Call ``module.forward(**data)`` and get losses. + - Parse losses. + - Call ``optim_wrapper.optimizer_step`` to update parameters. + - Return log messages of losses. + + Args: + data (List[dict]): Data sampled by dataloader. + optim_wrapper (OptimWrapper): A wrapper of optimizer to + update parameters. + + Returns: + Dict[str, torch.Tensor]: A ``dict`` of tensor for logging. + """ + # enable automatic mixed precision training context. + with optim_wrapper.precision_context(): + batch_inputs, data_samples = self.module.data_preprocessor( + data, training=True) + losses = self(batch_inputs, data_samples, mode='loss') + if self.detect_anomalous_params: + detect_anomalous_params(losses, model=self) + parsed_loss, log_vars = self.module.parse_losses(losses) + optim_wrapper.update_params(parsed_loss) + return log_vars + + def val_step(self, data: List[dict]) -> List[BaseDataElement]: + """Gets the prediction of module during validation process. + + Args: + data (List[dict]): Data sampled by dataloader. + + Returns: + List[BaseDataElement] or dict: The predictions of given data. + """ + return self.module.val_step(data) + + def test_step(self, data: List[dict]) -> List[BaseDataElement]: + """Gets the predictions of module during testing process. + + Args: + data: Data sampled by dataloader. + + Returns: + List[BaseDataElement]: The predictions of given data. + """ + return self.module.test_step(data) diff --git a/mmengine/model/wrappers/seperate_distributed.py b/mmengine/model/wrappers/seperate_distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..a369be06167a7fc6678fa536a278e3a4681a4c17 --- /dev/null +++ b/mmengine/model/wrappers/seperate_distributed.py @@ -0,0 +1,124 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from contextlib import ExitStack, contextmanager +from typing import Dict, List + +import torch +import torch.nn as nn +from torch.nn.parallel.distributed import DistributedDataParallel + +from mmengine.data import BaseDataElement +from mmengine.optim import OptimWrapperDict +from mmengine.registry import MODEL_WRAPPERS +from .distributed import MMDistributedDataParallel + + +@MODEL_WRAPPERS.register_module() +class MMSeparateDistributedDataParallel(DistributedDataParallel): + """A DistributedDataParallel wrapper for models in MMGeneration. + + In MMedting and MMGeneration there is a need to wrap different modules in + the models with separate DistributedDataParallel. Otherwise, it will cause + errors for GAN training. For example, the GAN model, usually has two + submodules: generator and discriminator. If we wrap both of them in one + standard DistributedDataParallel, it will cause errors during training, + because when we update the parameters of the generator (or discriminator), + the parameters of the discriminator (or generator) is not updated, which is + not allowed for DistributedDataParallel. So we design this wrapper to + separately wrap DistributedDataParallel for generator and discriminator. + In this wrapper, we perform two operations: + + 1. Wraps each module in the models with separate MMDistributedDataParallel. + Note that only modules with parameters will be wrapped. + 2. Calls ``train_step``, ``val_step`` and ``test_step`` of submodules to + get losses and predictions. + + Args: + module (nn.Module): model contain multiple submodules which have + separately updating strategy. + *args: list arguments passed to ``MMDistributedDataParallel`` + **kwargs: keyword arguments passed to ``MMDistributedDataParallel``. + """ + + def __init__(self, module: nn.Module, *args, **kwargs): + super(DistributedDataParallel, self).__init__() + self.module = module + # Wrap the submodule with parameters of `self.module` to + # `MMDistributedDataParallel` + for name, _module in module._modules.items(): + # module without parameters. + if next(_module.parameters(), None) is None: + _module = _module.cuda() + elif all(not p.requires_grad for p in module.parameters()): + _module = _module.cuda() + else: + _module = MMDistributedDataParallel( + module=_module.cuda(), *args, **kwargs) + module._modules[name] = _module + + def train_step(self, data: List[dict], + optim_wrapper: OptimWrapperDict) -> Dict[str, torch.Tensor]: + """Interface for model forward, backward and parameters updating during + training process. + + Args: + data: Data sampled by dataloader. + optim_wrapper (OptimWrapperDict): A wrapper of optimizer to + update parameters. + + Returns: + Dict[str, torch.Tensor]: A dict of tensor for logging. + """ + return self.module.train_step(data, optim_wrapper) + + def val_step(self, data) -> List[BaseDataElement]: + """Gets the prediction of module during validation process. + + Args: + data (List[dict]): Data sampled by dataloader. + + Returns: + List[BaseDataElement]: The predictions of given data. + """ + return self.module.val_step(data) + + def test_step(self, data: List[dict]) -> List[BaseDataElement]: + """Gets the predictions of module during testing process. + + Args: + data: Data sampled by dataloader. + + Returns: + ForwardResults: The predictions of given data. + """ + return self.module.test_step(data) + + @contextmanager + def no_sync(self): + """Enables ``no_sync`` context of all sub ``MMDistributedDataParallel`` + modules.""" + with ExitStack() as stack: + for sub_ddp_model in self.module._modules.values(): + stack.enter_context(sub_ddp_model.no_sync()) + yield + + def train(self, mode: bool = True) -> 'MMSeparateDistributedDataParallel': + """Sets the module in training mode. + + In order to make the ddp wrapper inheritance hierarchy more uniform, + ``MMSeparateDistributedDataParallel`` inherits from + ``DistributedDataParallel``, but will not call its constructor. + Since the attributes of ``DistributedDataParallel`` have not been + initialized, call the ``train`` method of ``DistributedDataParallel`` + will raise an error if pytorch version <= 1.9. Therefore, override + this method to call the ``train`` method of submodules. + + Args: + mode (bool): whether to set training mode (``True``) or evaluation + mode (``False``). Default: ``True``. + + Returns: + Module: self. + """ + self.training = mode + self.module.train(mode) + return self diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 482207dd0f6c57377482be8f1d4b75e58955a49a..6e18685568e2ee1465ec002de0a2bae8d41719f5 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -99,15 +99,20 @@ class EpochBasedTrainLoop(BaseLoop): """ self.runner.call_hook( 'before_train_iter', batch_idx=idx, data_batch=data_batch) - # outputs should be a dict containing one or multiple loss tensors - self.runner.outputs = self.runner.model(data_batch, return_loss=True) + # Enable gradient accumulation mode and avoid unnecessary gradient + # synchronization during gradient accumulation process. + with self.runner.optim_wrapper.accumulate_grad(self.runner.model, + self._iter, + self._max_iters): + # outputs should be a dict of loss. + outputs = self.runner.model.train_step( + data_batch, optim_wrapper=self.runner.optim_wrapper) self.runner.call_hook( 'after_train_iter', batch_idx=idx, data_batch=data_batch, - outputs=self.runner.outputs) - + outputs=outputs) self._iter += 1 @@ -197,14 +202,21 @@ class IterBasedTrainLoop(BaseLoop): """ self.runner.call_hook( 'before_train_iter', batch_idx=self._iter, data_batch=data_batch) - # outputs should be a dict containing loss tensor - self.runner.outputs = self.runner.model(data_batch, return_loss=True) + # Enable gradient accumulation mode and avoid unnecessary gradient + # synchronization during gradient accumulation process. + with self.runner.optim_wrapper.accumulate_grad(self.runner.model, + self._iter, + self._max_iters): + # train_logs should be a dict of loss. + train_logs = self.runner.model.train_step( + data_batch, optim_wrapper=self.runner.optim_wrapper) + self.runner.message_hub.update_info('train_logs', train_logs) self.runner.call_hook( 'after_train_iter', batch_idx=self._iter, data_batch=data_batch, - outputs=self.runner.outputs) + outputs=train_logs) self._iter += 1 @@ -247,7 +259,6 @@ class ValLoop(BaseLoop): # compute metrics metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) - self.runner.call_hook('after_val_epoch', metrics=metrics) self.runner.call_hook('after_val') @@ -262,7 +273,7 @@ class ValLoop(BaseLoop): self.runner.call_hook( 'before_val_iter', batch_idx=idx, data_batch=data_batch) # outputs should be sequence of BaseDataElement - outputs = self.runner.model(data_batch) + outputs = self.runner.model.val_step(data_batch) self.evaluator.process(data_batch, outputs) self.runner.call_hook( 'after_val_iter', @@ -310,7 +321,6 @@ class TestLoop(BaseLoop): # compute metrics metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) - self.runner.call_hook('after_test_epoch', metrics=metrics) self.runner.call_hook('after_test') @@ -324,7 +334,7 @@ class TestLoop(BaseLoop): self.runner.call_hook( 'before_test_iter', batch_idx=idx, data_batch=data_batch) # predictions should be sequence of BaseDataElement - predictions = self.runner.model(data_batch) + predictions = self.runner.model.test_step(data_batch) self.evaluator.process(data_batch, predictions) self.runner.call_hook( 'after_test_iter', diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index aff5c6ba55fd81203b91b356d07f8303c7040d36..38355db50c60d52914cddceb73e6380d2a3df3ba 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -13,7 +13,7 @@ from typing import Callable, Dict, List, Optional, Sequence, Union import numpy as np import torch import torch.nn as nn -from torch.nn.parallel import DistributedDataParallel +from torch.nn.parallel.distributed import DistributedDataParallel from torch.optim import Optimizer from torch.utils.data import DataLoader @@ -25,7 +25,8 @@ from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist, from mmengine.evaluator import Evaluator from mmengine.hooks import Hook from mmengine.logging import LogProcessor, MessageHub, MMLogger -from mmengine.model import is_model_wrapper +from mmengine.model import (BaseModel, MMDistributedDataParallel, + is_model_wrapper) from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler, build_optim_wrapper) from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, @@ -743,7 +744,7 @@ class Runner: 'visualizer should be Visualizer object, a dict or None, ' f'but got {visualizer}') - def build_model(self, model: Union[nn.Module, Dict]) -> nn.Module: + def build_model(self, model: Union[BaseModel, Dict]) -> BaseModel: """Build model. If ``model`` is a dict, it will be used to build a nn.Module object @@ -755,28 +756,30 @@ class Runner: model = dict(type='ResNet') Args: - model (nn.Module or dict): A nn.Module object or a dict to build + model (BaseModel or dict): A nn.Module object or a dict to build nn.Module object. If ``model`` is a nn.Module object, just returns itself. Returns: nn.Module: Model build from ``model``. """ - if isinstance(model, nn.Module): + if isinstance(model, BaseModel): return model elif isinstance(model, dict): model = MODELS.build(model) # init weights - if hasattr(model, 'init_weights'): - model.init_weights() - return model + if hasattr(model, 'init_weights'): # type: ignore + model.init_weights() # type: ignore + return model # type: ignore else: raise TypeError('model should be a nn.Module object or dict, ' f'but got {model}') - def wrap_model(self, model_wrapper_cfg: Optional[Dict], - model: nn.Module) -> nn.Module: - """Wrap model. + def wrap_model( + self, model_wrapper_cfg: Optional[Dict], + model: BaseModel) -> Union[DistributedDataParallel, BaseModel]: + """Wrap the model to :obj:``MMDistributedDataParallel`` or other custom + distributed data-parallel module wrappers. An example of ``model_wrapper_cfg``:: @@ -789,10 +792,11 @@ class Runner: model_wrapper_cfg (dict, optional): Config to wrap model. If not specified, ``DistributedDataParallel`` will be used in distributed environment. Defaults to None. - model (nn.Module): Model to be wrapped. + model (BaseModel): Model to be wrapped. Returns: - nn.Module: Wrapped model. + BaseModel or DistributedDataParallel: BaseModel or subclass of + ``DistributedDataParallel``. """ if is_model_wrapper(model): if model_wrapper_cfg is not None: @@ -802,25 +806,26 @@ class Runner: return model + # Set `export CUDA_VISIBLE_DEVICES=-1` to enable CPU training. + if torch.cuda.is_available(): + model = model.cuda() + + if not self.distributed: + return model + if model_wrapper_cfg is None: - if self.distributed: - find_unused_parameters = self.cfg.get('find_unused_parameters', - False) - # Sets the `find_unused_parameters` parameter in - # torch.nn.parallel.DistributedDataParallel - model = DistributedDataParallel( - self.model.cuda(), - device_ids=[torch.cuda.current_device()], - broadcast_buffers=False, - find_unused_parameters=find_unused_parameters) - else: - # Set `export CUDA_VISIBLE_DEVICES=-1` can enable CPU training. - if torch.cuda.is_available(): - model = model.cuda() + find_unused_parameters = self.cfg.get('find_unused_parameters', + False) + # Sets the `find_unused_parameters` parameter in + # torch.nn.parallel.DistributedDataParallel + model = MMDistributedDataParallel( + module=model, + device_ids=[torch.cuda.current_device()], + broadcast_buffers=False, + find_unused_parameters=find_unused_parameters) else: model = MODEL_WRAPPERS.build( - model_wrapper_cfg, default_args=dict(model=self.model)) - + model_wrapper_cfg, default_args=dict(module=model)) return model def scale_lr(self, @@ -908,10 +913,17 @@ class Runner: dict to build OptimWrapper objects. If ``optim_wrapper`` is an OptimWrapper, just return an ``OptimizeWrapper`` instance. + Note: + For single optimizer training, if `optim_wrapper` is a config + dict, `type` is optional(defaults to :obj:`OptimWrapper`) and it + must contain `optimizer` to build the corresponding optimizer. + Examples: >>> # build an optimizer >>> optim_wrapper_cfg = dict(type='OptimWrapper', optimizer=dict( ... type='SGD', lr=0.01)) + >>> # optim_wrapper_cfg = dict(optimizer=dict(type='SGD', lr=0.01)) + >>> # is also valid. >>> optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg) >>> optim_wrapper Type: OptimWrapper @@ -1648,8 +1660,6 @@ class Runner: +======================+=========================+ | RuntimeInfoHook | VERY_HIGH (10) | +----------------------+-------------------------+ - | OptimizerHook | HIGH (30) | - +----------------------+-------------------------+ | IterTimerHook | NORMAL (40) | +----------------------+-------------------------+ | DistSamplerSeedHook | NORMAL (40) | @@ -1666,7 +1676,6 @@ class Runner: default_hooks = dict( runtime_info=dict(type='RuntimeInfoHook'), - optimizer=dict(type='OptimizerHook', grad_clip=None), timer=dict(type='IterTimerHook'), sampler_seed=dict(type='DistSamplerSeedHook'), logger=dict(type='LoggerHook'), @@ -1689,7 +1698,6 @@ class Runner: """ default_hooks: dict = dict( runtime_info=dict(type='RuntimeInfoHook'), - optimizer=dict(type='OptimizerHook', grad_clip=None), timer=dict(type='IterTimerHook'), logger=dict(type='LoggerHook'), param_scheduler=dict(type='ParamSchedulerHook'), diff --git a/mmengine/utils/misc.py b/mmengine/utils/misc.py index 13b8e3220b80c4d5a7a5eee07ac314219895181d..5e151561e418c2a2db4f3d486d8b96be1f88e990 100644 --- a/mmengine/utils/misc.py +++ b/mmengine/utils/misc.py @@ -227,7 +227,7 @@ def check_prerequisites( prerequisites, checker, msg_tmpl='Prerequisites "{}" are required in method "{}" but not ' - 'found, please install them first.'): # yapf: disable + 'found, please install them first.'): # yapf: disable """A decorator factory to check if prerequisites are satisfied. Args: @@ -341,7 +341,6 @@ def deprecated_api_warning(name_dict: dict, if kwargs: for src_arg_name, dst_arg_name in name_dict.items(): if src_arg_name in kwargs: - assert dst_arg_name not in kwargs, ( f'The expected behavior is to replace ' f'the deprecated key `{src_arg_name}` to ' @@ -471,7 +470,7 @@ def tensor2imgs(tensor: torch.Tensor, if std is None: std = (1, ) * channels assert (channels == len(mean) == len(std) == 3) or \ - (channels == len(mean) == len(std) == 1 and not to_bgr) + (channels == len(mean) == len(std) == 1 and not to_bgr) mean = tensor.new_tensor(mean).view(1, -1) std = tensor.new_tensor(std).view(1, -1) tensor = tensor.permute(0, 2, 3, 1) * std + mean diff --git a/tests/test_hook/test_ema_hook.py b/tests/test_hook/test_ema_hook.py index 4cae6c837c77805936ed9533d72ef40c047b41d4..0da87f1a650a4342168cba3e0a255eb1b1086039 100644 --- a/tests/test_hook/test_ema_hook.py +++ b/tests/test_hook/test_ema_hook.py @@ -9,7 +9,7 @@ import torch.nn as nn from torch.utils.data import Dataset from mmengine.hooks import EMAHook -from mmengine.model import ExponentialMovingAverage +from mmengine.model import BaseModel, ExponentialMovingAverage from mmengine.optim import OptimWrapper from mmengine.registry import DATASETS, MODEL_WRAPPERS from mmengine.runner import Runner @@ -21,25 +21,28 @@ class ToyModel(nn.Module): super().__init__() self.linear = nn.Linear(2, 1) - def forward(self, data_batch, return_loss=False): - inputs, labels = [], [] - for x in data_batch: - inputs.append(x['inputs']) - labels.append(x['data_sample']) - - device = 'cuda:0' if torch.cuda.is_available() else 'cpu' - inputs = torch.stack(inputs).to(device) - labels = torch.stack(labels).to(device) - outputs = self.linear(inputs) - if return_loss: + def forward(self, batch_inputs, labels, mode='tensor'): + labels = torch.stack(labels) + outputs = self.linear(batch_inputs) + if mode == 'tensor': + return outputs + elif mode == 'loss': loss = (labels - outputs).sum() - outputs = dict(loss=loss, log_vars=dict(loss=loss.item())) + outputs = dict(loss=loss) return outputs else: - outputs = dict(log_vars=dict(a=1, b=0.5)) return outputs +class ToyModel1(BaseModel, ToyModel): + + def __init__(self): + super().__init__() + + def forward(self, *args, **kwargs): + return super(BaseModel, self).forward(*args, **kwargs) + + @DATASETS.register_module() class DummyDataset(Dataset): METAINFO = dict() # type: ignore @@ -67,7 +70,7 @@ class TestEMAHook(TestCase): def test_ema_hook(self): device = 'cuda:0' if torch.cuda.is_available() else 'cpu' - model = ToyModel().to(device) + model = ToyModel1().to(device) evaluator = Mock() evaluator.evaluate = Mock(return_value=dict(acc=0.5)) runner = Runner( @@ -121,7 +124,7 @@ class TestEMAHook(TestCase): runner.test() @MODEL_WRAPPERS.register_module() - class DummyWrapper(nn.Module): + class DummyWrapper(BaseModel): def __init__(self, model): super().__init__() @@ -132,7 +135,7 @@ class TestEMAHook(TestCase): # with model wrapper runner = Runner( - model=DummyWrapper(model), + model=DummyWrapper(ToyModel()), test_dataloader=dict( dataset=dict(type='DummyDataset'), sampler=dict(type='DefaultSampler', shuffle=True), diff --git a/tests/test_hook/test_optimizer_hook.py b/tests/test_hook/test_optimizer_hook.py deleted file mode 100644 index 1e665a93edd5301cdd68b5e163e622aa48fb4d52..0000000000000000000000000000000000000000 --- a/tests/test_hook/test_optimizer_hook.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from unittest.mock import MagicMock, Mock - -import torch -from torch import nn - -from mmengine.hooks import OptimizerHook - - -class TestOptimizerHook: - - def test_after_train_iter(self): - - class Model(nn.Module): - - def __init__(self): - super().__init__() - self.conv1 = nn.Conv2d( - in_channels=1, - out_channels=2, - kernel_size=3, - stride=1, - padding=1, - dilation=1) - self.conv2 = nn.Conv2d( - in_channels=2, - out_channels=2, - kernel_size=3, - stride=1, - padding=1, - dilation=1) - self.conv3 = nn.Conv2d( - in_channels=1, - out_channels=2, - kernel_size=3, - stride=1, - padding=1, - dilation=1) - - def forward(self, x): - x1 = self.conv1(x) - x2 = self.conv2(x1) - return x1, x2 - - model = Model() - x = torch.rand(1, 1, 3, 3) - - dummy_runner = MagicMock() - dummy_runner.optim_wrapper.zero_grad = Mock(return_value=None) - dummy_runner.optim_wrapper.step = Mock(return_value=None) - dummy_runner.model = model - dummy_runner.outputs = dict() - - dummy_runner.outputs['num_samples'] = 0 - - class DummyLogger(): - - def __init__(self): - self.msg = '' - - def log(self, msg=None, **kwargs): - self.msg += msg - - dummy_runner.logger = DummyLogger() - optimizer_hook = OptimizerHook( - dict(max_norm=2), detect_anomalous_params=True) - - dummy_runner.outputs['loss'] = model(x)[0].sum() - - dummy_runner.outputs['loss'].backward = Mock( - wraps=dummy_runner.outputs['loss'].backward) - optimizer_hook.detect_anomalous_parameters = Mock( - wraps=optimizer_hook.detect_anomalous_parameters) - optimizer_hook.clip_grads = Mock(wraps=optimizer_hook.clip_grads) - - optimizer_hook.after_train_iter(dummy_runner, 0) - # assert the parameters of conv2 and conv3 are not in the - # computational graph which is with x1.sum() as root. - assert 'conv2.weight' in dummy_runner.logger.msg - assert 'conv2.bias' in dummy_runner.logger.msg - assert 'conv3.weight' in dummy_runner.logger.msg - assert 'conv3.bias' in dummy_runner.logger.msg - assert 'conv1.weight' not in dummy_runner.logger.msg - assert 'conv1.bias' not in dummy_runner.logger.msg - dummy_runner.optim_wrapper.step.assert_called() - dummy_runner.outputs['loss'].backward.assert_called() - optimizer_hook.clip_grads.assert_called() - optimizer_hook.detect_anomalous_parameters.assert_called() - - dummy_runner.outputs['loss'] = model(x)[1].sum() - dummy_runner.logger.msg = '' - optimizer_hook.after_train_iter(dummy_runner, 0) - # assert the parameters of conv3 are not in the computational graph - assert 'conv3.weight' in dummy_runner.logger.msg - assert 'conv3.bias' in dummy_runner.logger.msg - assert 'conv2.weight' not in dummy_runner.logger.msg - assert 'conv2.bias' not in dummy_runner.logger.msg - assert 'conv1.weight' not in dummy_runner.logger.msg - assert 'conv1.bias' not in dummy_runner.logger.msg - - # grad_clip is None and detect_anomalous_parameters is False - optimizer_hook = OptimizerHook(detect_anomalous_params=False) - optimizer_hook.detect_anomalous_parameters = Mock( - wraps=optimizer_hook.detect_anomalous_parameters) - optimizer_hook.clip_grads = Mock(wraps=optimizer_hook.clip_grads) - dummy_runner.outputs['loss'] = model(x)[0].sum() - dummy_runner.outputs['loss'].backward = Mock( - wraps=dummy_runner.outputs['loss'].backward) - - optimizer_hook.after_train_iter(dummy_runner, 0) - - dummy_runner.optim_wrapper.step.assert_called() - dummy_runner.outputs['loss'].backward.assert_called() - optimizer_hook.clip_grads.assert_not_called() - optimizer_hook.detect_anomalous_parameters.assert_not_called() diff --git a/tests/test_hook/test_runtime_info_hook.py b/tests/test_hook/test_runtime_info_hook.py index 2eb651c5e4e21ae654fee5076d0f73ef8fac623d..b57e26eb529de73a701e42dda819bce467057189 100644 --- a/tests/test_hook/test_runtime_info_hook.py +++ b/tests/test_hook/test_runtime_info_hook.py @@ -102,12 +102,7 @@ class TestRuntimeInfoHook(TestCase): runner.message_hub = message_hub hook = RuntimeInfoHook() hook.after_train_iter( - runner, - batch_idx=2, - data_batch=None, - outputs={'log_vars': { - 'loss_cls': 1.111 - }}) + runner, batch_idx=2, data_batch=None, outputs={'loss_cls': 1.111}) self.assertEqual( message_hub.get_scalar('train/loss_cls').current(), 1.111) diff --git a/tests/test_logging/test_message_hub.py b/tests/test_logging/test_message_hub.py index c36d239df17ca0c0128c40b4448c237d9eaeb087..e64cb88ac05aee2166ba28d5a74c2bf0e4fbc1cc 100644 --- a/tests/test_logging/test_message_hub.py +++ b/tests/test_logging/test_message_hub.py @@ -101,7 +101,7 @@ class TestMessageHub: message_hub.update_scalar('lr', 0.1, resumed=False) # update runtime information message_hub.update_info('iter', 1, resumed=True) - message_hub.update_info('feat', [1, 2, 3], resumed=False) + message_hub.update_info('tensor', [1, 2, 3], resumed=False) obj = pickle.dumps(message_hub) instance = pickle.loads(obj) diff --git a/tests/test_model/test_base_model/test_base_model.py b/tests/test_model/test_base_model/test_base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..280fcead669cd8ead5361d2eff486e0aa89c6a67 --- /dev/null +++ b/tests/test_model/test_base_model/test_base_model.py @@ -0,0 +1,125 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import unittest +from unittest import TestCase + +import torch +import torch.nn as nn +from torch.optim import SGD + +from mmengine.model import BaseDataPreprocessor, BaseModel +from mmengine.optim import OptimWrapper +from mmengine.registry import MODELS +from mmengine.testing import assert_allclose + + +@MODELS.register_module() +class CustomDataPreprocessor(BaseDataPreprocessor): + + def forward(self, data, training=False): + if training: + return 1 + else: + return 2 + + +class ToyModel(BaseModel): + + def __init__(self, data_preprocessor=None): + super().__init__(None, data_preprocessor=data_preprocessor) + self.conv = nn.Conv2d(3, 1, 1) + + def forward(self, batch_inputs, data_samples=None, mode='tensor'): + if mode == 'loss': + out = self.conv(batch_inputs) + return dict(loss=out) + elif mode == 'predict': + out = self.conv(batch_inputs) + return out + elif mode == 'tensor': + out = self.conv(batch_inputs) + return out + + +class TestBaseModel(TestCase): + + def test_init(self): + # initiate model without `preprocess_cfg` + model = ToyModel() + self.assertIsInstance(model.data_preprocessor, BaseDataPreprocessor) + data_preprocessor = dict(type='CustomDataPreprocessor') + model = ToyModel(data_preprocessor=data_preprocessor) + self.assertIsInstance(model.data_preprocessor, CustomDataPreprocessor) + self.assertEqual(model.data_preprocessor(1, training=True), 1) + self.assertEqual(model.data_preprocessor(1, training=False), 2) + + # initiate model with built `data_preprocessor`. + data_preprocessor = CustomDataPreprocessor() + model = ToyModel(data_preprocessor=data_preprocessor) + self.assertIs(model.data_preprocessor, data_preprocessor) + + # initiate model with error type `data_preprocessor`. + with self.assertRaisesRegex(TypeError, 'data_preprocessor should be'): + ToyModel(data_preprocessor=[data_preprocessor]) + + def test_parse_losses(self): + model = ToyModel() + loss_cls = torch.tensor(1, dtype=torch.float32) + loss_list = [ + torch.tensor(2, dtype=torch.float32), + torch.tensor(3, dtype=torch.float32) + ] + losses = dict(loss_cls=loss_cls, loss_list=loss_list) + target_parsed_losses = torch.tensor(6, dtype=torch.float32) + targe_log_vars = dict( + loss=torch.tensor(6, dtype=torch.float32), + loss_cls=torch.tensor(1, dtype=torch.float32), + loss_list=torch.tensor(5, dtype=torch.float32)) + parse_losses, log_vars = model.parse_losses(losses) + assert_allclose(parse_losses, target_parsed_losses) + for key in log_vars: + self.assertIn(key, targe_log_vars) + assert_allclose(log_vars[key], targe_log_vars[key]) + + with self.assertRaises(TypeError): + losses['error_key'] = dict() + model.parse_losses(losses) + + def test_train_step(self): + model = ToyModel() + optimizer = SGD(model.parameters(), lr=0.1) + optim_wrapper = OptimWrapper(optimizer) + inputs = torch.randn(3, 1, 1) + data = dict(inputs=inputs) + # initiate grad. + # model.conv.weight.grad = torch.randn(1, 3, 1, 1) + log_vars = model.train_step([data], optim_wrapper) + self.assertIsNotNone(model.conv.weight.grad) + self.assertIsInstance(log_vars['loss'], torch.Tensor) + + def test_val_step(self): + inputs = torch.randn(3, 1, 1) + data = dict(inputs=inputs) + model = ToyModel() + out = model.val_step([data]) + self.assertIsInstance(out, torch.Tensor) + + def test_test_step(self): + inputs = torch.randn(3, 1, 1) + data = dict(inputs=inputs) + model = ToyModel() + out = model.val_step([data]) + self.assertIsInstance(out, torch.Tensor) + + @unittest.skipIf(not torch.cuda.is_available(), 'cuda should be available') + def test_cuda(self): + inputs = torch.randn(3, 1, 1).cuda() + data = dict(inputs=inputs) + model = ToyModel().cuda() + model.val_step([data]) + + @unittest.skipIf(not torch.cuda.is_available(), 'cuda should be available') + def test_to(self): + inputs = torch.randn(3, 1, 1).cuda() + data = dict(inputs=inputs) + model = ToyModel().to(torch.cuda.current_device()) + model.val_step([data]) diff --git a/tests/test_model/test_base_model/test_data_preprocessor.py b/tests/test_model/test_base_model/test_data_preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..146ed35e209aeab31802c2b3d31e142086f2f0bc --- /dev/null +++ b/tests/test_model/test_base_model/test_data_preprocessor.py @@ -0,0 +1,110 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch +import torch.nn.functional as F + +from mmengine import InstanceData +from mmengine.model import BaseDataPreprocessor, ImgDataPreprocessor +from mmengine.testing import assert_allclose + + +class TestBaseDataPreprocessor(TestCase): + + def test_init(self): + base_data_preprocessor = BaseDataPreprocessor() + self.assertEqual(base_data_preprocessor.device, 'cpu') + + def test_forward(self): + base_data_preprocessor = BaseDataPreprocessor() + input1 = torch.randn(1, 3, 5) + input2 = torch.randn(1, 3, 5) + label1 = torch.randn(1) + label2 = torch.randn(1) + + data = [ + dict(inputs=input1, data_sample=label1), + dict(inputs=input2, data_sample=label2) + ] + + batch_inputs, batch_labels = base_data_preprocessor(data) + self.assertEqual(batch_inputs.shape, (2, 1, 3, 5)) + + assert_allclose(input1, batch_inputs[0]) + assert_allclose(input2, batch_inputs[1]) + assert_allclose(label1, batch_labels[0]) + assert_allclose(label2, batch_labels[1]) + + +class TestImageDataPreprocessor(TestBaseDataPreprocessor): + + def test_init(self): + # initiate model without `preprocess_cfg` + data_processor = ImgDataPreprocessor() + self.assertFalse(data_processor.channel_conversion) + assert_allclose(data_processor.mean, + torch.tensor([127.5, 127.5, 127.5]).view(-1, 1, 1)) + assert_allclose(data_processor.std, + torch.tensor([127.5, 127.5, 127.5]).view(-1, 1, 1)) + self.assertEqual(data_processor.pad_size_divisor, 1) + assert_allclose(data_processor.pad_value, torch.tensor(0)) + # initiate model with preprocess_cfg` and feat keys + data_processor = ImgDataPreprocessor( + bgr_to_rgb=True, + mean=[0, 0, 0], + std=[255, 255, 255], + pad_size_divisor=16, + pad_value=10) + self.assertTrue(data_processor.channel_conversion, True) + assert_allclose(data_processor.mean, + torch.tensor([0, 0, 0]).view(-1, 1, 1)) + assert_allclose(data_processor.std, + torch.tensor([255, 255, 255]).view(-1, 1, 1)) + assert_allclose(data_processor.pad_value, torch.tensor(10)) + self.assertEqual(data_processor.pad_size_divisor, 16) + + with self.assertRaisesRegex(AssertionError, 'The length of mean'): + ImgDataPreprocessor(mean=(1, 2)) + + with self.assertRaisesRegex(AssertionError, 'The length of std'): + ImgDataPreprocessor(std=(1, 2)) + + with self.assertRaisesRegex(AssertionError, '`bgr2rgb` and `rgb2bgr`'): + ImgDataPreprocessor(bgr_to_rgb=True, rgb_to_bgr=True) + + def test_forward(self): + # Test `pad_value`, `to_rgb`, `pad_size_divisor`. + data_preprocessor = ImgDataPreprocessor( + mean=[127.5], + std=[1, 2, 3], + pad_size_divisor=16, + pad_value=10, + rgb_to_bgr=True, + ) + inputs1 = torch.randn(3, 10, 10) + inputs2 = torch.randn(3, 15, 15) + data_sample1 = InstanceData(bboxes=torch.randn(5, 4)) + data_sample2 = InstanceData(bboxes=torch.randn(5, 4)) + data = [ + dict(inputs=inputs1.clone(), data_sample=data_sample1.clone()), + dict(inputs=inputs2.clone(), data_sample=data_sample2.clone()) + ] + + std = torch.tensor([1, 2, 3]).view(-1, 1, 1) + inputs1 = (inputs1[[2, 1, 0], ...] - 127.5) / std + inputs2 = (inputs2[[2, 1, 0], ...] - 127.5) / std + inputs1 = F.pad(inputs1, (0, 6, 0, 6), value=10) + inputs2 = F.pad(inputs2, (0, 1, 0, 1), value=10) + + target_inputs = [inputs1, inputs2] + inputs, data_samples = data_preprocessor(data, True) + + target_data_samples = [data_sample1, data_sample2] + for input_, data_sample, target_input, target_data_sample in zip( + inputs, data_samples, target_inputs, target_data_samples): + assert_allclose(input_, target_input) + assert_allclose(data_sample.bboxes, target_data_sample.bboxes) + + # Test empty `data_sample` + data = [dict(inputs=inputs1.clone()), dict(inputs=inputs2.clone())] + data_preprocessor(data, True) diff --git a/tests/test_model/test_wrappers/test_data_parallel.py b/tests/test_model/test_wrappers/test_data_parallel.py deleted file mode 100644 index c1f96ac476e7ae4eeae5a6a795012a1b48fb70b1..0000000000000000000000000000000000000000 --- a/tests/test_model/test_wrappers/test_data_parallel.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from unittest import TestCase -from unittest.mock import MagicMock, patch - -import pytest -import torch -import torch.nn as nn -from torch.nn.parallel import DataParallel -from torch.nn.parallel.distributed import DistributedDataParallel - -from mmengine.model.wrappers import (MMDataParallel, MMDistributedDataParallel, - is_model_wrapper) -from mmengine.registry import MODEL_WRAPPERS - - -def mock(*args, **kwargs): - pass - - -@patch('torch.distributed._broadcast_coalesced', mock) -@patch('torch.distributed.broadcast', mock) -@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock) -def test_is_model_wrapper(): - - class Model(nn.Module): - - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(2, 2, 1) - - def forward(self, x): - return self.conv(x) - - # _verify_model_across_ranks is added in torch1.9.0 so we should check - # whether _verify_model_across_ranks is the member of torch.distributed - # before mocking - if hasattr(torch.distributed, '_verify_model_across_ranks'): - torch.distributed._verify_model_across_ranks = mock - - # _verify_model_across_ranks is added in torch1.11.0 so we should check - # whether _verify_params_across_processes is the member of - # torch.distributed before mocking - if hasattr(torch.distributed, '_verify_params_across_processes'): - torch.distributed._verify_params_across_processes = mock - - model = Model() - assert not is_model_wrapper(model) - - mmdp = MMDataParallel(model) - assert is_model_wrapper(mmdp) - - mmddp = MMDistributedDataParallel(model, process_group=MagicMock()) - assert is_model_wrapper(mmddp) - - torch_dp = DataParallel(model) - assert is_model_wrapper(torch_dp) - - torch_ddp = DistributedDataParallel(model, process_group=MagicMock()) - assert is_model_wrapper(torch_ddp) - - # test model wrapper registry - @MODEL_WRAPPERS.register_module() - class ModelWrapper: - - def __init__(self, module): - self.module = module - - def forward(self, *args, **kwargs): - return self.module(*args, **kwargs) - - model_wrapper = ModelWrapper(model) - assert is_model_wrapper(model_wrapper) - - -class TestMMDataParallel(TestCase): - - def setUp(self): - """Setup the demo image in every test method. - - TestCase calls functions in this order: setUp() -> testMethod() -> - tearDown() -> cleanUp() - """ - - class Model(nn.Module): - - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(1, 2, 1) - - def forward(self, x): - return self.conv(x) - - def train_step(self, x): - return self.forward(x) - - def val_step(self, x): - return self.forward(x) - - self.model = Model() - - def test_train_step(self): - - class Model(nn.Module): - - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(1, 2, 1) - - def forward(self, x): - return self.conv(x) - - model = Model() - mmdp = MMDataParallel(model) - - # test without train_step attribute - with pytest.raises(AssertionError): - mmdp.train_step(torch.zeros([1, 1, 3, 3])) - - out = self.model.train_step(torch.zeros([1, 1, 3, 3])) - assert out.shape == (1, 2, 3, 3) - - def test_val_step(self): - - class Model(nn.Module): - - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(1, 2, 1) - - def forward(self, x): - return self.conv(x) - - model = Model() - mmdp = MMDataParallel(model) - - # test without val_step attribute - with pytest.raises(AssertionError): - mmdp.val_step(torch.zeros([1, 1, 3, 3])) - - out = self.model.val_step(torch.zeros([1, 1, 3, 3])) - assert out.shape == (1, 2, 3, 3) diff --git a/tests/test_model/test_wrappers/test_model_wrapper.py b/tests/test_model/test_wrappers/test_model_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..2a7e74bee16062148511d85724f0bc6ef4362d77 --- /dev/null +++ b/tests/test_model/test_wrappers/test_model_wrapper.py @@ -0,0 +1,161 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import unittest +from unittest.mock import MagicMock + +import torch +import torch.distributed as torch_dist +import torch.nn as nn +from torch.optim import SGD + +from mmengine.model import (BaseModel, MMDistributedDataParallel, + MMSeparateDistributedDataParallel) +from mmengine.optim import OptimWrapper, OptimWrapperDict +from mmengine.testing import assert_allclose +from mmengine.testing._internal import MultiProcessTestCase + + +class ToyModel(BaseModel): + + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 1, 1) + self.conv2 = nn.Conv2d(1, 1, 1) + + def forward(self, x, data_samples=None, mode='tensor'): + if mode == 'loss': + x = self.conv1(x) + x = self.conv2(x) + return dict(loss=x) + elif mode == 'predict': + return x + else: + return x + + +class ComplexModel(BaseModel): + + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 1, 1) + self.conv2 = nn.Conv2d(3, 1, 1) + + def train_step(self, data, optim_wrapper): + batch_inputs, _ = self.data_preprocessor(data) + loss1 = self.conv1(batch_inputs) + optim_wrapper['optim_wrapper1'].update_params(loss1) + loss2 = self.conv2(batch_inputs) + optim_wrapper['optim_wrapper2'].update_params(loss2) + return dict(loss1=loss1, loss2=loss2) + + def val_step(self, data): + return 1 + + def test_step(self, data): + return 2 + + def forward(self): + pass + + +class TestModelWrapper(MultiProcessTestCase): + + def setUp(self) -> None: + super().setUp() + self._spawn_processes() + + def test_train_step(self): + self._init_dist_env(self.rank, self.world_size) + # Test `optim_wrapper` is a instance of `OptimWrapper` + model = ToyModel() + ddp_model = MMDistributedDataParallel(module=model) + optimizer = SGD(ddp_model.parameters(), lr=0) + optim_wrapper = OptimWrapper(optimizer, accumulative_iters=1) + inputs = torch.randn(3, 1, 1) * self.rank * 255 + data = dict(inputs=inputs, data_sample=MagicMock()) + ddp_model.train_step([data], optim_wrapper=optim_wrapper) + grad = ddp_model.module.conv1.weight.grad + assert_allclose(grad, torch.zeros_like(grad)) + + def test_val_step(self): + self._init_dist_env(self.rank, self.world_size) + model = ToyModel() + ddp_model = MMDistributedDataParallel(module=model) + inputs = torch.randn(3, 1, 1) * self.rank * 255 + data = dict(inputs=inputs, data_sample=MagicMock()) + # Test get predictions. + predictions = ddp_model.val_step([data]) + self.assertIsInstance(predictions, torch.Tensor) + + def test_test_step(self): + self._init_dist_env(self.rank, self.world_size) + model = ToyModel() + ddp_model = MMDistributedDataParallel(module=model) + inputs = torch.randn(3, 1, 1) * self.rank * 255 + data = dict(inputs=inputs, data_sample=MagicMock()) + predictions = ddp_model.test_step([data]) + self.assertIsInstance(predictions, torch.Tensor) + + def _init_dist_env(self, rank, world_size): + """Initialize the distributed environment.""" + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '29510' + os.environ['RANK'] = str(rank) + torch_dist.init_process_group( + backend='gloo', rank=rank, world_size=world_size) + + +@unittest.skipIf( + not torch.cuda.is_available(), reason='cuda should be available') +class TestMMSeparateDistributedDataParallel(TestModelWrapper): + + def test_train_step(self): + self._init_dist_env(self.rank, self.world_size) + # Test `optim_wrapper` is a dict. In this case, + # There will be two independently updated `DistributedDataParallel` + # submodules. + model = ComplexModel() + ddp_model = MMSeparateDistributedDataParallel(model.cuda()) + optimizer1 = SGD(model.conv1.parameters(), lr=0.1) + optimizer2 = SGD(model.conv1.parameters(), lr=0.2) + optim_wrapper1 = OptimWrapper(optimizer1, 1) + optim_wrapper2 = OptimWrapper(optimizer2, 1) + optim_wrapper_dict = OptimWrapperDict( + optim_wrapper1=optim_wrapper1, optim_wrapper2=optim_wrapper2) + inputs = torch.randn(3, 1, 1).cuda() * self.rank * 255 + data = dict(inputs=inputs) + # Automatically sync grads of `optim_wrapper1` since + # `cumulative_iters` = 1 + ddp_model.train() + self.assertTrue(ddp_model.training) + ddp_model.train_step([data], optim_wrapper=optim_wrapper_dict) + + def test_val_step(self): + self._init_dist_env(self.rank, self.world_size) + model = ComplexModel() + ddp_model = MMSeparateDistributedDataParallel(model) + data = torch.randn(3, 1, 1) + # Test get predictions. + ddp_model.eval() + self.assertFalse(ddp_model.training) + predictions = ddp_model.val_step([data]) + self.assertEqual(predictions, 1) + + def test_test_step(self): + self._init_dist_env(self.rank, self.world_size) + model = ComplexModel() + ddp_model = MMSeparateDistributedDataParallel(model) + data = torch.randn(3, 1, 1) + # Test get predictions. + ddp_model.eval() + self.assertFalse(ddp_model.training) + predictions = ddp_model.test_step(data) + self.assertEqual(predictions, 2) + + def _init_dist_env(self, rank, world_size): + """Initialize the distributed environment.""" + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '29515' + os.environ['RANK'] = str(rank) + torch_dist.init_process_group( + backend='gloo', rank=rank, world_size=world_size) diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index 438b840e46e4b1d59d48f40fbe5dc5989b09057a..d88a0e5b894c2f47c769ba5341960958043ebe9a 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy +import os import os.path as osp import shutil import tempfile @@ -19,6 +20,7 @@ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, Hook, IterTimerHook, LoggerHook, ParamSchedulerHook, RuntimeInfoHook) from mmengine.logging import LogProcessor, MessageHub, MMLogger +from mmengine.model import BaseModel from mmengine.optim import (DefaultOptimWrapperConstructor, MultiStepLR, OptimWrapper, OptimWrapperDict, StepLR) from mmengine.registry import (DATASETS, HOOKS, LOG_PROCESSORS, LOOPS, METRICS, @@ -33,29 +35,25 @@ from mmengine.visualization import Visualizer @MODELS.register_module() -class ToyModel(nn.Module): +class ToyModel(BaseModel): def __init__(self): super().__init__() self.linear1 = nn.Linear(2, 2) self.linear2 = nn.Linear(2, 1) - def forward(self, data_batch, return_loss=False): - inputs, labels = [], [] - for x in data_batch: - inputs.append(x['inputs']) - labels.append(x['data_sample']) - - device = 'cuda:0' if torch.cuda.is_available() else 'cpu' - inputs = torch.stack(inputs).to(device) - labels = torch.stack(labels).to(device) - outputs = self.linear1(inputs) + def forward(self, batch_inputs, labels, mode='tensor'): + labels = torch.stack(labels) + outputs = self.linear1(batch_inputs) outputs = self.linear2(outputs) - if return_loss: + + if mode == 'tensor': + return outputs + elif mode == 'loss': loss = (labels - outputs).sum() - outputs = dict(loss=loss, log_vars=dict(loss=loss.item())) + outputs = dict(loss=loss) return outputs - else: + elif mode == 'predict': outputs = dict(log_vars=dict(a=1, b=0.5)) return outputs @@ -67,12 +65,43 @@ class ToyModel1(ToyModel): super().__init__() +@MODELS.register_module() +class TopGANModel(BaseModel): + + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(2, 1) + self.linear2 = nn.Linear(2, 1) + + def forward(self, batch_inputs, labels, mode='tensor'): + labels = torch.stack(labels) + output1 = self.linear1(batch_inputs) + output2 = self.linear2(batch_inputs) + + if mode == 'tensor': + return output1, output2 + elif mode == 'loss': + loss1 = (labels - output1).sum() + loss2 = (labels - output2).sum() + outputs = dict(linear1=loss1, linear2=loss2) + return outputs + elif mode == 'predict': + return output1, output2 + + def train_step(self, data, optim_wrapper): + batch_inputs, batch_labels = self.data_preprocessor(data) + loss = self(batch_inputs, batch_labels, mode='loss') + optim_wrapper['linear1'].update_params(loss['linear1']) + optim_wrapper['linear2'].update_params(loss['linear2']) + return loss + + @MODEL_WRAPPERS.register_module() class CustomModelWrapper(nn.Module): - def __init__(self, model): + def __init__(self, module): super().__init__() - self.model = model + self.model = module @OPTIM_WRAPPER_CONSTRUCTORS.register_module() @@ -294,7 +323,6 @@ class TestRunner(TestCase): custom_hooks=[], default_hooks=dict( runtime_info=dict(type='RuntimeInfoHook'), - optimizer=dict(type='OptimizerHook', grad_clip=None), timer=dict(type='IterTimerHook'), logger=dict(type='LoggerHook'), param_scheduler=dict(type='ParamSchedulerHook'), @@ -314,7 +342,6 @@ class TestRunner(TestCase): self.iter_based_cfg.train_cfg = dict(by_epoch=False, max_iters=12) self.iter_based_cfg.default_hooks = dict( runtime_info=dict(type='RuntimeInfoHook'), - optimizer=dict(type='OptimizerHook', grad_clip=None), timer=dict(type='IterTimerHook'), logger=dict(type='LoggerHook'), param_scheduler=dict(type='ParamSchedulerHook'), @@ -649,13 +676,21 @@ class TestRunner(TestCase): self.assertFalse(model.initiailzed) def test_wrap_model(self): - # TODO: test on distributed environment # custom model wrapper cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_wrap_model' cfg.model_wrapper_cfg = dict(type='CustomModelWrapper') runner = Runner.from_cfg(cfg) - self.assertIsInstance(runner.model, CustomModelWrapper) + self.assertIsInstance(runner.model, BaseModel) + if torch.cuda.is_available(): + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '29515' + os.environ['RANK'] = str(0) + os.environ['WORLD_SIZE'] = str(1) + cfg.launcher = 'pytorch' + cfg.experiment_name = 'test_wrap_model1' + runner = Runner.from_cfg(cfg) + self.assertIsInstance(runner.model, CustomModelWrapper) def test_scale_lr(self): cfg = copy.deepcopy(self.epoch_based_cfg) @@ -1270,35 +1305,35 @@ class TestRunner(TestCase): # register 7 hooks by default runner.register_default_hooks() - self.assertEqual(len(runner._hooks), 7) + self.assertEqual(len(runner._hooks), 6) # the third registered hook should be `DistSamplerSeedHook` - self.assertTrue(isinstance(runner._hooks[3], DistSamplerSeedHook)) + self.assertTrue(isinstance(runner._hooks[2], DistSamplerSeedHook)) # the fifth registered hook should be `ParamSchedulerHook` - self.assertTrue(isinstance(runner._hooks[5], ParamSchedulerHook)) + self.assertTrue(isinstance(runner._hooks[4], ParamSchedulerHook)) runner._hooks = [] # remove `ParamSchedulerHook` from default hooks runner.register_default_hooks(hooks=dict(timer=None)) - self.assertEqual(len(runner._hooks), 6) + self.assertEqual(len(runner._hooks), 5) # `ParamSchedulerHook` was popped so the fifth is `CheckpointHook` - self.assertTrue(isinstance(runner._hooks[5], CheckpointHook)) + self.assertTrue(isinstance(runner._hooks[4], CheckpointHook)) # add a new default hook runner._hooks = [] runner.register_default_hooks(hooks=dict(ToyHook=dict(type='ToyHook'))) - self.assertEqual(len(runner._hooks), 8) - self.assertTrue(isinstance(runner._hooks[7], ToyHook)) + self.assertEqual(len(runner._hooks), 7) + self.assertTrue(isinstance(runner._hooks[6], ToyHook)) def test_custom_hooks(self): cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_custom_hooks' runner = Runner.from_cfg(cfg) - self.assertEqual(len(runner._hooks), 7) + self.assertEqual(len(runner._hooks), 6) custom_hooks = [dict(type='ToyHook')] runner.register_custom_hooks(custom_hooks) - self.assertEqual(len(runner._hooks), 8) - self.assertTrue(isinstance(runner._hooks[7], ToyHook)) + self.assertEqual(len(runner._hooks), 7) + self.assertTrue(isinstance(runner._hooks[6], ToyHook)) def test_register_hooks(self): cfg = copy.deepcopy(self.epoch_based_cfg) @@ -1309,8 +1344,8 @@ class TestRunner(TestCase): custom_hooks = [dict(type='ToyHook')] runner.register_hooks(custom_hooks=custom_hooks) # six default hooks + custom hook (ToyHook) - self.assertEqual(len(runner._hooks), 8) - self.assertTrue(isinstance(runner._hooks[7], ToyHook)) + self.assertEqual(len(runner._hooks), 7) + self.assertTrue(isinstance(runner._hooks[6], ToyHook)) def test_custom_loop(self): # test custom loop with additional hook @@ -1346,12 +1381,11 @@ class TestRunner(TestCase): def warmup_iter(self, data_batch): self.runner.call_hook( 'before_warmup_iter', data_batch=data_batch) - self.runner.outputs = self.runner.model( - data_batch, return_loss=True) + train_logs = self.runner.model.train_step( + data_batch, self.runner.optim_wrapper) + self.runner.message_hub.update_info('train_logs', train_logs) self.runner.call_hook( - 'after_warmup_iter', - data_batch=data_batch, - outputs=self.runner.outputs) + 'after_warmup_iter', data_batch=data_batch) before_warmup_iter_results = [] after_warmup_iter_results = [] @@ -1513,8 +1547,8 @@ class TestRunner(TestCase): linear2=dict( type='OptimWrapper', optimizer=dict(type='Adam', lr=0.02)), constructor='ToyMultipleOptimizerConstructor') + cfg.model = dict(type='TopGANModel') # disable OptimizerHook because it only works with one optimizer - cfg.default_hooks = dict(optimizer=None) runner = Runner.from_cfg(cfg) runner.train() path = osp.join(self.temp_dir, 'epoch_3.pth') @@ -1533,8 +1567,8 @@ class TestRunner(TestCase): linear2=dict( type='OptimWrapper', optimizer=dict(type='Adam', lr=0.03)), constructor='ToyMultipleOptimizerConstructor') + cfg.model = dict(type='TopGANModel') cfg.param_scheduler = dict(type='MultiStepLR', milestones=[1, 2, 3]) - cfg.default_hooks = dict(optimizer=None) runner = Runner.from_cfg(cfg) runner.resume(path) self.assertIsInstance(runner.optim_wrapper, OptimWrapperDict)