diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 34546cc264973db2619e6d2c2da1339073c2d988..c85017608a7c0965a19846df25f4d4bf6f31f62b 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import time import warnings from typing import Dict, List, Sequence, Union @@ -113,6 +114,51 @@ class EpochBasedTrainLoop(BaseLoop): self._iter += 1 +class _InfiniteDataloaderIterator: + """An infinite dataloader iterator wrapper for IterBasedTrainLoop. + + It resets the dataloader to continue iterating when the iterator has + iterated over all the data. However, this approach is not efficient, as the + workers need to be restarted every time the dataloader is reset. It is + recommended to use `mmengine.data.InfiniteSampler` to enable the dataloader + to iterate infinitely. + """ + + def __init__(self, dataloader: DataLoader) -> None: + self._dataloader = dataloader + self._iterator = iter(self._dataloader) + self._epoch = 0 + + def __iter__(self): + return self + + def __next__(self) -> Sequence[dict]: + try: + data = next(self._iterator) + except StopIteration: + warnings.warn('Reach the end of the dataloader, it will be ' + 'restarted and continue to iterate. It is ' + 'recommended to use `mmengine.data.InfiniteSampler` ' + 'to enable the dataloader to iterate infinitely.') + self._epoch += 1 + if hasattr(self._dataloader, 'sampler') and hasattr( + self._dataloader.sampler, 'set_epoch'): + # In case the` _SingleProcessDataLoaderIter` has no sampler, + # or data loader uses `SequentialSampler` in Pytorch. + self._dataloader.sampler.set_epoch(self._epoch) + + elif hasattr(self._dataloader, 'batch_sampler') and hasattr( + self._dataloader.batch_sampler.sampler, 'set_epoch'): + # In case the` _SingleProcessDataLoaderIter` has no batch + # sampler. batch sampler in pytorch warps the sampler as its + # attributes. + self._dataloader.batch_sampler.sampler.set_epoch(self._epoch) + time.sleep(2) # Prevent possible deadlock during epoch transition + self._iterator = iter(self._dataloader) + data = next(self._iterator) + return data + + @LOOPS.register_module() class IterBasedTrainLoop(BaseLoop): """Loop for iter-based training. @@ -149,7 +195,7 @@ class IterBasedTrainLoop(BaseLoop): 'metainfo. ``dataset_meta`` in visualizer will be ' 'None.') # get the iterator of the dataloader - self.dataloader_iterator = iter(self.dataloader) + self.dataloader_iterator = _InfiniteDataloaderIterator(self.dataloader) @property def max_epochs(self): diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index 1097cffa3e1b0342508cb628b6f4bd3e09cd3933..5dd91996048ab5f068e70f9cf1512d1f1e2b57b2 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -29,6 +29,7 @@ from mmengine.registry import (DATASETS, EVALUATOR, HOOKS, LOG_PROCESSORS, RUNNERS, Registry) from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop, Runner, TestLoop, ValLoop) +from mmengine.runner.loops import _InfiniteDataloaderIterator from mmengine.runner.priority import Priority, get_priority from mmengine.utils import is_list_of from mmengine.visualization import Visualizer @@ -1202,6 +1203,50 @@ class TestRunner(TestCase): val_batch_idx_targets): self.assertEqual(result, target) + # 4. test iter and epoch counter of IterBasedTrainLoop and timing of + # running ValLoop without InfiniteSampler + epoch_results = [] + iter_results = [] + batch_idx_results = [] + val_iter_results = [] + val_batch_idx_results = [] + iter_targets = [i for i in range(12)] + batch_idx_targets = [i for i in range(12)] + val_iter_targets = [i for i in range(4, 12)] + val_batch_idx_targets = [i for i in range(4)] * 2 + + cfg = copy.deepcopy(self.iter_based_cfg) + cfg.experiment_name = 'test_train4' + cfg.train_dataloader.sampler = dict( + type='DefaultSampler', shuffle=True) + cfg.custom_hooks = [dict(type='TestIterHook', priority=50)] + cfg.train_cfg = dict( + by_epoch=False, max_iters=12, val_interval=4, val_begin=4) + runner = Runner.from_cfg(cfg) + with self.assertWarnsRegex( + Warning, + 'Reach the end of the dataloader, it will be restarted and ' + 'continue to iterate.'): + runner.train() + + assert isinstance(runner.train_loop, IterBasedTrainLoop) + assert isinstance(runner.train_loop.dataloader_iterator, + _InfiniteDataloaderIterator) + + self.assertEqual(len(epoch_results), 1) + self.assertEqual(epoch_results[0], 0) + self.assertEqual(runner.val_interval, 4) + self.assertEqual(runner.val_begin, 4) + for result, target, in zip(iter_results, iter_targets): + self.assertEqual(result, target) + for result, target, in zip(batch_idx_results, batch_idx_targets): + self.assertEqual(result, target) + for result, target, in zip(val_iter_results, val_iter_targets): + self.assertEqual(result, target) + for result, target, in zip(val_batch_idx_results, + val_batch_idx_targets): + self.assertEqual(result, target) + def test_val(self): cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_val1'