From a7961407e4da6ee7aaf630792b851f42c6ba1706 Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Sun, 13 Mar 2022 16:48:09 +0800 Subject: [PATCH] [Refactor] Refactor the interfaces of Hook and its subclassed (#117) * Fix hook * Fix * Fix docs * FIx * Fix * Fix as comment * update * Fix hook * Fix hook * Fix hook * Fix itertimerhook * Fix iter_timer_hook * Fix * Fix * fix logger hook * Fix loggerhook * update cur_dataloader * Fix docstring * Fix docstring * Fix as commet * Fix as commet * Fix as comment * rename is_last_epoch, enhance and add after_val before_val .etc * fix typo in docstring * remove resolved TODO * refactor docstring --- mmengine/hooks/checkpoint_hook.py | 10 +- mmengine/hooks/empty_cache_hook.py | 31 +-- mmengine/hooks/hook.py | 242 ++++++++++++---------- mmengine/hooks/iter_timer_hook.py | 27 ++- mmengine/hooks/logger_hook.py | 9 +- mmengine/hooks/sampler_seed_hook.py | 10 +- mmengine/hooks/sync_buffer_hook.py | 2 +- tests/test_hook/test_empty_cache_hook.py | 6 +- tests/test_hook/test_hook.py | 36 ++-- tests/test_hook/test_iter_timer_hook.py | 14 +- tests/test_hook/test_logger_hook.py | 8 +- tests/test_hook/test_sampler_seed_hook.py | 24 +-- tests/test_hook/test_sync_buffers_hook.py | 2 +- 13 files changed, 235 insertions(+), 186 deletions(-) diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index d8c75d1f..313900b7 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -119,9 +119,8 @@ class CheckpointHook(Hook): # save checkpoint for following cases: # 1. every ``self.interval`` epochs # 2. reach the last epoch of training - if self.every_n_epochs( - runner, self.interval) or (self.save_last - and self.is_last_epoch(runner)): + if self.every_n_epochs(runner, self.interval) or ( + self.save_last and self.is_last_train_epoch(runner)): runner.logger.info(f'Saving checkpoint at \ {runner.epoch + 1} epochs') if self.sync_buffer: @@ -187,9 +186,8 @@ class CheckpointHook(Hook): # save checkpoint for following cases: # 1. every ``self.interval`` iterations # 2. reach the last iteration of training - if self.every_n_iters( - runner, self.interval) or (self.save_last - and self.is_last_iter(runner)): + if self.every_n_iters(runner, self.interval) or \ + (self.save_last and self.is_last_iter(runner, mode='train')): runner.logger.info(f'Saving checkpoint at \ {runner.iter + 1} iterations') if self.sync_buffer: diff --git a/mmengine/hooks/empty_cache_hook.py b/mmengine/hooks/empty_cache_hook.py index a37de337..cb20a744 100644 --- a/mmengine/hooks/empty_cache_hook.py +++ b/mmengine/hooks/empty_cache_hook.py @@ -30,16 +30,16 @@ class EmptyCacheHook(Hook): before_epoch: bool = False, after_epoch: bool = True, after_iter: bool = False) -> None: - self._before_epoch = before_epoch - self._after_epoch = after_epoch - self._after_iter = after_iter + self._do_before_epoch = before_epoch + self._do_after_epoch = after_epoch + self._do_after_iter = after_iter - def after_iter(self, - runner, - data_batch: DATA_BATCH = None, - outputs: - Optional[Union[dict, Sequence[BaseDataSample]]] = None)\ - -> None: + def _after_iter(self, + runner, + data_batch: DATA_BATCH = None, + outputs: Optional[Union[dict, + Sequence[BaseDataSample]]] = None, + mode: str = 'train') -> None: """Empty cache after an iteration. Args: @@ -48,24 +48,27 @@ class EmptyCacheHook(Hook): from dataloader. Defaults to None. outputs (dict or sequence, optional): Outputs from model. Defaults to None. + mode (str): Current mode of runner. Defaults to 'train'. """ - if self._after_iter: + if self._do_after_iter: torch.cuda.empty_cache() - def before_epoch(self, runner) -> None: + def _before_epoch(self, runner, mode: str = 'train') -> None: """Empty cache before an epoch. Args: runner (Runner): The runner of the training process. + mode (str): Current mode of runner. Defaults to 'train'. """ - if self._before_epoch: + if self._do_before_epoch: torch.cuda.empty_cache() - def after_epoch(self, runner) -> None: + def _after_epoch(self, runner, mode: str = 'train') -> None: """Empty cache after an epoch. Args: runner (Runner): The runner of the training process. + mode (str): Current mode of runner. Defaults to 'train'. """ - if self._after_epoch: + if self._do_after_epoch: torch.cuda.empty_cache() diff --git a/mmengine/hooks/hook.py b/mmengine/hooks/hook.py index 91729375..7eeb5e54 100644 --- a/mmengine/hooks/hook.py +++ b/mmengine/hooks/hook.py @@ -16,20 +16,20 @@ class Hook: def before_run(self, runner) -> None: """All subclasses should override this method, if they need any - operations before the training process. + operations before the training validation or testing process. Args: - runner (Runner): The runner of the training/validation/testing + runner (Runner): The runner of the training, validation or testing process. """ pass def after_run(self, runner) -> None: """All subclasses should override this method, if they need any - operations after the training process. + operations before the training validation or testing process. Args: - runner (Runner): The runner of the training/validation/testing + runner (Runner): The runner of the training, validation or testing process. """ pass @@ -54,7 +54,7 @@ class Hook: def before_val(self, runner) -> None: """All subclasses should override this method, if they need any - operations before val. + operations before validation. Args: runner (Runner): The runner of the validation process. @@ -63,7 +63,7 @@ class Hook: def after_val(self, runner) -> None: """All subclasses should override this method, if they need any - operations after val. + operations after validation. Args: runner (Runner): The runner of the validation process. @@ -72,7 +72,7 @@ class Hook: def before_test(self, runner) -> None: """All subclasses should override this method, if they need any - operations before test. + operations before testing. Args: runner (Runner): The runner of the testing process. @@ -81,67 +81,21 @@ class Hook: def after_test(self, runner) -> None: """All subclasses should override this method, if they need any - operations after test. + operations after testing. Args: runner (Runner): The runner of the testing process. """ pass - def before_epoch(self, runner) -> None: - """All subclasses should override this method, if they need any - operations before each epoch. - - Args: - runner (Runner): The runner of the training process. - """ - pass - - def after_epoch(self, runner) -> None: - """All subclasses should override this method, if they need any - operations after each epoch. - - Args: - runner (Runner): The runner of the training process. - """ - pass - - def before_iter(self, runner, data_batch: DATA_BATCH = None) -> None: - """All subclasses should override this method, if they need any - operations before each iter. - - Args: - runner (Runner): The runner of the training process. - data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): - Data from dataloader. Defaults to None. - """ - pass - - def after_iter(self, - runner, - data_batch: DATA_BATCH = None, - outputs: - Optional[Union[dict, Sequence[BaseDataSample]]] = None) \ - -> None: - """All subclasses should override this method, if they need any - operations after each epoch. - - Args: - runner (Runner): The runner of the training process. - data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): - Data from dataloader. Defaults to None. - outputs (dict or sequence, optional): Outputs from model. Defaults - to None. - """ - pass - def before_save_checkpoint(self, runner, checkpoint: dict) -> None: """All subclasses should override this method, if they need any operations before saving the checkpoint. Args: - runner (Runner): The runner of the training process. - checkpoints (dict): Model's checkpoint. + runner (Runner): The runner of the training, validation or testing + process. + checkpoint (dict): Model's checkpoint. """ pass @@ -150,8 +104,9 @@ class Hook: operations after loading the checkpoint. Args: - runner (Runner): The runner of the training process. - checkpoints (dict): Model's checkpoint. + runner (Runner): The runner of the training, validation or testing + process. + checkpoint (dict): Model's checkpoint. """ pass @@ -162,25 +117,25 @@ class Hook: Args: runner (Runner): The runner of the training process. """ - self.before_epoch(runner) + self._before_epoch(runner, mode='train') def before_val_epoch(self, runner) -> None: """All subclasses should override this method, if they need any operations before each validation epoch. Args: - runner (Runner): The runner of the training process. + runner (Runner): The runner of the validation process. """ - self.before_epoch(runner) + self._before_epoch(runner, mode='val') def before_test_epoch(self, runner) -> None: """All subclasses should override this method, if they need any operations before each test epoch. Args: - runner (Runner): The runner of the training process. + runner (Runner): The runner of the testing process. """ - self.before_epoch(runner) + self._before_epoch(runner, mode='test') def after_train_epoch(self, runner) -> None: """All subclasses should override this method, if they need any @@ -189,25 +144,25 @@ class Hook: Args: runner (Runner): The runner of the training process. """ - self.after_epoch(runner) + self._after_epoch(runner, mode='train') def after_val_epoch(self, runner) -> None: """All subclasses should override this method, if they need any operations after each validation epoch. Args: - runner (Runner): The runner of the training process. + runner (Runner): The runner of the validation process. """ - self.after_epoch(runner) + self._after_epoch(runner, mode='val') def after_test_epoch(self, runner) -> None: """All subclasses should override this method, if they need any operations after each test epoch. Args: - runner (Runner): The runner of the training process. + runner (Runner): The runner of the testing process. """ - self.after_epoch(runner) + self._after_epoch(runner, mode='test') def before_train_iter(self, runner, data_batch: DATA_BATCH = None) -> None: """All subclasses should override this method, if they need any @@ -218,29 +173,29 @@ class Hook: data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. """ - self.before_iter(runner, data_batch=None) + self._before_iter(runner, data_batch=data_batch, mode='train') def before_val_iter(self, runner, data_batch: DATA_BATCH = None) -> None: """All subclasses should override this method, if they need any operations before each validation iteration. Args: - runner (Runner): The runner of the training process. + runner (Runner): The runner of the validation process. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. """ - self.before_iter(runner, data_batch=None) + self._before_iter(runner, data_batch=data_batch, mode='val') def before_test_iter(self, runner, data_batch: DATA_BATCH = None) -> None: """All subclasses should override this method, if they need any operations before each test iteration. Args: - runner (Runner): The runner of the training process. + runner (Runner): The runner of the testing process. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. """ - self.before_iter(runner, data_batch=None) + self._before_iter(runner, data_batch=data_batch, mode='test') def after_train_iter(self, runner, @@ -256,7 +211,8 @@ class Hook: outputs (dict, optional): Outputs from model. Defaults to None. """ - self.after_iter(runner, data_batch=None, outputs=None) + self._after_iter( + runner, data_batch=data_batch, outputs=outputs, mode='train') def after_val_iter(self, runner, @@ -267,13 +223,14 @@ class Hook: operations after each validation iteration. Args: - runner (Runner): The runner of the training process. + runner (Runner): The runner of the validation process. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. outputs (dict or sequence, optional): Outputs from model. Defaults to None. """ - self.after_iter(runner, data_batch=None, outputs=None) + self._after_iter( + runner, data_batch=data_batch, outputs=outputs, mode='val') def after_test_iter( self, @@ -284,48 +241,108 @@ class Hook: operations after each test iteration. Args: - runner (Runner): The runner of the training process. + runner (Runner): The runner of the training process. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. outputs (dict, optional): Outputs from model. Defaults to None. """ - self.after_iter(runner, data_batch=None, outputs=None) + self._after_iter( + runner, data_batch=data_batch, outputs=outputs, mode='test') + + def _before_epoch(self, runner, mode: str = 'train') -> None: + """All subclasses should override this method, if they need any + operations before each epoch. + + Args: + runner (Runner): The runner of the training, validation or testing + process. + mode (str): Current mode of runner. Defaults to 'train'. + """ + pass + + def _after_epoch(self, runner, mode: str = 'train') -> None: + """All subclasses should override this method, if they need any + operations after each epoch. + + Args: + runner (Runner): The runner of the training, validation or testing + process. + mode (str): Current mode of runner. Defaults to 'train'. + """ + pass + + def _before_iter(self, + runner, + data_batch: DATA_BATCH = None, + mode: str = 'train') -> None: + """All subclasses should override this method, if they need any + operations before each iter. + + Args: + runner (Runner): The runner of the training, validation or testing + process. + data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): + Data from dataloader. Defaults to None. + mode (str): Current mode of runner. Defaults to 'train'. + """ + pass + + def _after_iter(self, + runner, + data_batch: DATA_BATCH = None, + outputs: Optional[Union[Sequence[BaseDataSample], + dict]] = None, + mode: str = 'train') -> None: + """All subclasses should override this method, if they need any + operations after each epoch. + + Args: + runner (Runner): The runner of the training, validation or testing + process. + data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): + Data from dataloader. Defaults to None. + outputs (Sequence[BaseDataSample], optional): Outputs from model. + Defaults to None. + mode (str): Current mode of runner. Defaults to 'train'. + """ + pass def every_n_epochs(self, runner, n: int) -> bool: - """Test whether or not current epoch can be evenly divided by n. + """Test whether current epoch can be evenly divided by n. Args: - runner (Runner): The runner of the training process. - n (int): Whether or not current epoch can be evenly divided by n. + runner (Runner): The runner of the training, validation or testing + process. + n (int): Whether current epoch can be evenly divided by n. Returns: - bool: whether or not current epoch can be evenly divided by n. + bool: Whether current epoch can be evenly divided by n. """ return (runner.epoch + 1) % n == 0 if n > 0 else False def every_n_inner_iters(self, runner, n: int) -> bool: - """Test whether or not current inner iteration can be evenly divided by - n. + """Test whether current inner iteration can be evenly divided by n. Args: - runner (Runner): The runner of the training process. - n (int): Whether or not current inner iteration can be evenly + runner (Runner): The runner of the training, validation or testing + process. + n (int): Whether current inner iteration can be evenly divided by n. Returns: - bool: whether or not current inner iteration can be evenly + bool: Whether current inner iteration can be evenly divided by n. """ return (runner.inner_iter + 1) % n == 0 if n > 0 else False def every_n_iters(self, runner, n: int) -> bool: - """Test whether or not current iteration can be evenly divided by n. + """Test whether current iteration can be evenly divided by n. Args: - runner (Runner): The runner of the training process. - n (int): Whether or not current iteration can be - evenly divided by n. + runner (Runner): The runner of the training, validation or testing + process. + n (int): Whether current iteration can be evenly divided by n. Returns: bool: Return True if the current iteration can be evenly divided @@ -334,35 +351,46 @@ class Hook: return (runner.iter + 1) % n == 0 if n > 0 else False def end_of_epoch(self, runner) -> bool: - """Check whether the current epoch reaches the `max_epochs` or not. + """Check whether the current iteration reaches the last iteration of + current dataloader. Args: - runner (Runner): The runner of the training process. + runner (Runner): The runner of the training, validation or testing + process. Returns: - bool: whether the end of current epoch or not. + bool: Whether reaches the end of current epoch or not. """ - return runner.inner_iter + 1 == len(runner.data_loader) + return runner.inner_iter + 1 == len(runner.cur_dataloader) - def is_last_epoch(self, runner) -> bool: - """Test whether or not current epoch is the last epoch. + def is_last_train_epoch(self, runner) -> bool: + """Test whether current epoch is the last train epoch. Args: runner (Runner): The runner of the training process. Returns: - bool: bool: Return True if the current epoch reaches the - `max_epochs`, otherwise False. + bool: Whether reaches the end of training epoch. """ - return runner.epoch + 1 == runner._max_epochs + return runner.epoch + 1 == runner.train_loop.max_epochs - def is_last_iter(self, runner) -> bool: - """Test whether or not current epoch is the last iteration. + def is_last_iter(self, runner, mode='train') -> bool: + """Test whether current iteration is the last iteration. Args: - runner (Runner): The runner of the training process. + runner (Runner): The runner of the training, validation or testing + process. Returns: - bool: whether or not current iteration is the last iteration. - """ - return runner.iter + 1 == runner._max_iters + bool: Whether current iteration is the last iteration. + mode (str): Current mode of runner. Defaults to 'train'. + """ + if mode == 'train': + return runner.iter + 1 == runner.train_loop.max_iters + elif mode == 'val': + return runner.iter + 1 == runner.val_loop.max_iters + elif mode == 'test': + return runner.iter + 1 == runner.test_loop.max_iters + else: + raise ValueError('mode should be train, val or test but got' + f'{mode}') diff --git a/mmengine/hooks/iter_timer_hook.py b/mmengine/hooks/iter_timer_hook.py index 72701504..824774ed 100644 --- a/mmengine/hooks/iter_timer_hook.py +++ b/mmengine/hooks/iter_timer_hook.py @@ -18,30 +18,37 @@ class IterTimerHook(Hook): priority = 'NORMAL' - def before_epoch(self, runner) -> None: + def _before_epoch(self, runner, mode: str = 'train') -> None: """Record time flag before start a epoch. Args: runner (Runner): The runner of the training process. + mode (str): Current mode of runner. Defaults to 'train'. """ self.t = time.time() - def before_iter(self, runner, data_batch: DATA_BATCH = None) -> None: + def _before_iter(self, + runner, + data_batch: DATA_BATCH = None, + mode: str = 'train') -> None: """Logging time for loading data and update the time flag. Args: runner (Runner): The runner of the training process. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. + mode (str): Current mode of runner. Defaults to 'train'. """ # TODO: update for new logging system - runner.log_buffer.update({'data_time': time.time() - self.t}) + runner.message_hub.update_log(f'{mode}/data_time', + time.time() - self.t) - def after_iter(self, - runner, - data_batch: DATA_BATCH = None, - outputs: - Optional[Union[dict, Sequence[BaseDataSample]]] = None) \ + def _after_iter(self, + runner, + data_batch: DATA_BATCH = None, + outputs: + Optional[Union[dict, Sequence[BaseDataSample]]] = None, + mode: str = 'train') \ -> None: """Logging time for a iteration and update the time flag. @@ -51,7 +58,9 @@ class IterTimerHook(Hook): from dataloader. Defaults to None. outputs (dict or sequence, optional): Outputs from model. Defaults to None. + mode (str): Current mode of runner. Defaults to 'train'. """ # TODO: update for new logging system - runner.log_buffer.update({'time': time.time() - self.t}) + + runner.message_hub.update_log(f'{mode}/time', time.time() - self.t) self.t = time.time() diff --git a/mmengine/hooks/logger_hook.py b/mmengine/hooks/logger_hook.py index b557c2d0..9e2518bf 100644 --- a/mmengine/hooks/logger_hook.py +++ b/mmengine/hooks/logger_hook.py @@ -264,14 +264,15 @@ class LoggerHook(Hook): # by iter: Iter [100/100000] if self.by_epoch: log_str = f'Epoch [{cur_epoch}]' \ - f'[{cur_iter}/{len(runner.data_loader)}]\t' + f'[{cur_iter}/{len(runner.cur_dataloader)}]\t' else: - log_str = f'Iter [{cur_iter}/{runner.max_iters}]\t' + log_str = f'Iter [{cur_iter}/{runner.train_loop.max_iters}]\t' log_str += f'{lr_momentum_str}, ' # Calculate eta time. self.time_sec_tot += (tag['time'] * self.interval) time_sec_avg = self.time_sec_tot / (runner.iter - self.start_iter + 1) - eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1) + eta_sec = time_sec_avg * ( + runner.train_loop.max_iters - runner.iter - 1) eta_str = str(datetime.timedelta(seconds=int(eta_sec))) log_str += f'eta: {eta_str}, ' log_str += f'time: {tag["time"]:.3f}, ' \ @@ -302,7 +303,7 @@ class LoggerHook(Hook): """ tag = self._collect_info(runner, 'val') # Compatible with function `log` https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/logger/text.py # noqa E501 - eval_iter = len(runner.data_loader) + eval_iter = len(runner.cur_dataloader) cur_iter = self._get_iter(runner) cur_epoch = self._get_epoch(runner, 'val') # val/test time diff --git a/mmengine/hooks/sampler_seed_hook.py b/mmengine/hooks/sampler_seed_hook.py index d7d243d8..eed3fa90 100644 --- a/mmengine/hooks/sampler_seed_hook.py +++ b/mmengine/hooks/sampler_seed_hook.py @@ -14,15 +14,15 @@ class DistSamplerSeedHook(Hook): priority = 'NORMAL' - def before_epoch(self, runner) -> None: + def before_train_epoch(self, runner, mode: str = 'train') -> None: """Set the seed for sampler and batch_sampler. Args: runner (Runner): The runner of the training process. """ - if hasattr(runner.data_loader.sampler, 'set_epoch'): + if hasattr(runner.cur_dataloader.sampler, 'set_epoch'): # in case the data loader uses `SequentialSampler` in Pytorch - runner.data_loader.sampler.set_epoch(runner.epoch) - elif hasattr(runner.data_loader.batch_sampler.sampler, 'set_epoch'): + runner.cur_dataloader.sampler.set_epoch(runner.epoch) + elif hasattr(runner.cur_dataloader.batch_sampler.sampler, 'set_epoch'): # batch sampler in pytorch warps the sampler as its attributes. - runner.data_loader.batch_sampler.sampler.set_epoch(runner.epoch) + runner.cur_dataloader.batch_sampler.sampler.set_epoch(runner.epoch) diff --git a/mmengine/hooks/sync_buffer_hook.py b/mmengine/hooks/sync_buffer_hook.py index 9aa6402d..37b62f98 100644 --- a/mmengine/hooks/sync_buffer_hook.py +++ b/mmengine/hooks/sync_buffer_hook.py @@ -89,7 +89,7 @@ class SyncBuffersHook(Hook): def __init__(self) -> None: self.distributed = dist.IS_DIST - def after_epoch(self, runner) -> None: + def after_train_epoch(self, runner) -> None: """All-reduce model buffers at the end of each epoch. Args: diff --git a/tests/test_hook/test_empty_cache_hook.py b/tests/test_hook/test_empty_cache_hook.py index a09806d0..dc8ce8fc 100644 --- a/tests/test_hook/test_empty_cache_hook.py +++ b/tests/test_hook/test_empty_cache_hook.py @@ -9,6 +9,6 @@ class TestEmptyCacheHook: def test_emtpy_cache_hook(self): Hook = EmptyCacheHook(True, True, True) Runner = Mock() - Hook.after_iter(Runner) - Hook.before_epoch(Runner) - Hook.after_epoch(Runner) + Hook._after_iter(Runner) + Hook._before_epoch(Runner) + Hook._after_epoch(Runner) diff --git a/tests/test_hook/test_hook.py b/tests/test_hook/test_hook.py index 5884a161..d55f8a66 100644 --- a/tests/test_hook/test_hook.py +++ b/tests/test_hook/test_hook.py @@ -1,6 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. from unittest.mock import Mock +import pytest + from mmengine.hooks import Hook @@ -19,25 +21,25 @@ class TestHook: def test_before_epoch(self): hook = Hook() runner = Mock() - hook.before_epoch(runner) + hook._before_epoch(runner) def test_after_epoch(self): hook = Hook() runner = Mock() - hook.after_epoch(runner) + hook._after_epoch(runner) def test_before_iter(self): hook = Hook() runner = Mock() data_batch = {} - hook.before_iter(runner, data_batch) + hook._before_iter(runner, data_batch) def test_after_iter(self): hook = Hook() runner = Mock() data_batch = {} outputs = {} - hook.after_iter(runner, data_batch, outputs) + hook._after_iter(runner, data_batch, outputs) def test_before_save_checkpoint(self): hook = Hook() @@ -161,7 +163,8 @@ class TestHook: # last inner iter runner.inner_iter = 1 - runner.data_loader.__len__ = Mock(return_value=2) + runner.cur_dataloader.__len__ = Mock(return_value=2) + runner.cur_dataloader.__len__ = Mock(return_value=2) return_val = hook.end_of_epoch(runner) assert return_val @@ -170,19 +173,19 @@ class TestHook: return_val = hook.end_of_epoch(runner) assert not return_val - def test_is_last_epoch(self): + def test_is_last_train_epoch(self): hook = Hook() runner = Mock() # last epoch runner.epoch = 1 - runner._max_epochs = 2 - return_val = hook.is_last_epoch(runner) + runner.train_loop.max_epochs = 2 + return_val = hook.is_last_train_epoch(runner) assert return_val # not the last epoch - runner.epoch = 0 - return_val = hook.is_last_epoch(runner) + runner.train_loop.max_epochs = 0 + return_val = hook.is_last_train_epoch(runner) assert not return_val def test_is_last_iter(self): @@ -191,11 +194,18 @@ class TestHook: # last iter runner.iter = 1 - runner._max_iters = 2 + runner.train_loop.max_iters = 2 return_val = hook.is_last_iter(runner) assert return_val # not the last iter - runner.iter = 0 - return_val = hook.is_last_iter(runner) + runner.val_loop.max_iters = 0 + return_val = hook.is_last_iter(runner, mode='val') assert not return_val + + runner.test_loop.max_iters = 0 + return_val = hook.is_last_iter(runner, mode='test') + assert not return_val + + with pytest.raises(ValueError): + hook.is_last_iter(runner, mode='error_mode') diff --git a/tests/test_hook/test_iter_timer_hook.py b/tests/test_hook/test_iter_timer_hook.py index 5e3b6e71..8b20d4f1 100644 --- a/tests/test_hook/test_iter_timer_hook.py +++ b/tests/test_hook/test_iter_timer_hook.py @@ -9,21 +9,21 @@ class TestIterTimerHook: def test_before_epoch(self): Hook = IterTimerHook() Runner = Mock() - Hook.before_epoch(Runner) + Hook._before_epoch(Runner) assert isinstance(Hook.t, float) def test_before_iter(self): Hook = IterTimerHook() Runner = Mock() Runner.log_buffer = dict() - Hook.before_epoch(Runner) - Hook.before_iter(Runner) - assert 'data_time' in Runner.log_buffer + Hook._before_epoch(Runner) + Hook._before_iter(Runner) + Runner.message_hub.update_log.assert_called() def test_after_iter(self): Hook = IterTimerHook() Runner = Mock() Runner.log_buffer = dict() - Hook.before_epoch(Runner) - Hook.after_iter(Runner) - assert 'time' in Runner.log_buffer + Hook._before_epoch(Runner) + Hook._after_iter(Runner) + Runner.message_hub.update_log.assert_called() diff --git a/tests/test_hook/test_logger_hook.py b/tests/test_hook/test_logger_hook.py index 90faad77..d92ac9bc 100644 --- a/tests/test_hook/test_logger_hook.py +++ b/tests/test_hook/test_logger_hook.py @@ -110,7 +110,7 @@ class TestLoggerHook: # Test end of the epoch. logger_hook = LoggerHook(by_epoch=True, ignore_last=False) logger_hook._log_train = MagicMock() - runner.data_loader = [0] * 5 + runner.cur_dataloader = [0] * 5 runner.inner_iter = 4 logger_hook.after_train_iter(runner) logger_hook._log_train.assert_called() @@ -155,7 +155,7 @@ class TestLoggerHook: out, _ = capsys.readouterr() time_avg = logger_hook.time_sec_tot / ( runner.iter + 1 - logger_hook.start_iter) - eta_second = time_avg * (runner.max_iters - runner.iter - 1) + eta_second = time_avg * (runner.train_loop.max_iters - runner.iter - 1) eta_str = str(datetime.timedelta(seconds=int(eta_second))) if by_epoch: if torch.cuda.is_available(): @@ -337,10 +337,10 @@ class TestLoggerHook: def _setup_runner(self): runner = MagicMock() runner.epoch = 1 - runner.data_loader = [0] * 5 + runner.cur_dataloader = [0] * 5 runner.inner_iter = 1 runner.iter = 10 - runner.max_iters = 50 + runner.train_loop.max_iters = 50 logger = logging.getLogger() logger.setLevel(logging.INFO) for handler in logger.handlers: diff --git a/tests/test_hook/test_sampler_seed_hook.py b/tests/test_hook/test_sampler_seed_hook.py index 0bf96743..9d19edf7 100644 --- a/tests/test_hook/test_sampler_seed_hook.py +++ b/tests/test_hook/test_sampler_seed_hook.py @@ -12,17 +12,17 @@ class TestDistSamplerSeedHook: # Test dataset sampler runner = Mock() runner.epoch = 1 - runner.data_loader = Mock() - runner.data_loader.sampler = Mock() - runner.data_loader.sampler.set_epoch = Mock() - hook.before_epoch(runner) - runner.data_loader.sampler.set_epoch.assert_called() + runner.cur_dataloader = Mock() + runner.cur_dataloader.sampler = Mock() + runner.cur_dataloader.sampler.set_epoch = Mock() + hook.before_train_epoch(runner) + runner.cur_dataloader.sampler.set_epoch.assert_called() # Test batch sampler runner = Mock() - runner.data_loader = Mock() - runner.data_loader.sampler = Mock(spec_set=True) - runner.data_loader.batch_sampler = Mock() - runner.data_loader.batch_sampler.sampler = Mock() - runner.data_loader.batch_sampler.sampler.set_epoch = Mock() - hook.before_epoch(runner) - runner.data_loader.batch_sampler.sampler.set_epoch.assert_called() + runner.cur_dataloader = Mock() + runner.cur_dataloader.sampler = Mock(spec_set=True) + runner.cur_dataloader.batch_sampler = Mock() + runner.cur_dataloader.batch_sampler.sampler = Mock() + runner.cur_dataloader.batch_sampler.sampler.set_epoch = Mock() + hook.before_train_epoch(runner) + runner.cur_dataloader.batch_sampler.sampler.set_epoch.assert_called() diff --git a/tests/test_hook/test_sync_buffers_hook.py b/tests/test_hook/test_sync_buffers_hook.py index 6bba7de5..1c0b6295 100644 --- a/tests/test_hook/test_sync_buffers_hook.py +++ b/tests/test_hook/test_sync_buffers_hook.py @@ -10,4 +10,4 @@ class TestSyncBuffersHook: Runner = Mock() Runner.model = Mock() Hook = SyncBuffersHook() - Hook.after_epoch(Runner) + Hook._after_epoch(Runner) -- GitLab