Skip to content
Snippets Groups Projects
Unverified Commit 9c55b430 authored by Cedric Luo's avatar Cedric Luo Committed by GitHub
Browse files

[Enhance] Support dynamic interval (#342)


* support dynamic interval in iterbasedtrainloop

* update typehint

* update typehint

* add dynamic interval in epochbasedtrainloop

* update

* fix

Co-authored-by: default avatarluochunhua.vendor <luochunhua@pjlab.org.cn>
parent d65350a9
No related branches found
No related tags found
No related merge requests found
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import bisect
import time import time
import warnings import warnings
from typing import Dict, List, Sequence, Union from typing import Dict, List, Optional, Sequence, Tuple, Union
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
...@@ -11,6 +12,7 @@ from mmengine.registry import LOOPS ...@@ -11,6 +12,7 @@ from mmengine.registry import LOOPS
from mmengine.utils import is_list_of from mmengine.utils import is_list_of
from .amp import autocast from .amp import autocast
from .base_loop import BaseLoop from .base_loop import BaseLoop
from .utils import calc_dynamic_intervals
@LOOPS.register_module() @LOOPS.register_module()
...@@ -25,14 +27,20 @@ class EpochBasedTrainLoop(BaseLoop): ...@@ -25,14 +27,20 @@ class EpochBasedTrainLoop(BaseLoop):
val_begin (int): The epoch that begins validating. val_begin (int): The epoch that begins validating.
Defaults to 1. Defaults to 1.
val_interval (int): Validation interval. Defaults to 1. val_interval (int): Validation interval. Defaults to 1.
dynamic_intervals (List[Tuple[int, int]], optional): The
first element in the tuple is a milestone and the second
element is a interval. The interval is used after the
corresponding milestone. Defaults to None.
""" """
def __init__(self, def __init__(
runner, self,
dataloader: Union[DataLoader, Dict], runner,
max_epochs: int, dataloader: Union[DataLoader, Dict],
val_begin: int = 1, max_epochs: int,
val_interval: int = 1) -> None: val_begin: int = 1,
val_interval: int = 1,
dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None:
super().__init__(runner, dataloader) super().__init__(runner, dataloader)
self._max_epochs = max_epochs self._max_epochs = max_epochs
self._max_iters = max_epochs * len(self.dataloader) self._max_iters = max_epochs * len(self.dataloader)
...@@ -49,6 +57,10 @@ class EpochBasedTrainLoop(BaseLoop): ...@@ -49,6 +57,10 @@ class EpochBasedTrainLoop(BaseLoop):
'metainfo. ``dataset_meta`` in visualizer will be ' 'metainfo. ``dataset_meta`` in visualizer will be '
'None.') 'None.')
self.dynamic_milestones, self.dynamic_intervals = \
calc_dynamic_intervals(
self.val_interval, dynamic_intervals)
@property @property
def max_epochs(self): def max_epochs(self):
"""int: Total epochs to train model.""" """int: Total epochs to train model."""
...@@ -76,6 +88,7 @@ class EpochBasedTrainLoop(BaseLoop): ...@@ -76,6 +88,7 @@ class EpochBasedTrainLoop(BaseLoop):
while self._epoch < self._max_epochs: while self._epoch < self._max_epochs:
self.run_epoch() self.run_epoch()
self._decide_current_val_interval()
if (self.runner.val_loop is not None if (self.runner.val_loop is not None
and self._epoch >= self.val_begin and self._epoch >= self.val_begin
and self._epoch % self.val_interval == 0): and self._epoch % self.val_interval == 0):
...@@ -114,6 +127,11 @@ class EpochBasedTrainLoop(BaseLoop): ...@@ -114,6 +127,11 @@ class EpochBasedTrainLoop(BaseLoop):
outputs=outputs) outputs=outputs)
self._iter += 1 self._iter += 1
def _decide_current_val_interval(self) -> None:
"""Dynamically modify the ``val_interval``."""
step = bisect.bisect(self.dynamic_milestones, (self.epoch + 1))
self.val_interval = self.dynamic_intervals[step - 1]
class _InfiniteDataloaderIterator: class _InfiniteDataloaderIterator:
"""An infinite dataloader iterator wrapper for IterBasedTrainLoop. """An infinite dataloader iterator wrapper for IterBasedTrainLoop.
...@@ -172,14 +190,20 @@ class IterBasedTrainLoop(BaseLoop): ...@@ -172,14 +190,20 @@ class IterBasedTrainLoop(BaseLoop):
val_begin (int): The iteration that begins validating. val_begin (int): The iteration that begins validating.
Defaults to 1. Defaults to 1.
val_interval (int): Validation interval. Defaults to 1000. val_interval (int): Validation interval. Defaults to 1000.
dynamic_intervals (List[Tuple[int, int]], optional): The
first element in the tuple is a milestone and the second
element is a interval. The interval is used after the
corresponding milestone. Defaults to None.
""" """
def __init__(self, def __init__(
runner, self,
dataloader: Union[DataLoader, Dict], runner,
max_iters: int, dataloader: Union[DataLoader, Dict],
val_begin: int = 1, max_iters: int,
val_interval: int = 1000) -> None: val_begin: int = 1,
val_interval: int = 1000,
dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None:
super().__init__(runner, dataloader) super().__init__(runner, dataloader)
self._max_iters = max_iters self._max_iters = max_iters
self._max_epochs = 1 # for compatibility with EpochBasedTrainLoop self._max_epochs = 1 # for compatibility with EpochBasedTrainLoop
...@@ -198,6 +222,10 @@ class IterBasedTrainLoop(BaseLoop): ...@@ -198,6 +222,10 @@ class IterBasedTrainLoop(BaseLoop):
# get the iterator of the dataloader # get the iterator of the dataloader
self.dataloader_iterator = _InfiniteDataloaderIterator(self.dataloader) self.dataloader_iterator = _InfiniteDataloaderIterator(self.dataloader)
self.dynamic_milestones, self.dynamic_intervals = \
calc_dynamic_intervals(
self.val_interval, dynamic_intervals)
@property @property
def max_epochs(self): def max_epochs(self):
"""int: Total epochs to train model.""" """int: Total epochs to train model."""
...@@ -230,6 +258,7 @@ class IterBasedTrainLoop(BaseLoop): ...@@ -230,6 +258,7 @@ class IterBasedTrainLoop(BaseLoop):
data_batch = next(self.dataloader_iterator) data_batch = next(self.dataloader_iterator)
self.run_iter(data_batch) self.run_iter(data_batch)
self._decide_current_val_interval()
if (self.runner.val_loop is not None if (self.runner.val_loop is not None
and self._iter >= self.val_begin and self._iter >= self.val_begin
and self._iter % self.val_interval == 0): and self._iter % self.val_interval == 0):
...@@ -260,6 +289,11 @@ class IterBasedTrainLoop(BaseLoop): ...@@ -260,6 +289,11 @@ class IterBasedTrainLoop(BaseLoop):
outputs=outputs) outputs=outputs)
self._iter += 1 self._iter += 1
def _decide_current_val_interval(self) -> None:
"""Dynamically modify the ``val_interval``."""
step = bisect.bisect(self.dynamic_milestones, (self._iter + 1))
self.val_interval = self.dynamic_intervals[step - 1]
@LOOPS.register_module() @LOOPS.register_module()
class ValLoop(BaseLoop): class ValLoop(BaseLoop):
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple
from mmengine.utils.misc import is_list_of
def calc_dynamic_intervals(
start_interval: int,
dynamic_interval_list: Optional[List[Tuple[int, int]]] = None
) -> Tuple[List[int], List[int]]:
"""Calculate dynamic intervals.
Args:
start_interval (int): The interval used in the beginning.
dynamic_interval_list (List[Tuple[int, int]], optional): The
first element in the tuple is a milestone and the second
element is a interval. The interval is used after the
corresponding milestone. Defaults to None.
Returns:
Tuple[List[int], List[int]]: a list of milestone and its corresponding
intervals.
"""
if dynamic_interval_list is None:
return [0], [start_interval]
assert is_list_of(dynamic_interval_list, tuple)
dynamic_milestones = [0]
dynamic_milestones.extend(
[dynamic_interval[0] for dynamic_interval in dynamic_interval_list])
dynamic_intervals = [start_interval]
dynamic_intervals.extend(
[dynamic_interval[1] for dynamic_interval in dynamic_interval_list])
return dynamic_milestones, dynamic_intervals
...@@ -1283,6 +1283,80 @@ class TestRunner(TestCase): ...@@ -1283,6 +1283,80 @@ class TestRunner(TestCase):
val_batch_idx_targets): val_batch_idx_targets):
self.assertEqual(result, target) self.assertEqual(result, target)
# 5. test dynamic interval in IterBasedTrainLoop
max_iters = 12
interval = 5
dynamic_intervals = [(11, 2)]
iter_results = []
iter_targets = [5, 10, 12]
val_interval_results = []
val_interval_targets = [5] * 10 + [2] * 2
@HOOKS.register_module()
class TestIterDynamicIntervalHook(Hook):
def before_val(self, runner):
iter_results.append(runner.iter)
def before_train_iter(self, runner, batch_idx, data_batch=None):
val_interval_results.append(runner.train_loop.val_interval)
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_train5'
cfg.train_dataloader.sampler = dict(
type='DefaultSampler', shuffle=True)
cfg.custom_hooks = [
dict(type='TestIterDynamicIntervalHook', priority=50)
]
cfg.train_cfg = dict(
by_epoch=False,
max_iters=max_iters,
val_interval=interval,
dynamic_intervals=dynamic_intervals)
runner = Runner.from_cfg(cfg)
runner.train()
for result, target, in zip(iter_results, iter_targets):
self.assertEqual(result, target)
for result, target, in zip(val_interval_results, val_interval_targets):
self.assertEqual(result, target)
# 6. test dynamic interval in EpochBasedTrainLoop
max_epochs = 12
interval = 5
dynamic_intervals = [(11, 2)]
epoch_results = []
epoch_targets = [5, 10, 12]
val_interval_results = []
val_interval_targets = [5] * 10 + [2] * 2
@HOOKS.register_module()
class TestEpochDynamicIntervalHook(Hook):
def before_val_epoch(self, runner):
epoch_results.append(runner.epoch)
def before_train_epoch(self, runner):
val_interval_results.append(runner.train_loop.val_interval)
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_train6'
cfg.train_dataloader.sampler = dict(
type='DefaultSampler', shuffle=True)
cfg.custom_hooks = [
dict(type='TestEpochDynamicIntervalHook', priority=50)
]
cfg.train_cfg = dict(
by_epoch=True,
max_epochs=max_epochs,
val_interval=interval,
dynamic_intervals=dynamic_intervals)
runner = Runner.from_cfg(cfg)
runner.train()
for result, target, in zip(epoch_results, epoch_targets):
self.assertEqual(result, target)
for result, target, in zip(val_interval_results, val_interval_targets):
self.assertEqual(result, target)
def test_val(self): def test_val(self):
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_val1' cfg.experiment_name = 'test_val1'
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment