diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index d8c75d1f2f083cbaf0b8c30e4ed213f8221d2b02..313900b7a234773f10f922d511bba623cb55e0f9 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -119,9 +119,8 @@ class CheckpointHook(Hook): # save checkpoint for following cases: # 1. every ``self.interval`` epochs # 2. reach the last epoch of training - if self.every_n_epochs( - runner, self.interval) or (self.save_last - and self.is_last_epoch(runner)): + if self.every_n_epochs(runner, self.interval) or ( + self.save_last and self.is_last_train_epoch(runner)): runner.logger.info(f'Saving checkpoint at \ {runner.epoch + 1} epochs') if self.sync_buffer: @@ -187,9 +186,8 @@ class CheckpointHook(Hook): # save checkpoint for following cases: # 1. every ``self.interval`` iterations # 2. reach the last iteration of training - if self.every_n_iters( - runner, self.interval) or (self.save_last - and self.is_last_iter(runner)): + if self.every_n_iters(runner, self.interval) or \ + (self.save_last and self.is_last_iter(runner, mode='train')): runner.logger.info(f'Saving checkpoint at \ {runner.iter + 1} iterations') if self.sync_buffer: diff --git a/mmengine/hooks/empty_cache_hook.py b/mmengine/hooks/empty_cache_hook.py index a37de3377ad3f38f2c7e567ade4517e31eb9104b..cb20a744643607e7137de743228ff721673623ac 100644 --- a/mmengine/hooks/empty_cache_hook.py +++ b/mmengine/hooks/empty_cache_hook.py @@ -30,16 +30,16 @@ class EmptyCacheHook(Hook): before_epoch: bool = False, after_epoch: bool = True, after_iter: bool = False) -> None: - self._before_epoch = before_epoch - self._after_epoch = after_epoch - self._after_iter = after_iter + self._do_before_epoch = before_epoch + self._do_after_epoch = after_epoch + self._do_after_iter = after_iter - def after_iter(self, - runner, - data_batch: DATA_BATCH = None, - outputs: - Optional[Union[dict, Sequence[BaseDataSample]]] = None)\ - -> None: + def _after_iter(self, + runner, + data_batch: DATA_BATCH = None, + outputs: Optional[Union[dict, + Sequence[BaseDataSample]]] = None, + mode: str = 'train') -> None: """Empty cache after an iteration. Args: @@ -48,24 +48,27 @@ class EmptyCacheHook(Hook): from dataloader. Defaults to None. outputs (dict or sequence, optional): Outputs from model. Defaults to None. + mode (str): Current mode of runner. Defaults to 'train'. """ - if self._after_iter: + if self._do_after_iter: torch.cuda.empty_cache() - def before_epoch(self, runner) -> None: + def _before_epoch(self, runner, mode: str = 'train') -> None: """Empty cache before an epoch. Args: runner (Runner): The runner of the training process. + mode (str): Current mode of runner. Defaults to 'train'. """ - if self._before_epoch: + if self._do_before_epoch: torch.cuda.empty_cache() - def after_epoch(self, runner) -> None: + def _after_epoch(self, runner, mode: str = 'train') -> None: """Empty cache after an epoch. Args: runner (Runner): The runner of the training process. + mode (str): Current mode of runner. Defaults to 'train'. """ - if self._after_epoch: + if self._do_after_epoch: torch.cuda.empty_cache() diff --git a/mmengine/hooks/hook.py b/mmengine/hooks/hook.py index 917293750f70863fdb5db336eb85088ffd466837..7eeb5e54136baa1b474b3328b1ca7581dfaf0e7f 100644 --- a/mmengine/hooks/hook.py +++ b/mmengine/hooks/hook.py @@ -16,20 +16,20 @@ class Hook: def before_run(self, runner) -> None: """All subclasses should override this method, if they need any - operations before the training process. + operations before the training validation or testing process. Args: - runner (Runner): The runner of the training/validation/testing + runner (Runner): The runner of the training, validation or testing process. """ pass def after_run(self, runner) -> None: """All subclasses should override this method, if they need any - operations after the training process. + operations before the training validation or testing process. Args: - runner (Runner): The runner of the training/validation/testing + runner (Runner): The runner of the training, validation or testing process. """ pass @@ -54,7 +54,7 @@ class Hook: def before_val(self, runner) -> None: """All subclasses should override this method, if they need any - operations before val. + operations before validation. Args: runner (Runner): The runner of the validation process. @@ -63,7 +63,7 @@ class Hook: def after_val(self, runner) -> None: """All subclasses should override this method, if they need any - operations after val. + operations after validation. Args: runner (Runner): The runner of the validation process. @@ -72,7 +72,7 @@ class Hook: def before_test(self, runner) -> None: """All subclasses should override this method, if they need any - operations before test. + operations before testing. Args: runner (Runner): The runner of the testing process. @@ -81,67 +81,21 @@ class Hook: def after_test(self, runner) -> None: """All subclasses should override this method, if they need any - operations after test. + operations after testing. Args: runner (Runner): The runner of the testing process. """ pass - def before_epoch(self, runner) -> None: - """All subclasses should override this method, if they need any - operations before each epoch. - - Args: - runner (Runner): The runner of the training process. - """ - pass - - def after_epoch(self, runner) -> None: - """All subclasses should override this method, if they need any - operations after each epoch. - - Args: - runner (Runner): The runner of the training process. - """ - pass - - def before_iter(self, runner, data_batch: DATA_BATCH = None) -> None: - """All subclasses should override this method, if they need any - operations before each iter. - - Args: - runner (Runner): The runner of the training process. - data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): - Data from dataloader. Defaults to None. - """ - pass - - def after_iter(self, - runner, - data_batch: DATA_BATCH = None, - outputs: - Optional[Union[dict, Sequence[BaseDataSample]]] = None) \ - -> None: - """All subclasses should override this method, if they need any - operations after each epoch. - - Args: - runner (Runner): The runner of the training process. - data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): - Data from dataloader. Defaults to None. - outputs (dict or sequence, optional): Outputs from model. Defaults - to None. - """ - pass - def before_save_checkpoint(self, runner, checkpoint: dict) -> None: """All subclasses should override this method, if they need any operations before saving the checkpoint. Args: - runner (Runner): The runner of the training process. - checkpoints (dict): Model's checkpoint. + runner (Runner): The runner of the training, validation or testing + process. + checkpoint (dict): Model's checkpoint. """ pass @@ -150,8 +104,9 @@ class Hook: operations after loading the checkpoint. Args: - runner (Runner): The runner of the training process. - checkpoints (dict): Model's checkpoint. + runner (Runner): The runner of the training, validation or testing + process. + checkpoint (dict): Model's checkpoint. """ pass @@ -162,25 +117,25 @@ class Hook: Args: runner (Runner): The runner of the training process. """ - self.before_epoch(runner) + self._before_epoch(runner, mode='train') def before_val_epoch(self, runner) -> None: """All subclasses should override this method, if they need any operations before each validation epoch. Args: - runner (Runner): The runner of the training process. + runner (Runner): The runner of the validation process. """ - self.before_epoch(runner) + self._before_epoch(runner, mode='val') def before_test_epoch(self, runner) -> None: """All subclasses should override this method, if they need any operations before each test epoch. Args: - runner (Runner): The runner of the training process. + runner (Runner): The runner of the testing process. """ - self.before_epoch(runner) + self._before_epoch(runner, mode='test') def after_train_epoch(self, runner) -> None: """All subclasses should override this method, if they need any @@ -189,25 +144,25 @@ class Hook: Args: runner (Runner): The runner of the training process. """ - self.after_epoch(runner) + self._after_epoch(runner, mode='train') def after_val_epoch(self, runner) -> None: """All subclasses should override this method, if they need any operations after each validation epoch. Args: - runner (Runner): The runner of the training process. + runner (Runner): The runner of the validation process. """ - self.after_epoch(runner) + self._after_epoch(runner, mode='val') def after_test_epoch(self, runner) -> None: """All subclasses should override this method, if they need any operations after each test epoch. Args: - runner (Runner): The runner of the training process. + runner (Runner): The runner of the testing process. """ - self.after_epoch(runner) + self._after_epoch(runner, mode='test') def before_train_iter(self, runner, data_batch: DATA_BATCH = None) -> None: """All subclasses should override this method, if they need any @@ -218,29 +173,29 @@ class Hook: data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. """ - self.before_iter(runner, data_batch=None) + self._before_iter(runner, data_batch=data_batch, mode='train') def before_val_iter(self, runner, data_batch: DATA_BATCH = None) -> None: """All subclasses should override this method, if they need any operations before each validation iteration. Args: - runner (Runner): The runner of the training process. + runner (Runner): The runner of the validation process. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. """ - self.before_iter(runner, data_batch=None) + self._before_iter(runner, data_batch=data_batch, mode='val') def before_test_iter(self, runner, data_batch: DATA_BATCH = None) -> None: """All subclasses should override this method, if they need any operations before each test iteration. Args: - runner (Runner): The runner of the training process. + runner (Runner): The runner of the testing process. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. """ - self.before_iter(runner, data_batch=None) + self._before_iter(runner, data_batch=data_batch, mode='test') def after_train_iter(self, runner, @@ -256,7 +211,8 @@ class Hook: outputs (dict, optional): Outputs from model. Defaults to None. """ - self.after_iter(runner, data_batch=None, outputs=None) + self._after_iter( + runner, data_batch=data_batch, outputs=outputs, mode='train') def after_val_iter(self, runner, @@ -267,13 +223,14 @@ class Hook: operations after each validation iteration. Args: - runner (Runner): The runner of the training process. + runner (Runner): The runner of the validation process. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. outputs (dict or sequence, optional): Outputs from model. Defaults to None. """ - self.after_iter(runner, data_batch=None, outputs=None) + self._after_iter( + runner, data_batch=data_batch, outputs=outputs, mode='val') def after_test_iter( self, @@ -284,48 +241,108 @@ class Hook: operations after each test iteration. Args: - runner (Runner): The runner of the training process. + runner (Runner): The runner of the training process. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. outputs (dict, optional): Outputs from model. Defaults to None. """ - self.after_iter(runner, data_batch=None, outputs=None) + self._after_iter( + runner, data_batch=data_batch, outputs=outputs, mode='test') + + def _before_epoch(self, runner, mode: str = 'train') -> None: + """All subclasses should override this method, if they need any + operations before each epoch. + + Args: + runner (Runner): The runner of the training, validation or testing + process. + mode (str): Current mode of runner. Defaults to 'train'. + """ + pass + + def _after_epoch(self, runner, mode: str = 'train') -> None: + """All subclasses should override this method, if they need any + operations after each epoch. + + Args: + runner (Runner): The runner of the training, validation or testing + process. + mode (str): Current mode of runner. Defaults to 'train'. + """ + pass + + def _before_iter(self, + runner, + data_batch: DATA_BATCH = None, + mode: str = 'train') -> None: + """All subclasses should override this method, if they need any + operations before each iter. + + Args: + runner (Runner): The runner of the training, validation or testing + process. + data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): + Data from dataloader. Defaults to None. + mode (str): Current mode of runner. Defaults to 'train'. + """ + pass + + def _after_iter(self, + runner, + data_batch: DATA_BATCH = None, + outputs: Optional[Union[Sequence[BaseDataSample], + dict]] = None, + mode: str = 'train') -> None: + """All subclasses should override this method, if they need any + operations after each epoch. + + Args: + runner (Runner): The runner of the training, validation or testing + process. + data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): + Data from dataloader. Defaults to None. + outputs (Sequence[BaseDataSample], optional): Outputs from model. + Defaults to None. + mode (str): Current mode of runner. Defaults to 'train'. + """ + pass def every_n_epochs(self, runner, n: int) -> bool: - """Test whether or not current epoch can be evenly divided by n. + """Test whether current epoch can be evenly divided by n. Args: - runner (Runner): The runner of the training process. - n (int): Whether or not current epoch can be evenly divided by n. + runner (Runner): The runner of the training, validation or testing + process. + n (int): Whether current epoch can be evenly divided by n. Returns: - bool: whether or not current epoch can be evenly divided by n. + bool: Whether current epoch can be evenly divided by n. """ return (runner.epoch + 1) % n == 0 if n > 0 else False def every_n_inner_iters(self, runner, n: int) -> bool: - """Test whether or not current inner iteration can be evenly divided by - n. + """Test whether current inner iteration can be evenly divided by n. Args: - runner (Runner): The runner of the training process. - n (int): Whether or not current inner iteration can be evenly + runner (Runner): The runner of the training, validation or testing + process. + n (int): Whether current inner iteration can be evenly divided by n. Returns: - bool: whether or not current inner iteration can be evenly + bool: Whether current inner iteration can be evenly divided by n. """ return (runner.inner_iter + 1) % n == 0 if n > 0 else False def every_n_iters(self, runner, n: int) -> bool: - """Test whether or not current iteration can be evenly divided by n. + """Test whether current iteration can be evenly divided by n. Args: - runner (Runner): The runner of the training process. - n (int): Whether or not current iteration can be - evenly divided by n. + runner (Runner): The runner of the training, validation or testing + process. + n (int): Whether current iteration can be evenly divided by n. Returns: bool: Return True if the current iteration can be evenly divided @@ -334,35 +351,46 @@ class Hook: return (runner.iter + 1) % n == 0 if n > 0 else False def end_of_epoch(self, runner) -> bool: - """Check whether the current epoch reaches the `max_epochs` or not. + """Check whether the current iteration reaches the last iteration of + current dataloader. Args: - runner (Runner): The runner of the training process. + runner (Runner): The runner of the training, validation or testing + process. Returns: - bool: whether the end of current epoch or not. + bool: Whether reaches the end of current epoch or not. """ - return runner.inner_iter + 1 == len(runner.data_loader) + return runner.inner_iter + 1 == len(runner.cur_dataloader) - def is_last_epoch(self, runner) -> bool: - """Test whether or not current epoch is the last epoch. + def is_last_train_epoch(self, runner) -> bool: + """Test whether current epoch is the last train epoch. Args: runner (Runner): The runner of the training process. Returns: - bool: bool: Return True if the current epoch reaches the - `max_epochs`, otherwise False. + bool: Whether reaches the end of training epoch. """ - return runner.epoch + 1 == runner._max_epochs + return runner.epoch + 1 == runner.train_loop.max_epochs - def is_last_iter(self, runner) -> bool: - """Test whether or not current epoch is the last iteration. + def is_last_iter(self, runner, mode='train') -> bool: + """Test whether current iteration is the last iteration. Args: - runner (Runner): The runner of the training process. + runner (Runner): The runner of the training, validation or testing + process. Returns: - bool: whether or not current iteration is the last iteration. - """ - return runner.iter + 1 == runner._max_iters + bool: Whether current iteration is the last iteration. + mode (str): Current mode of runner. Defaults to 'train'. + """ + if mode == 'train': + return runner.iter + 1 == runner.train_loop.max_iters + elif mode == 'val': + return runner.iter + 1 == runner.val_loop.max_iters + elif mode == 'test': + return runner.iter + 1 == runner.test_loop.max_iters + else: + raise ValueError('mode should be train, val or test but got' + f'{mode}') diff --git a/mmengine/hooks/iter_timer_hook.py b/mmengine/hooks/iter_timer_hook.py index 727015045efd57fabc679ba22096f0512f9e4ced..824774ed646638209973ac0cf31b26e204832878 100644 --- a/mmengine/hooks/iter_timer_hook.py +++ b/mmengine/hooks/iter_timer_hook.py @@ -18,30 +18,37 @@ class IterTimerHook(Hook): priority = 'NORMAL' - def before_epoch(self, runner) -> None: + def _before_epoch(self, runner, mode: str = 'train') -> None: """Record time flag before start a epoch. Args: runner (Runner): The runner of the training process. + mode (str): Current mode of runner. Defaults to 'train'. """ self.t = time.time() - def before_iter(self, runner, data_batch: DATA_BATCH = None) -> None: + def _before_iter(self, + runner, + data_batch: DATA_BATCH = None, + mode: str = 'train') -> None: """Logging time for loading data and update the time flag. Args: runner (Runner): The runner of the training process. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. + mode (str): Current mode of runner. Defaults to 'train'. """ # TODO: update for new logging system - runner.log_buffer.update({'data_time': time.time() - self.t}) + runner.message_hub.update_log(f'{mode}/data_time', + time.time() - self.t) - def after_iter(self, - runner, - data_batch: DATA_BATCH = None, - outputs: - Optional[Union[dict, Sequence[BaseDataSample]]] = None) \ + def _after_iter(self, + runner, + data_batch: DATA_BATCH = None, + outputs: + Optional[Union[dict, Sequence[BaseDataSample]]] = None, + mode: str = 'train') \ -> None: """Logging time for a iteration and update the time flag. @@ -51,7 +58,9 @@ class IterTimerHook(Hook): from dataloader. Defaults to None. outputs (dict or sequence, optional): Outputs from model. Defaults to None. + mode (str): Current mode of runner. Defaults to 'train'. """ # TODO: update for new logging system - runner.log_buffer.update({'time': time.time() - self.t}) + + runner.message_hub.update_log(f'{mode}/time', time.time() - self.t) self.t = time.time() diff --git a/mmengine/hooks/logger_hook.py b/mmengine/hooks/logger_hook.py index b557c2d0988648554190b02e01ee18aa65466b45..9e2518bf2ce87fc5bed6c41b16deebad69a9b133 100644 --- a/mmengine/hooks/logger_hook.py +++ b/mmengine/hooks/logger_hook.py @@ -264,14 +264,15 @@ class LoggerHook(Hook): # by iter: Iter [100/100000] if self.by_epoch: log_str = f'Epoch [{cur_epoch}]' \ - f'[{cur_iter}/{len(runner.data_loader)}]\t' + f'[{cur_iter}/{len(runner.cur_dataloader)}]\t' else: - log_str = f'Iter [{cur_iter}/{runner.max_iters}]\t' + log_str = f'Iter [{cur_iter}/{runner.train_loop.max_iters}]\t' log_str += f'{lr_momentum_str}, ' # Calculate eta time. self.time_sec_tot += (tag['time'] * self.interval) time_sec_avg = self.time_sec_tot / (runner.iter - self.start_iter + 1) - eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1) + eta_sec = time_sec_avg * ( + runner.train_loop.max_iters - runner.iter - 1) eta_str = str(datetime.timedelta(seconds=int(eta_sec))) log_str += f'eta: {eta_str}, ' log_str += f'time: {tag["time"]:.3f}, ' \ @@ -302,7 +303,7 @@ class LoggerHook(Hook): """ tag = self._collect_info(runner, 'val') # Compatible with function `log` https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/logger/text.py # noqa E501 - eval_iter = len(runner.data_loader) + eval_iter = len(runner.cur_dataloader) cur_iter = self._get_iter(runner) cur_epoch = self._get_epoch(runner, 'val') # val/test time diff --git a/mmengine/hooks/sampler_seed_hook.py b/mmengine/hooks/sampler_seed_hook.py index d7d243d8fea9fbde20aed3010fcaac859a771e7f..eed3fa90d46f5abd98bec0bb2c60978c82795855 100644 --- a/mmengine/hooks/sampler_seed_hook.py +++ b/mmengine/hooks/sampler_seed_hook.py @@ -14,15 +14,15 @@ class DistSamplerSeedHook(Hook): priority = 'NORMAL' - def before_epoch(self, runner) -> None: + def before_train_epoch(self, runner, mode: str = 'train') -> None: """Set the seed for sampler and batch_sampler. Args: runner (Runner): The runner of the training process. """ - if hasattr(runner.data_loader.sampler, 'set_epoch'): + if hasattr(runner.cur_dataloader.sampler, 'set_epoch'): # in case the data loader uses `SequentialSampler` in Pytorch - runner.data_loader.sampler.set_epoch(runner.epoch) - elif hasattr(runner.data_loader.batch_sampler.sampler, 'set_epoch'): + runner.cur_dataloader.sampler.set_epoch(runner.epoch) + elif hasattr(runner.cur_dataloader.batch_sampler.sampler, 'set_epoch'): # batch sampler in pytorch warps the sampler as its attributes. - runner.data_loader.batch_sampler.sampler.set_epoch(runner.epoch) + runner.cur_dataloader.batch_sampler.sampler.set_epoch(runner.epoch) diff --git a/mmengine/hooks/sync_buffer_hook.py b/mmengine/hooks/sync_buffer_hook.py index 9aa6402d8cd4bdbbd93433416c23a72e9a68b129..37b62f986077fea3ebe0c1ba1add66d92cdc41af 100644 --- a/mmengine/hooks/sync_buffer_hook.py +++ b/mmengine/hooks/sync_buffer_hook.py @@ -89,7 +89,7 @@ class SyncBuffersHook(Hook): def __init__(self) -> None: self.distributed = dist.IS_DIST - def after_epoch(self, runner) -> None: + def after_train_epoch(self, runner) -> None: """All-reduce model buffers at the end of each epoch. Args: diff --git a/tests/test_hook/test_empty_cache_hook.py b/tests/test_hook/test_empty_cache_hook.py index a09806d0f6697c8f03a4cf60992f3927fa77e80c..dc8ce8fc5fc355252a02726d9cac2080a2776b05 100644 --- a/tests/test_hook/test_empty_cache_hook.py +++ b/tests/test_hook/test_empty_cache_hook.py @@ -9,6 +9,6 @@ class TestEmptyCacheHook: def test_emtpy_cache_hook(self): Hook = EmptyCacheHook(True, True, True) Runner = Mock() - Hook.after_iter(Runner) - Hook.before_epoch(Runner) - Hook.after_epoch(Runner) + Hook._after_iter(Runner) + Hook._before_epoch(Runner) + Hook._after_epoch(Runner) diff --git a/tests/test_hook/test_hook.py b/tests/test_hook/test_hook.py index 5884a161d7358022b16c8454fa5febc9f365d0fa..d55f8a66d9d530593c4f5bb4e5e6da23d8a883f1 100644 --- a/tests/test_hook/test_hook.py +++ b/tests/test_hook/test_hook.py @@ -1,6 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. from unittest.mock import Mock +import pytest + from mmengine.hooks import Hook @@ -19,25 +21,25 @@ class TestHook: def test_before_epoch(self): hook = Hook() runner = Mock() - hook.before_epoch(runner) + hook._before_epoch(runner) def test_after_epoch(self): hook = Hook() runner = Mock() - hook.after_epoch(runner) + hook._after_epoch(runner) def test_before_iter(self): hook = Hook() runner = Mock() data_batch = {} - hook.before_iter(runner, data_batch) + hook._before_iter(runner, data_batch) def test_after_iter(self): hook = Hook() runner = Mock() data_batch = {} outputs = {} - hook.after_iter(runner, data_batch, outputs) + hook._after_iter(runner, data_batch, outputs) def test_before_save_checkpoint(self): hook = Hook() @@ -161,7 +163,8 @@ class TestHook: # last inner iter runner.inner_iter = 1 - runner.data_loader.__len__ = Mock(return_value=2) + runner.cur_dataloader.__len__ = Mock(return_value=2) + runner.cur_dataloader.__len__ = Mock(return_value=2) return_val = hook.end_of_epoch(runner) assert return_val @@ -170,19 +173,19 @@ class TestHook: return_val = hook.end_of_epoch(runner) assert not return_val - def test_is_last_epoch(self): + def test_is_last_train_epoch(self): hook = Hook() runner = Mock() # last epoch runner.epoch = 1 - runner._max_epochs = 2 - return_val = hook.is_last_epoch(runner) + runner.train_loop.max_epochs = 2 + return_val = hook.is_last_train_epoch(runner) assert return_val # not the last epoch - runner.epoch = 0 - return_val = hook.is_last_epoch(runner) + runner.train_loop.max_epochs = 0 + return_val = hook.is_last_train_epoch(runner) assert not return_val def test_is_last_iter(self): @@ -191,11 +194,18 @@ class TestHook: # last iter runner.iter = 1 - runner._max_iters = 2 + runner.train_loop.max_iters = 2 return_val = hook.is_last_iter(runner) assert return_val # not the last iter - runner.iter = 0 - return_val = hook.is_last_iter(runner) + runner.val_loop.max_iters = 0 + return_val = hook.is_last_iter(runner, mode='val') assert not return_val + + runner.test_loop.max_iters = 0 + return_val = hook.is_last_iter(runner, mode='test') + assert not return_val + + with pytest.raises(ValueError): + hook.is_last_iter(runner, mode='error_mode') diff --git a/tests/test_hook/test_iter_timer_hook.py b/tests/test_hook/test_iter_timer_hook.py index 5e3b6e71b2c077e950914bceea69c1621ce95f09..8b20d4f1255db29cda905e35338b75bbfe31ac4b 100644 --- a/tests/test_hook/test_iter_timer_hook.py +++ b/tests/test_hook/test_iter_timer_hook.py @@ -9,21 +9,21 @@ class TestIterTimerHook: def test_before_epoch(self): Hook = IterTimerHook() Runner = Mock() - Hook.before_epoch(Runner) + Hook._before_epoch(Runner) assert isinstance(Hook.t, float) def test_before_iter(self): Hook = IterTimerHook() Runner = Mock() Runner.log_buffer = dict() - Hook.before_epoch(Runner) - Hook.before_iter(Runner) - assert 'data_time' in Runner.log_buffer + Hook._before_epoch(Runner) + Hook._before_iter(Runner) + Runner.message_hub.update_log.assert_called() def test_after_iter(self): Hook = IterTimerHook() Runner = Mock() Runner.log_buffer = dict() - Hook.before_epoch(Runner) - Hook.after_iter(Runner) - assert 'time' in Runner.log_buffer + Hook._before_epoch(Runner) + Hook._after_iter(Runner) + Runner.message_hub.update_log.assert_called() diff --git a/tests/test_hook/test_logger_hook.py b/tests/test_hook/test_logger_hook.py index 90faad7741c09e27fc8f21d0f9045166255bf4ce..d92ac9bc208c4d52974799c7e44ae20a4ab3f652 100644 --- a/tests/test_hook/test_logger_hook.py +++ b/tests/test_hook/test_logger_hook.py @@ -110,7 +110,7 @@ class TestLoggerHook: # Test end of the epoch. logger_hook = LoggerHook(by_epoch=True, ignore_last=False) logger_hook._log_train = MagicMock() - runner.data_loader = [0] * 5 + runner.cur_dataloader = [0] * 5 runner.inner_iter = 4 logger_hook.after_train_iter(runner) logger_hook._log_train.assert_called() @@ -155,7 +155,7 @@ class TestLoggerHook: out, _ = capsys.readouterr() time_avg = logger_hook.time_sec_tot / ( runner.iter + 1 - logger_hook.start_iter) - eta_second = time_avg * (runner.max_iters - runner.iter - 1) + eta_second = time_avg * (runner.train_loop.max_iters - runner.iter - 1) eta_str = str(datetime.timedelta(seconds=int(eta_second))) if by_epoch: if torch.cuda.is_available(): @@ -337,10 +337,10 @@ class TestLoggerHook: def _setup_runner(self): runner = MagicMock() runner.epoch = 1 - runner.data_loader = [0] * 5 + runner.cur_dataloader = [0] * 5 runner.inner_iter = 1 runner.iter = 10 - runner.max_iters = 50 + runner.train_loop.max_iters = 50 logger = logging.getLogger() logger.setLevel(logging.INFO) for handler in logger.handlers: diff --git a/tests/test_hook/test_sampler_seed_hook.py b/tests/test_hook/test_sampler_seed_hook.py index 0bf967437dabd7d60f5c0e2af78f750c6f30a9ad..9d19edf713cafc665a8b75a8a800b3ecbd6a4f68 100644 --- a/tests/test_hook/test_sampler_seed_hook.py +++ b/tests/test_hook/test_sampler_seed_hook.py @@ -12,17 +12,17 @@ class TestDistSamplerSeedHook: # Test dataset sampler runner = Mock() runner.epoch = 1 - runner.data_loader = Mock() - runner.data_loader.sampler = Mock() - runner.data_loader.sampler.set_epoch = Mock() - hook.before_epoch(runner) - runner.data_loader.sampler.set_epoch.assert_called() + runner.cur_dataloader = Mock() + runner.cur_dataloader.sampler = Mock() + runner.cur_dataloader.sampler.set_epoch = Mock() + hook.before_train_epoch(runner) + runner.cur_dataloader.sampler.set_epoch.assert_called() # Test batch sampler runner = Mock() - runner.data_loader = Mock() - runner.data_loader.sampler = Mock(spec_set=True) - runner.data_loader.batch_sampler = Mock() - runner.data_loader.batch_sampler.sampler = Mock() - runner.data_loader.batch_sampler.sampler.set_epoch = Mock() - hook.before_epoch(runner) - runner.data_loader.batch_sampler.sampler.set_epoch.assert_called() + runner.cur_dataloader = Mock() + runner.cur_dataloader.sampler = Mock(spec_set=True) + runner.cur_dataloader.batch_sampler = Mock() + runner.cur_dataloader.batch_sampler.sampler = Mock() + runner.cur_dataloader.batch_sampler.sampler.set_epoch = Mock() + hook.before_train_epoch(runner) + runner.cur_dataloader.batch_sampler.sampler.set_epoch.assert_called() diff --git a/tests/test_hook/test_sync_buffers_hook.py b/tests/test_hook/test_sync_buffers_hook.py index 6bba7de5943032f0dd2801b4ad721fd0d38002c1..1c0b629569522250ef1a230ac998906716d4dc9e 100644 --- a/tests/test_hook/test_sync_buffers_hook.py +++ b/tests/test_hook/test_sync_buffers_hook.py @@ -10,4 +10,4 @@ class TestSyncBuffersHook: Runner = Mock() Runner.model = Mock() Hook = SyncBuffersHook() - Hook.after_epoch(Runner) + Hook._after_epoch(Runner)