From 59cc08e3ac87a1d99ed1153c9eddba7053150fb5 Mon Sep 17 00:00:00 2001 From: RangiLyu <lyuchqi@gmail.com> Date: Fri, 8 Apr 2022 15:57:10 +0800 Subject: [PATCH] [Refactor] Refactor data_batch type and remove cur_dataloader in runner. (#171) * [Refactor] Refactor data_batch type. * fix sampler * [Refactor] Remove cur_dataloader in runner. * fix set_epoch --- mmengine/data/sampler.py | 7 +- mmengine/data/utils.py | 10 +-- mmengine/evaluator/evaluator.py | 18 +++-- mmengine/evaluator/metric.py | 7 +- mmengine/hooks/checkpoint_hook.py | 9 +-- mmengine/hooks/empty_cache_hook.py | 8 +- mmengine/hooks/hook.py | 46 +++++------ mmengine/hooks/iter_timer_hook.py | 12 +-- mmengine/hooks/logger_hook.py | 19 ++--- mmengine/hooks/naive_visualization_hook.py | 19 +++-- mmengine/hooks/optimizer_hook.py | 12 ++- mmengine/hooks/param_scheduler_hook.py | 12 ++- mmengine/hooks/sampler_seed_hook.py | 10 ++- mmengine/runner/base_loop.py | 3 - mmengine/runner/loops.py | 45 ++++++----- mmengine/runner/runner.py | 4 + tests/test_evaluator/test_evaluator.py | 20 +++-- tests/test_hook/test_hook.py | 9 +-- tests/test_hook/test_logger_hook.py | 6 +- .../test_naive_visualization_hook.py | 78 ++++++++----------- tests/test_hook/test_sampler_seed_hook.py | 21 ++--- tests/test_runner/test_runner.py | 7 +- 22 files changed, 186 insertions(+), 196 deletions(-) diff --git a/mmengine/data/sampler.py b/mmengine/data/sampler.py index 47b2c3b4..ff1d13ec 100644 --- a/mmengine/data/sampler.py +++ b/mmengine/data/sampler.py @@ -2,18 +2,13 @@ import itertools import math from typing import Iterator, Optional, Sized -# from mmengine.dist import get_dist_info, sync_random_seed -from unittest.mock import MagicMock import torch from torch.utils.data import Sampler +from mmengine.dist import get_dist_info, sync_random_seed from mmengine.registry import DATA_SAMPLERS -# TODO, need to remove those lines after implementing dist module -get_dist_info = MagicMock(return_value=(0, 1)) -sync_random_seed = MagicMock(return_value=0) - @DATA_SAMPLERS.register_module() class DefaultSampler(Sampler): diff --git a/mmengine/data/utils.py b/mmengine/data/utils.py index 0f569d39..c284a336 100644 --- a/mmengine/data/utils.py +++ b/mmengine/data/utils.py @@ -1,13 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import random -from typing import Any, Sequence, Tuple +from typing import Sequence import numpy as np import torch -from .base_data_element import BaseDataElement - -DATA_BATCH = Sequence[Tuple[Any, BaseDataElement]] +DATA_BATCH = Sequence[dict] def worker_init_fn(worker_id: int, num_workers: int, rank: int, @@ -36,10 +34,10 @@ def pseudo_collate(data_batch: DATA_BATCH) -> DATA_BATCH: nothing just returns ``data_batch``. Args: - data_batch (Sequence[Tuple[Any, BaseDataElement]]): Batch of data from + data_batch (Sequence[dict]): Batch of data from dataloader. Returns: - Sequence[Tuple[Any, BaseDataElement]]: Return input ``data_batch``. + Sequence[dict]: Return input ``data_batch``. """ return data_batch diff --git a/mmengine/evaluator/evaluator.py b/mmengine/evaluator/evaluator.py index c653fb56..34bb02e5 100644 --- a/mmengine/evaluator/evaluator.py +++ b/mmengine/evaluator/evaluator.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Iterator, List, Optional, Sequence, Tuple, Union +from typing import Iterator, List, Optional, Sequence, Union from mmengine.data import BaseDataElement from ..registry.root import METRICS @@ -37,23 +37,25 @@ class Evaluator: for metric in self.metrics: metric.dataset_meta = dataset_meta - def process(self, data_batch: Sequence[Tuple[Any, BaseDataElement]], + def process(self, data_batch: Sequence[dict], predictions: Sequence[BaseDataElement]): """Convert ``BaseDataSample`` to dict and invoke process method of each metric. Args: - data_batch (Sequence[Tuple[Any, BaseDataElement]]): A batch of data - from the dataloader. + data_batch (Sequence[dict]): A batch of data from the dataloader. predictions (Sequence[BaseDataElement]): A batch of outputs from the model. """ _data_batch = [] - for input, data in data_batch: - if isinstance(data, BaseDataElement): - _data_batch.append((input, data.to_dict())) + for data in data_batch: + if isinstance(data['data_sample'], BaseDataElement): + _data_batch.append( + dict( + inputs=data['inputs'], + data_sample=data['data_sample'].to_dict())) else: - _data_batch.append((input, data)) + _data_batch.append(data) _predictions = [] for pred in predictions: if isinstance(pred, BaseDataElement): diff --git a/mmengine/evaluator/metric.py b/mmengine/evaluator/metric.py index e8a71488..4bcf163b 100644 --- a/mmengine/evaluator/metric.py +++ b/mmengine/evaluator/metric.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import warnings from abc import ABCMeta, abstractmethod -from typing import Any, List, Optional, Sequence, Tuple, Union +from typing import Any, List, Optional, Sequence, Union from mmengine.dist import (broadcast_object_list, collect_results, is_main_process) @@ -50,15 +50,14 @@ class BaseMetric(metaclass=ABCMeta): self._dataset_meta = dataset_meta @abstractmethod - def process(self, data_batch: Sequence[Tuple[Any, dict]], + def process(self, data_batch: Sequence[dict], predictions: Sequence[dict]) -> None: """Process one batch of data samples and predictions. The processed results should be stored in ``self.results``, which will be used to compute the metrics when all batches have been processed. Args: - data_batch (Sequence[Tuple[Any, dict]]): A batch of data - from the dataloader. + data_batch (Sequence[dict]): A batch of data from the dataloader. predictions (Sequence[dict]): A batch of outputs from the model. """ diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index a373bbb5..017784b9 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -2,15 +2,14 @@ import os.path as osp import warnings from pathlib import Path -from typing import Any, Optional, Sequence, Tuple, Union +from typing import Optional, Sequence, Union -from mmengine.data import BaseDataElement from mmengine.dist import master_only from mmengine.fileio import FileClient from mmengine.registry import HOOKS from .hook import Hook -DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]] +DATA_BATCH = Optional[Sequence[dict]] @HOOKS.register_module() @@ -185,8 +184,8 @@ class CheckpointHook(Hook): Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the train loop. - data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data - from dataloader. Defaults to None. + data_batch (Sequence[dict], optional): Data from dataloader. + Defaults to None. outputs (dict, optional): Outputs from model. Defaults to None. """ diff --git a/mmengine/hooks/empty_cache_hook.py b/mmengine/hooks/empty_cache_hook.py index c793f01b..be6b5c2c 100644 --- a/mmengine/hooks/empty_cache_hook.py +++ b/mmengine/hooks/empty_cache_hook.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Optional, Sequence, Tuple, Union +from typing import Optional, Sequence, Union import torch @@ -7,7 +7,7 @@ from mmengine.data import BaseDataElement from mmengine.registry import HOOKS from .hook import Hook -DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]] +DATA_BATCH = Optional[Sequence[dict]] @HOOKS.register_module() @@ -46,8 +46,8 @@ class EmptyCacheHook(Hook): Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the loop. - data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data - from dataloader. Defaults to None. + data_batch (Sequence[dict], optional): Data from dataloader. + Defaults to None. outputs (dict or sequence, optional): Outputs from model. Defaults to None. mode (str): Current mode of runner. Defaults to 'train'. diff --git a/mmengine/hooks/hook.py b/mmengine/hooks/hook.py index 84060334..49995334 100644 --- a/mmengine/hooks/hook.py +++ b/mmengine/hooks/hook.py @@ -1,9 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Optional, Sequence, Tuple, Union +from typing import Optional, Sequence, Union from mmengine.data import BaseDataElement -DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]] +DATA_BATCH = Optional[Sequence[dict]] class Hook: @@ -174,8 +174,8 @@ class Hook: Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the train loop. - data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): - Data from dataloader. Defaults to None. + data_batch (Sequence[dict], optional): Data from dataloader. + Defaults to None. """ self._before_iter( runner, batch_idx=batch_idx, data_batch=data_batch, mode='train') @@ -190,8 +190,8 @@ class Hook: Args: runner (Runner): The runner of the validation process. batch_idx (int): The index of the current batch in the val loop. - data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): - Data from dataloader. Defaults to None. + data_batch (Sequence[dict], optional): Data from dataloader. + Defaults to None. """ self._before_iter( runner, batch_idx=batch_idx, data_batch=data_batch, mode='val') @@ -206,8 +206,8 @@ class Hook: Args: runner (Runner): The runner of the testing process. batch_idx (int): The index of the current batch in the test loop. - data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): - Data from dataloader. Defaults to None. + data_batch (Sequence[dict], optional): Data from dataloader. + Defaults to None. """ self._before_iter( runner, batch_idx=batch_idx, data_batch=data_batch, mode='test') @@ -223,8 +223,8 @@ class Hook: Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the train loop. - data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): - Data from dataloader. Defaults to None. + data_batch (Sequence[dict], optional): Data from dataloader. + Defaults to None. outputs (dict, optional): Outputs from model. Defaults to None. """ @@ -247,8 +247,8 @@ class Hook: Args: runner (Runner): The runner of the validation process. batch_idx (int): The index of the current batch in the val loop. - data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): - Data from dataloader. Defaults to None. + data_batch (Sequence[dict], optional): Data from dataloader. + Defaults to None. outputs (dict or sequence, optional): Outputs from model. Defaults to None. """ @@ -271,8 +271,8 @@ class Hook: Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the test loop. - data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): - Data from dataloader. Defaults to None. + data_batch (Sequence[dict], optional): Data from dataloader. + Defaults to None. outputs (dict, optional): Outputs from model. Defaults to None. """ @@ -317,8 +317,8 @@ class Hook: runner (Runner): The runner of the training, validation or testing process. batch_idx (int): The index of the current batch in the loop. - data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): - Data from dataloader. Defaults to None. + data_batch (Sequence[dict], optional): Data from dataloader. + Defaults to None. mode (str): Current mode of runner. Defaults to 'train'. """ pass @@ -337,8 +337,8 @@ class Hook: runner (Runner): The runner of the training, validation or testing process. batch_idx (int): The index of the current batch in the loop. - data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): - Data from dataloader. Defaults to None. + data_batch (Sequence[dict], optional): Data from dataloader. + Defaults to None. outputs (Sequence[BaseDataElement], optional): Outputs from model. Defaults to None. mode (str): Current mode of runner. Defaults to 'train'. @@ -387,19 +387,19 @@ class Hook: """ return (runner.iter + 1) % n == 0 if n > 0 else False - def end_of_epoch(self, runner, batch_idx: int) -> bool: + def end_of_epoch(self, dataloader, batch_idx: int) -> bool: """Check whether the current iteration reaches the last iteration of - current dataloader. + the dataloader. Args: - runner (Runner): The runner of the training, validation or testing - process. + dataloader (Dataloader): The dataloader of the training, + validation or testing process. batch_idx (int): The index of the current batch in the loop. Returns: bool: Whether reaches the end of current epoch or not. """ - return batch_idx + 1 == len(runner.cur_dataloader) + return batch_idx + 1 == len(dataloader) def is_last_train_epoch(self, runner) -> bool: """Test whether current epoch is the last train epoch. diff --git a/mmengine/hooks/iter_timer_hook.py b/mmengine/hooks/iter_timer_hook.py index bf123cae..26abf17e 100644 --- a/mmengine/hooks/iter_timer_hook.py +++ b/mmengine/hooks/iter_timer_hook.py @@ -1,12 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. import time -from typing import Any, Optional, Sequence, Tuple, Union +from typing import Optional, Sequence, Union from mmengine.data import BaseDataElement from mmengine.registry import HOOKS from .hook import Hook -DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]] +DATA_BATCH = Optional[Sequence[dict]] @HOOKS.register_module() @@ -37,8 +37,8 @@ class IterTimerHook(Hook): Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the loop. - data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data - from dataloader. Defaults to None. + data_batch (Sequence[dict], optional): Data from dataloader. + Defaults to None. mode (str): Current mode of runner. Defaults to 'train'. """ # TODO: update for new logging system @@ -57,8 +57,8 @@ class IterTimerHook(Hook): Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the loop. - data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data - from dataloader. Defaults to None. + data_batch (Sequence[dict], optional): Data from dataloader. + Defaults to None. outputs (dict or sequence, optional): Outputs from model. Defaults to None. mode (str): Current mode of runner. Defaults to 'train'. diff --git a/mmengine/hooks/logger_hook.py b/mmengine/hooks/logger_hook.py index 786fd311..16adedaa 100644 --- a/mmengine/hooks/logger_hook.py +++ b/mmengine/hooks/logger_hook.py @@ -5,18 +5,17 @@ import os import os.path as osp from collections import OrderedDict from pathlib import Path -from typing import Any, Optional, Sequence, Tuple, Union +from typing import Optional, Sequence, Union import torch -from mmengine.data import BaseDataElement from mmengine.dist import master_only from mmengine.fileio import FileClient from mmengine.hooks import Hook from mmengine.registry import HOOKS from mmengine.utils import is_tuple_of, scandir -DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]] +DATA_BATCH = Optional[Sequence[dict]] @HOOKS.register_module() @@ -183,15 +182,16 @@ class LoggerHook(Hook): Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the train loop. - data_batch (Sequence[BaseDataElement], optional): Data from - dataloader. Defaults to None. + data_batch (Sequence[dict], optional): Data from dataloader. + Defaults to None. outputs (dict, optional): Outputs from model. Defaults to None. """ self._inner_iter = batch_idx if runner.meta is not None and 'exp_name' in runner.meta: if (self.every_n_iters(runner, self.interval_exp_name)) or ( - self.by_epoch and self.end_of_epoch(runner, batch_idx)): + self.by_epoch and self.end_of_epoch( + runner.train_loop.dataloader, batch_idx)): exp_info = f'Exp name: {runner.meta["exp_name"]}' runner.logger.info(exp_info) if self.by_epoch and self.every_n_inner_iters(batch_idx, @@ -199,7 +199,8 @@ class LoggerHook(Hook): self._log_train(runner) elif not self.by_epoch and self.every_n_iters(runner, self.interval): self._log_train(runner) - elif self.end_of_epoch(runner, batch_idx) and not self.ignore_last: + elif self.end_of_epoch(runner.train_loop.dataloader, + batch_idx) and not self.ignore_last: # `runner.max_iters` may not be divisible by `self.interval`. if # `self.ignore_last==True`, the log of remaining iterations will # be recorded (Epoch [4][1000/1007], the logs of 998-1007 @@ -271,7 +272,7 @@ class LoggerHook(Hook): # by iter: Iter [100/100000] if self.by_epoch: log_str = f'Epoch [{cur_epoch}]' \ - f'[{cur_iter}/{len(runner.cur_dataloader)}]\t' + f'[{cur_iter}/{len(runner.train_loop.dataloader)}]\t' else: log_str = f'Iter [{cur_iter}/{runner.train_loop.max_iters}]\t' log_str += f'{lr_momentum_str}, ' @@ -311,7 +312,7 @@ class LoggerHook(Hook): """ tag = self._collect_info(runner, 'val') # Compatible with function `log` https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/logger/text.py # noqa E501 - eval_iter = len(runner.cur_dataloader) + eval_iter = len(runner.val_loop.dataloader) cur_iter = self._get_iter(runner) cur_epoch = self._get_epoch(runner, 'val') # val/test time diff --git a/mmengine/hooks/naive_visualization_hook.py b/mmengine/hooks/naive_visualization_hook.py index 2e05fc59..e8bd3834 100644 --- a/mmengine/hooks/naive_visualization_hook.py +++ b/mmengine/hooks/naive_visualization_hook.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp -from typing import Any, Optional, Sequence, Tuple +from typing import Optional, Sequence, Tuple import cv2 import numpy as np @@ -41,26 +41,25 @@ class NaiveVisualizationHook(Hook): self, runner, batch_idx: int, - data_batch: Optional[Sequence[Tuple[Any, BaseDataElement]]] = None, + data_batch: Optional[Sequence[dict]] = None, outputs: Optional[Sequence[BaseDataElement]] = None) -> None: """Show or Write the predicted results. Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the test loop. - data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data + data_batch (Sequence[dict], optional): Data from dataloader. Defaults to None. outputs (Sequence[BaseDataElement], optional): Outputs from model. Defaults to None. """ if self.every_n_iters(runner, self._interval): - inputs, data_samples = data_batch # type: ignore - inputs = tensor2imgs(inputs, - **data_samples[0].get('img_norm_cfg', dict())) - for input, data_sample, output in zip( - inputs, - data_samples, # type: ignore - outputs): # type: ignore + for data, output in zip(data_batch, outputs): # type: ignore + input = data['inputs'] + data_sample = data['data_sample'] + input = tensor2imgs(input, + **data_sample.get('img_norm_cfg', + dict()))[0] # TODO We will implement a function to revert the augmentation # in the future. ori_shape = (data_sample.ori_width, data_sample.ori_height) diff --git a/mmengine/hooks/optimizer_hook.py b/mmengine/hooks/optimizer_hook.py index ff33b54a..9107dbf0 100644 --- a/mmengine/hooks/optimizer_hook.py +++ b/mmengine/hooks/optimizer_hook.py @@ -1,16 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. import logging -from typing import Any, List, Optional, Sequence, Tuple +from typing import List, Optional, Sequence import torch from torch.nn.parameter import Parameter from torch.nn.utils import clip_grad -from mmengine.data import BaseDataElement from mmengine.registry import HOOKS from .hook import Hook -DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]] +DATA_BATCH = Optional[Sequence[dict]] @HOOKS.register_module() @@ -77,10 +76,9 @@ class OptimizerHook(Hook): Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the train loop. - data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data - from dataloader. In order to keep this interface consistent - with other hooks, we keep ``data_batch`` here. - Defaults to None. + data_batch (Sequence[dict], optional): Data from dataloader. + In order to keep this interface consistent with other hooks, + we keep ``data_batch`` here. Defaults to None. outputs (dict, optional): Outputs from model. In order to keep this interface consistent with other hooks, we keep ``outputs`` here. Defaults to None. diff --git a/mmengine/hooks/param_scheduler_hook.py b/mmengine/hooks/param_scheduler_hook.py index 9522abcf..c4e7af58 100644 --- a/mmengine/hooks/param_scheduler_hook.py +++ b/mmengine/hooks/param_scheduler_hook.py @@ -1,11 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Optional, Sequence, Tuple +from typing import Optional, Sequence -from mmengine.data import BaseDataElement from mmengine.registry import HOOKS from .hook import Hook -DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]] +DATA_BATCH = Optional[Sequence[dict]] @HOOKS.register_module() @@ -25,10 +24,9 @@ class ParamSchedulerHook(Hook): Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the train loop. - data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data - from dataloader. In order to keep this interface consistent - with other hooks, we keep ``data_batch`` here. - Defaults to None. + data_batch (Sequence[dict], optional): Data from dataloader. + In order to keep this interface consistent with other hooks, + we keep ``data_batch`` here. Defaults to None. outputs (dict, optional): Outputs from model. In order to keep this interface consistent with other hooks, we keep ``data_batch`` here. Defaults to None. diff --git a/mmengine/hooks/sampler_seed_hook.py b/mmengine/hooks/sampler_seed_hook.py index eed3fa90..b90657c8 100644 --- a/mmengine/hooks/sampler_seed_hook.py +++ b/mmengine/hooks/sampler_seed_hook.py @@ -20,9 +20,11 @@ class DistSamplerSeedHook(Hook): Args: runner (Runner): The runner of the training process. """ - if hasattr(runner.cur_dataloader.sampler, 'set_epoch'): + if hasattr(runner.train_loop.dataloader.sampler, 'set_epoch'): # in case the data loader uses `SequentialSampler` in Pytorch - runner.cur_dataloader.sampler.set_epoch(runner.epoch) - elif hasattr(runner.cur_dataloader.batch_sampler.sampler, 'set_epoch'): + runner.train_loop.dataloader.sampler.set_epoch(runner.epoch) + elif hasattr(runner.train_loop.dataloader.batch_sampler.sampler, + 'set_epoch'): # batch sampler in pytorch warps the sampler as its attributes. - runner.cur_dataloader.batch_sampler.sampler.set_epoch(runner.epoch) + runner.train_loop.dataloader.batch_sampler.sampler.set_epoch( + runner.epoch) diff --git a/mmengine/runner/base_loop.py b/mmengine/runner/base_loop.py index 0a0e3ca7..ec3f880d 100644 --- a/mmengine/runner/base_loop.py +++ b/mmengine/runner/base_loop.py @@ -25,9 +25,6 @@ class BaseLoop(metaclass=ABCMeta): else: self.dataloader = dataloader - # TODO, used by `end_of_epoch` of `Hook` - self._runner.data_loader = self.dataloader - @property def runner(self): return self._runner diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index eb5c3454..d791c52c 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -1,10 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Dict, List, Sequence, Tuple, Union +import warnings +from typing import Dict, List, Sequence, Union import torch from torch.utils.data import DataLoader -from mmengine.data import BaseDataElement from mmengine.evaluator import Evaluator from mmengine.registry import LOOPS from mmengine.utils import is_list_of @@ -40,7 +40,6 @@ class EpochBasedTrainLoop(BaseLoop): def run(self) -> None: """Launch training.""" - self.runner.cur_dataloader = self.dataloader self.runner.call_hook('before_train') while self.runner._epoch < self._max_epochs: @@ -62,13 +61,11 @@ class EpochBasedTrainLoop(BaseLoop): self.runner.call_hook('after_train_epoch') self.runner._epoch += 1 - def run_iter(self, idx, - data_batch: Sequence[Tuple[Any, BaseDataElement]]) -> None: + def run_iter(self, idx, data_batch: Sequence[dict]) -> None: """Iterate one min-batch. Args: - data_batch (Sequence[Tuple[Any, BaseDataElement]]): Batch of data - from dataloader. + data_batch (Sequence[dict]): Batch of data from dataloader. """ self.runner.call_hook( 'before_train_iter', batch_idx=idx, data_batch=data_batch) @@ -112,7 +109,6 @@ class IterBasedTrainLoop(BaseLoop): def run(self) -> None: """Launch training.""" - self.runner.cur_dataloader = self.dataloader self.runner.call_hook('before_train') # In iteration-based training loop, we treat the whole training process # as a big epoch and execute the corresponding hook. @@ -130,13 +126,11 @@ class IterBasedTrainLoop(BaseLoop): self.runner.call_hook('after_train_epoch') self.runner.call_hook('after_train') - def run_iter(self, data_batch: Sequence[Tuple[Any, - BaseDataElement]]) -> None: + def run_iter(self, data_batch: Sequence[dict]) -> None: """Iterate one mini-batch. Args: - data_batch (Sequence[Tuple[Any, BaseDataElement]]): Batch of data - from dataloader. + data_batch (Sequence[dict]): Batch of data from dataloader. """ self.runner.call_hook( 'before_train_iter', @@ -180,12 +174,17 @@ class ValLoop(BaseLoop): self.evaluator = runner.build_evaluator(evaluator) # type: ignore else: self.evaluator = evaluator # type: ignore - + if hasattr(self.dataloader.dataset, 'metainfo'): + self.evaluator.dataset_meta = self.dataloader.dataset.metainfo + else: + warnings.warn( + f'Dataset {self.dataloader.dataset.__class__.__name__} has no ' + 'metainfo. ``dataset_meta`` in evaluator and metric will be ' + 'None.') self.interval = interval def run(self): """Launch validation.""" - self.runner.cur_dataloader = self.dataloader self.runner.call_hook('before_val') self.runner.call_hook('before_val_epoch') self.runner.model.eval() @@ -201,11 +200,11 @@ class ValLoop(BaseLoop): self.runner.call_hook('after_val') @torch.no_grad() - def run_iter(self, idx, data_batch: Sequence[Tuple[Any, BaseDataElement]]): + def run_iter(self, idx, data_batch: Sequence[dict]): """Iterate one mini-batch. Args: - data_batch (Sequence[Tuple[Any, BaseDataElement]]): Batch of data + data_batch (Sequence[dict]): Batch of data from dataloader. """ self.runner.call_hook( @@ -239,10 +238,16 @@ class TestLoop(BaseLoop): self.evaluator = runner.build_evaluator(evaluator) # type: ignore else: self.evaluator = evaluator # type: ignore + if hasattr(self.dataloader.dataset, 'metainfo'): + self.evaluator.dataset_meta = self.dataloader.dataset.metainfo + else: + warnings.warn( + f'Dataset {self.dataloader.dataset.__class__.__name__} has no ' + 'metainfo. ``dataset_meta`` in evaluator and metric will be ' + 'None.') def run(self) -> None: """Launch test.""" - self.runner.cur_dataloader = self.dataloader self.runner.call_hook('before_test') self.runner.call_hook('before_test_epoch') self.runner.model.eval() @@ -258,13 +263,11 @@ class TestLoop(BaseLoop): self.runner.call_hook('after_test') @torch.no_grad() - def run_iter(self, idx, - data_batch: Sequence[Tuple[Any, BaseDataElement]]) -> None: + def run_iter(self, idx, data_batch: Sequence[dict]) -> None: """Iterate one mini-batch. Args: - data_batch (Sequence[Tuple[Any, BaseDataElement]]): Batch of data - from dataloader. + data_batch (Sequence[dict]): Batch of data from dataloader. """ self.runner.call_hook( 'before_test_iter', batch_idx=idx, data_batch=data_batch) diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 9d364b34..28d339ca 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -1215,6 +1215,8 @@ class Runner: +----------------------+-------------------------+ | IterTimerHook | NORMAL (40) | +----------------------+-------------------------+ + | DistSamplerSeedHook | NORMAL (40) | + +----------------------+-------------------------+ | LoggerHook | BELOW_NORMAL (60) | +----------------------+-------------------------+ | ParamSchedulerHook | LOW (70) | @@ -1228,6 +1230,7 @@ class Runner: default_hooks = dict( optimizer=dict(type='OptimizerHook', grad_clip=None), timer=dict(type='IterTimerHook'), + sampler_seed=dict(type='DistSamplerSeedHook'), logger=dict(type='LoggerHook'), param_scheduler=dict(type='ParamSchedulerHook'), checkpoint=dict(type='CheckpointHook', interval=1), @@ -1252,6 +1255,7 @@ class Runner: logger=dict(type='LoggerHook'), param_scheduler=dict(type='ParamSchedulerHook'), checkpoint=dict(type='CheckpointHook', interval=1), + sampler_seed=dict(type='DistSamplerSeedHook'), ) if hooks is not None: for name, hook in hooks.items(): diff --git a/tests/test_evaluator/test_evaluator.py b/tests/test_evaluator/test_evaluator.py index 61be034b..5364c067 100644 --- a/tests/test_evaluator/test_evaluator.py +++ b/tests/test_evaluator/test_evaluator.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import math -from typing import Any, Dict, List, Optional, Sequence, Tuple +from typing import Dict, List, Optional, Sequence from unittest import TestCase import numpy as np @@ -40,7 +40,7 @@ class ToyMetric(BaseMetric): def process(self, data_batch, predictions): results = [{ 'pred': pred.get('pred'), - 'label': data[1].get('label') + 'label': data['data_sample'].get('label') } for pred, data in zip(predictions, data_batch)] self.results.extend(results) @@ -66,7 +66,7 @@ class NonPrefixedMetric(BaseMetric): """Evaluator with unassigned `default_prefix` to test the warning information.""" - def process(self, data_batch: Sequence[Tuple[Any, dict]], + def process(self, data_batch: Sequence[dict], predictions: Sequence[dict]) -> None: pass @@ -79,8 +79,11 @@ def generate_test_results(size, batch_size, pred, label): bs_residual = size % batch_size for i in range(num_batch): bs = bs_residual if i == num_batch - 1 else batch_size - data_batch = [(np.zeros((3, 10, 10)), BaseDataElement(label=label)) - for _ in range(bs)] + data_batch = [ + dict( + inputs=np.zeros((3, 10, 10)), + data_sample=BaseDataElement(label=label)) for _ in range(bs) + ] predictions = [BaseDataElement(pred=pred) for _ in range(bs)] yield (data_batch, predictions) @@ -228,7 +231,10 @@ class TestEvaluator(TestCase): size = 10 - all_data = [(np.zeros((3, 10, 10)), BaseDataElement(label=1)) - for _ in range(size)] + all_data = [ + dict( + inputs=np.zeros((3, 10, 10)), + data_sample=BaseDataElement(label=1)) for _ in range(size) + ] all_predictions = [BaseDataElement(pred=0) for _ in range(size)] evaluator.offline_evaluate(all_data, all_predictions) diff --git a/tests/test_hook/test_hook.py b/tests/test_hook/test_hook.py index db80ed4a..771c54f6 100644 --- a/tests/test_hook/test_hook.py +++ b/tests/test_hook/test_hook.py @@ -157,18 +157,17 @@ class TestHook: def test_end_of_epoch(self): hook = Hook() - runner = Mock() # last inner iter batch_idx = 1 - runner.cur_dataloader.__len__ = Mock(return_value=2) - runner.cur_dataloader.__len__ = Mock(return_value=2) - return_val = hook.end_of_epoch(runner, batch_idx) + dataloader = Mock() + dataloader.__len__ = Mock(return_value=2) + return_val = hook.end_of_epoch(dataloader, batch_idx) assert return_val # not the last inner iter batch_idx = 0 - return_val = hook.end_of_epoch(runner, batch_idx) + return_val = hook.end_of_epoch(dataloader, batch_idx) assert not return_val def test_is_last_train_epoch(self): diff --git a/tests/test_hook/test_logger_hook.py b/tests/test_hook/test_logger_hook.py index 6fac1a93..024716f2 100644 --- a/tests/test_hook/test_logger_hook.py +++ b/tests/test_hook/test_logger_hook.py @@ -111,7 +111,7 @@ class TestLoggerHook: # Test end of the epoch. logger_hook = LoggerHook(by_epoch=True, ignore_last=False) logger_hook._log_train = MagicMock() - runner.cur_dataloader = [0] * 5 + runner.train_loop.dataloader = [0] * 5 batch_idx = 4 logger_hook.after_train_iter(runner, batch_idx=batch_idx) logger_hook._log_train.assert_called() @@ -341,7 +341,9 @@ class TestLoggerHook: def _setup_runner(self): runner = MagicMock() runner.epoch = 1 - runner.cur_dataloader = [0] * 5 + runner.train_loop.dataloader = [0] * 5 + runner.val_loop.dataloader = [0] * 5 + runner.test_loop.dataloader = [0] * 5 runner.iter = 10 runner.train_loop.max_iters = 50 logger = logging.getLogger() diff --git a/tests/test_hook/test_naive_visualization_hook.py b/tests/test_hook/test_naive_visualization_hook.py index 4d75fedb..0bbe47df 100644 --- a/tests/test_hook/test_naive_visualization_hook.py +++ b/tests/test_hook/test_naive_visualization_hook.py @@ -16,70 +16,56 @@ class TestNaiveVisualizationHook: inputs = torch.randn(1, 3, 15, 15) batch_idx = 10 # test with normalize, resize, pad - gt_datasamples = [ - BaseDataElement( - metainfo=dict( - img_norm_cfg=dict( - mean=(0, 0, 0), std=(0.5, 0.5, 0.5), to_bgr=True), - scale=(10, 10), - pad_shape=(15, 15, 3), - ori_height=5, - ori_width=5, - img_path='tmp.jpg')) - ] + gt_datasamples = BaseDataElement( + metainfo=dict( + img_norm_cfg=dict( + mean=(0, 0, 0), std=(0.5, 0.5, 0.5), to_bgr=True), + scale=(10, 10), + pad_shape=(15, 15, 3), + ori_height=5, + ori_width=5, + img_path='tmp.jpg')) pred_datasamples = [BaseDataElement()] - data_batch = (inputs, gt_datasamples) + data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)] naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, pred_datasamples) # test with resize, pad - gt_datasamples = [ - BaseDataElement( - metainfo=dict( - scale=(10, 10), - pad_shape=(15, 15, 3), - ori_height=5, - ori_width=5, - img_path='tmp.jpg')), - ] + gt_datasamples = BaseDataElement( + metainfo=dict( + scale=(10, 10), + pad_shape=(15, 15, 3), + ori_height=5, + ori_width=5, + img_path='tmp.jpg')) pred_datasamples = [BaseDataElement()] - data_batch = (inputs, gt_datasamples) + data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)] naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, pred_datasamples) # test with only resize - gt_datasamples = [ - BaseDataElement( - metainfo=dict( - scale=(15, 15), - ori_height=5, - ori_width=5, - img_path='tmp.jpg')), - ] + gt_datasamples = BaseDataElement( + metainfo=dict( + scale=(15, 15), ori_height=5, ori_width=5, img_path='tmp.jpg')) pred_datasamples = [BaseDataElement()] - data_batch = (inputs, gt_datasamples) + data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)] naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, pred_datasamples) # test with only pad - gt_datasamples = [ - BaseDataElement( - metainfo=dict( - pad_shape=(15, 15, 3), - ori_height=5, - ori_width=5, - img_path='tmp.jpg')), - ] + gt_datasamples = BaseDataElement( + metainfo=dict( + pad_shape=(15, 15, 3), + ori_height=5, + ori_width=5, + img_path='tmp.jpg')) pred_datasamples = [BaseDataElement()] - data_batch = (inputs, gt_datasamples) + data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)] naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, pred_datasamples) # test no transform - gt_datasamples = [ - BaseDataElement( - metainfo=dict(ori_height=15, ori_width=15, - img_path='tmp.jpg')), - ] + gt_datasamples = BaseDataElement( + metainfo=dict(ori_height=15, ori_width=15, img_path='tmp.jpg')) pred_datasamples = [BaseDataElement()] - data_batch = (inputs, gt_datasamples) + data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)] naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, pred_datasamples) diff --git a/tests/test_hook/test_sampler_seed_hook.py b/tests/test_hook/test_sampler_seed_hook.py index 9d19edf7..c1bf8b54 100644 --- a/tests/test_hook/test_sampler_seed_hook.py +++ b/tests/test_hook/test_sampler_seed_hook.py @@ -12,17 +12,18 @@ class TestDistSamplerSeedHook: # Test dataset sampler runner = Mock() runner.epoch = 1 - runner.cur_dataloader = Mock() - runner.cur_dataloader.sampler = Mock() - runner.cur_dataloader.sampler.set_epoch = Mock() + runner.train_loop.dataloader = Mock() + runner.train_loop.dataloader.sampler = Mock() + runner.train_loop.dataloader.sampler.set_epoch = Mock() hook.before_train_epoch(runner) - runner.cur_dataloader.sampler.set_epoch.assert_called() + runner.train_loop.dataloader.sampler.set_epoch.assert_called() # Test batch sampler runner = Mock() - runner.cur_dataloader = Mock() - runner.cur_dataloader.sampler = Mock(spec_set=True) - runner.cur_dataloader.batch_sampler = Mock() - runner.cur_dataloader.batch_sampler.sampler = Mock() - runner.cur_dataloader.batch_sampler.sampler.set_epoch = Mock() + runner.train_loop.dataloader = Mock() + runner.train_loop.dataloader.sampler = Mock(spec_set=True) + runner.train_loop.dataloader.batch_sampler = Mock() + runner.train_loop.dataloader.batch_sampler.sampler = Mock() + runner.train_loop.dataloader.batch_sampler.sampler.set_epoch = Mock() hook.before_train_epoch(runner) - runner.cur_dataloader.batch_sampler.sampler.set_epoch.assert_called() + runner.train_loop.dataloader.\ + batch_sampler.sampler.set_epoch.assert_called() diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index 0fdd8553..fa1b9e5b 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -36,7 +36,8 @@ class ToyModel(nn.Module): self.linear = nn.Linear(2, 1) def forward(self, data_batch, return_loss=False): - inputs, labels = zip(*data_batch) + inputs, labels = zip( + *map(lambda x: (x['inputs'], x['data_sample']), data_batch)) device = 'cuda:0' if torch.cuda.is_available() else 'cpu' inputs = torch.stack(inputs).to(device) labels = torch.stack(labels).to(device) @@ -67,7 +68,7 @@ class CustomModelWrapper(nn.Module): @DATASETS.register_module() class ToyDataset(Dataset): - META = dict() # type: ignore + METAINFO = dict() # type: ignore data = torch.randn(12, 2) label = torch.ones(12) @@ -75,7 +76,7 @@ class ToyDataset(Dataset): return self.data.size(0) def __getitem__(self, index): - return self.data[index], self.label[index] + return dict(inputs=self.data[index], data_sample=self.label[index]) @METRICS.register_module() -- GitLab