diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 1a6f26517ee837bebaf098a77d2da09ce6946fe2..3e32e54960cde61a1afb0a70249392cce1b4cc1e 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -1,7 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. +import bisect import time import warnings -from typing import Dict, List, Sequence, Union +from typing import Dict, List, Optional, Sequence, Tuple, Union import torch from torch.utils.data import DataLoader @@ -11,6 +12,7 @@ from mmengine.registry import LOOPS from mmengine.utils import is_list_of from .amp import autocast from .base_loop import BaseLoop +from .utils import calc_dynamic_intervals @LOOPS.register_module() @@ -25,14 +27,20 @@ class EpochBasedTrainLoop(BaseLoop): val_begin (int): The epoch that begins validating. Defaults to 1. val_interval (int): Validation interval. Defaults to 1. + dynamic_intervals (List[Tuple[int, int]], optional): The + first element in the tuple is a milestone and the second + element is a interval. The interval is used after the + corresponding milestone. Defaults to None. """ - def __init__(self, - runner, - dataloader: Union[DataLoader, Dict], - max_epochs: int, - val_begin: int = 1, - val_interval: int = 1) -> None: + def __init__( + self, + runner, + dataloader: Union[DataLoader, Dict], + max_epochs: int, + val_begin: int = 1, + val_interval: int = 1, + dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None: super().__init__(runner, dataloader) self._max_epochs = max_epochs self._max_iters = max_epochs * len(self.dataloader) @@ -49,6 +57,10 @@ class EpochBasedTrainLoop(BaseLoop): 'metainfo. ``dataset_meta`` in visualizer will be ' 'None.') + self.dynamic_milestones, self.dynamic_intervals = \ + calc_dynamic_intervals( + self.val_interval, dynamic_intervals) + @property def max_epochs(self): """int: Total epochs to train model.""" @@ -76,6 +88,7 @@ class EpochBasedTrainLoop(BaseLoop): while self._epoch < self._max_epochs: self.run_epoch() + self._decide_current_val_interval() if (self.runner.val_loop is not None and self._epoch >= self.val_begin and self._epoch % self.val_interval == 0): @@ -114,6 +127,11 @@ class EpochBasedTrainLoop(BaseLoop): outputs=outputs) self._iter += 1 + def _decide_current_val_interval(self) -> None: + """Dynamically modify the ``val_interval``.""" + step = bisect.bisect(self.dynamic_milestones, (self.epoch + 1)) + self.val_interval = self.dynamic_intervals[step - 1] + class _InfiniteDataloaderIterator: """An infinite dataloader iterator wrapper for IterBasedTrainLoop. @@ -172,14 +190,20 @@ class IterBasedTrainLoop(BaseLoop): val_begin (int): The iteration that begins validating. Defaults to 1. val_interval (int): Validation interval. Defaults to 1000. + dynamic_intervals (List[Tuple[int, int]], optional): The + first element in the tuple is a milestone and the second + element is a interval. The interval is used after the + corresponding milestone. Defaults to None. """ - def __init__(self, - runner, - dataloader: Union[DataLoader, Dict], - max_iters: int, - val_begin: int = 1, - val_interval: int = 1000) -> None: + def __init__( + self, + runner, + dataloader: Union[DataLoader, Dict], + max_iters: int, + val_begin: int = 1, + val_interval: int = 1000, + dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None: super().__init__(runner, dataloader) self._max_iters = max_iters self._max_epochs = 1 # for compatibility with EpochBasedTrainLoop @@ -198,6 +222,10 @@ class IterBasedTrainLoop(BaseLoop): # get the iterator of the dataloader self.dataloader_iterator = _InfiniteDataloaderIterator(self.dataloader) + self.dynamic_milestones, self.dynamic_intervals = \ + calc_dynamic_intervals( + self.val_interval, dynamic_intervals) + @property def max_epochs(self): """int: Total epochs to train model.""" @@ -230,6 +258,7 @@ class IterBasedTrainLoop(BaseLoop): data_batch = next(self.dataloader_iterator) self.run_iter(data_batch) + self._decide_current_val_interval() if (self.runner.val_loop is not None and self._iter >= self.val_begin and self._iter % self.val_interval == 0): @@ -260,6 +289,11 @@ class IterBasedTrainLoop(BaseLoop): outputs=outputs) self._iter += 1 + def _decide_current_val_interval(self) -> None: + """Dynamically modify the ``val_interval``.""" + step = bisect.bisect(self.dynamic_milestones, (self._iter + 1)) + self.val_interval = self.dynamic_intervals[step - 1] + @LOOPS.register_module() class ValLoop(BaseLoop): diff --git a/mmengine/runner/utils.py b/mmengine/runner/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..02e12dc84aa44d91e957cd39936849b51a779625 --- /dev/null +++ b/mmengine/runner/utils.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple + +from mmengine.utils.misc import is_list_of + + +def calc_dynamic_intervals( + start_interval: int, + dynamic_interval_list: Optional[List[Tuple[int, int]]] = None +) -> Tuple[List[int], List[int]]: + """Calculate dynamic intervals. + + Args: + start_interval (int): The interval used in the beginning. + dynamic_interval_list (List[Tuple[int, int]], optional): The + first element in the tuple is a milestone and the second + element is a interval. The interval is used after the + corresponding milestone. Defaults to None. + + Returns: + Tuple[List[int], List[int]]: a list of milestone and its corresponding + intervals. + """ + if dynamic_interval_list is None: + return [0], [start_interval] + + assert is_list_of(dynamic_interval_list, tuple) + + dynamic_milestones = [0] + dynamic_milestones.extend( + [dynamic_interval[0] for dynamic_interval in dynamic_interval_list]) + dynamic_intervals = [start_interval] + dynamic_intervals.extend( + [dynamic_interval[1] for dynamic_interval in dynamic_interval_list]) + return dynamic_milestones, dynamic_intervals diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index 5c75094eb1b7fa4c64cc208635506b77f5ba77bd..e9afa693e377f37e064585aa07e06f35a15c09cb 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -1283,6 +1283,80 @@ class TestRunner(TestCase): val_batch_idx_targets): self.assertEqual(result, target) + # 5. test dynamic interval in IterBasedTrainLoop + max_iters = 12 + interval = 5 + dynamic_intervals = [(11, 2)] + iter_results = [] + iter_targets = [5, 10, 12] + val_interval_results = [] + val_interval_targets = [5] * 10 + [2] * 2 + + @HOOKS.register_module() + class TestIterDynamicIntervalHook(Hook): + + def before_val(self, runner): + iter_results.append(runner.iter) + + def before_train_iter(self, runner, batch_idx, data_batch=None): + val_interval_results.append(runner.train_loop.val_interval) + + cfg = copy.deepcopy(self.iter_based_cfg) + cfg.experiment_name = 'test_train5' + cfg.train_dataloader.sampler = dict( + type='DefaultSampler', shuffle=True) + cfg.custom_hooks = [ + dict(type='TestIterDynamicIntervalHook', priority=50) + ] + cfg.train_cfg = dict( + by_epoch=False, + max_iters=max_iters, + val_interval=interval, + dynamic_intervals=dynamic_intervals) + runner = Runner.from_cfg(cfg) + runner.train() + for result, target, in zip(iter_results, iter_targets): + self.assertEqual(result, target) + for result, target, in zip(val_interval_results, val_interval_targets): + self.assertEqual(result, target) + + # 6. test dynamic interval in EpochBasedTrainLoop + max_epochs = 12 + interval = 5 + dynamic_intervals = [(11, 2)] + epoch_results = [] + epoch_targets = [5, 10, 12] + val_interval_results = [] + val_interval_targets = [5] * 10 + [2] * 2 + + @HOOKS.register_module() + class TestEpochDynamicIntervalHook(Hook): + + def before_val_epoch(self, runner): + epoch_results.append(runner.epoch) + + def before_train_epoch(self, runner): + val_interval_results.append(runner.train_loop.val_interval) + + cfg = copy.deepcopy(self.epoch_based_cfg) + cfg.experiment_name = 'test_train6' + cfg.train_dataloader.sampler = dict( + type='DefaultSampler', shuffle=True) + cfg.custom_hooks = [ + dict(type='TestEpochDynamicIntervalHook', priority=50) + ] + cfg.train_cfg = dict( + by_epoch=True, + max_epochs=max_epochs, + val_interval=interval, + dynamic_intervals=dynamic_intervals) + runner = Runner.from_cfg(cfg) + runner.train() + for result, target, in zip(epoch_results, epoch_targets): + self.assertEqual(result, target) + for result, target, in zip(val_interval_results, val_interval_targets): + self.assertEqual(result, target) + def test_val(self): cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_val1'