diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index a312307fe4ef368cb8060d24ff0940c7f30f63b8..d8c75d1f2f083cbaf0b8c30e4ed213f8221d2b02 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -168,18 +168,17 @@ class CheckpointHook(Hook): else: break - def after_train_iter( - self, - runner, - data_batch: DATA_BATCH = None, - outputs: Optional[Sequence[BaseDataSample]] = None) -> None: + def after_train_iter(self, + runner, + data_batch: DATA_BATCH = None, + outputs=Optional[dict]) -> None: """Save the checkpoint and synchronize buffers after each iteration. Args: runner (Runner): The runner of the training process. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. - outputs (Sequence[BaseDataSample], optional): Outputs from model. + outputs (dict, optional): Outputs from model. Defaults to None. """ if self.by_epoch: diff --git a/mmengine/hooks/empty_cache_hook.py b/mmengine/hooks/empty_cache_hook.py index 2f621131441acae39c50e8caa97b2384ba978f53..a37de3377ad3f38f2c7e567ade4517e31eb9104b 100644 --- a/mmengine/hooks/empty_cache_hook.py +++ b/mmengine/hooks/empty_cache_hook.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Optional, Sequence, Tuple +from typing import Any, Optional, Sequence, Tuple, Union import torch @@ -37,14 +37,16 @@ class EmptyCacheHook(Hook): def after_iter(self, runner, data_batch: DATA_BATCH = None, - outputs: Optional[Sequence[BaseDataSample]] = None) -> None: + outputs: + Optional[Union[dict, Sequence[BaseDataSample]]] = None)\ + -> None: """Empty cache after an iteration. Args: runner (Runner): The runner of the training process. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. - outputs (Sequence[BaseDataSample]): Outputs from model. + outputs (dict or sequence, optional): Outputs from model. Defaults to None. """ if self._after_iter: diff --git a/mmengine/hooks/hook.py b/mmengine/hooks/hook.py index f2d52fc96f51b47a5c66b5f2352fb28e27e1ecbb..917293750f70863fdb5db336eb85088ffd466837 100644 --- a/mmengine/hooks/hook.py +++ b/mmengine/hooks/hook.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Optional, Sequence, Tuple +from typing import Any, Optional, Sequence, Tuple, Union from mmengine.data import BaseDataSample @@ -19,7 +19,8 @@ class Hook: operations before the training process. Args: - runner (Runner): The runner of the training process. + runner (Runner): The runner of the training/validation/testing + process. """ pass @@ -27,11 +28,66 @@ class Hook: """All subclasses should override this method, if they need any operations after the training process. + Args: + runner (Runner): The runner of the training/validation/testing + process. + """ + pass + + def before_train(self, runner) -> None: + """All subclasses should override this method, if they need any + operations before train. + + Args: + runner (Runner): The runner of the training process. + """ + pass + + def after_train(self, runner) -> None: + """All subclasses should override this method, if they need any + operations after train. + Args: runner (Runner): The runner of the training process. """ pass + def before_val(self, runner) -> None: + """All subclasses should override this method, if they need any + operations before val. + + Args: + runner (Runner): The runner of the validation process. + """ + pass + + def after_val(self, runner) -> None: + """All subclasses should override this method, if they need any + operations after val. + + Args: + runner (Runner): The runner of the validation process. + """ + pass + + def before_test(self, runner) -> None: + """All subclasses should override this method, if they need any + operations before test. + + Args: + runner (Runner): The runner of the testing process. + """ + pass + + def after_test(self, runner) -> None: + """All subclasses should override this method, if they need any + operations after test. + + Args: + runner (Runner): The runner of the testing process. + """ + pass + def before_epoch(self, runner) -> None: """All subclasses should override this method, if they need any operations before each epoch. @@ -64,7 +120,9 @@ class Hook: def after_iter(self, runner, data_batch: DATA_BATCH = None, - outputs: Optional[Sequence[BaseDataSample]] = None) -> None: + outputs: + Optional[Union[dict, Sequence[BaseDataSample]]] = None) \ + -> None: """All subclasses should override this method, if they need any operations after each epoch. @@ -72,8 +130,8 @@ class Hook: runner (Runner): The runner of the training process. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. - outputs (Sequence[BaseDataSample], optional): Outputs from model. - Defaults to None. + outputs (dict or sequence, optional): Outputs from model. Defaults + to None. """ pass @@ -184,11 +242,10 @@ class Hook: """ self.before_iter(runner, data_batch=None) - def after_train_iter( - self, - runner, - data_batch: DATA_BATCH = None, - outputs: Optional[Sequence[BaseDataSample]] = None) -> None: + def after_train_iter(self, + runner, + data_batch: DATA_BATCH = None, + outputs: Optional[dict] = None) -> None: """All subclasses should override this method, if they need any operations after each training iteration. @@ -196,16 +253,16 @@ class Hook: runner (Runner): The runner of the training process. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. - outputs (Sequence[BaseDataSample], optional): Outputs from model. + outputs (dict, optional): Outputs from model. Defaults to None. """ self.after_iter(runner, data_batch=None, outputs=None) - def after_val_iter( - self, - runner, - data_batch: DATA_BATCH = None, - outputs: Optional[Sequence[BaseDataSample]] = None) -> None: + def after_val_iter(self, + runner, + data_batch: DATA_BATCH = None, + outputs: Optional[Sequence[BaseDataSample]] = None) \ + -> None: """All subclasses should override this method, if they need any operations after each validation iteration. @@ -213,7 +270,7 @@ class Hook: runner (Runner): The runner of the training process. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. - outputs (Sequence[BaseDataSample], optional): Outputs from + outputs (dict or sequence, optional): Outputs from model. Defaults to None. """ self.after_iter(runner, data_batch=None, outputs=None) @@ -230,7 +287,7 @@ class Hook: runner (Runner): The runner of the training process. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. - outputs (Sequence[BaseDataSample], optional): Outputs from model. + outputs (dict, optional): Outputs from model. Defaults to None. """ self.after_iter(runner, data_batch=None, outputs=None) diff --git a/mmengine/hooks/iter_timer_hook.py b/mmengine/hooks/iter_timer_hook.py index d1d6404fc1d8779bce10e790a4d3b1cd4fb302c0..727015045efd57fabc679ba22096f0512f9e4ced 100644 --- a/mmengine/hooks/iter_timer_hook.py +++ b/mmengine/hooks/iter_timer_hook.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import time -from typing import Any, Optional, Sequence, Tuple +from typing import Any, Optional, Sequence, Tuple, Union from mmengine.data import BaseDataSample from mmengine.registry import HOOKS @@ -40,15 +40,17 @@ class IterTimerHook(Hook): def after_iter(self, runner, data_batch: DATA_BATCH = None, - outputs: Optional[Sequence[BaseDataSample]] = None) -> None: + outputs: + Optional[Union[dict, Sequence[BaseDataSample]]] = None) \ + -> None: """Logging time for a iteration and update the time flag. Args: runner (Runner): The runner of the training process. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. - outputs (Sequence[BaseDataSample]): Outputs from model. - Defaults to None. + outputs (dict or sequence, optional): Outputs from model. Defaults + to None. """ # TODO: update for new logging system runner.log_buffer.update({'time': time.time() - self.t}) diff --git a/mmengine/hooks/logger_hook.py b/mmengine/hooks/logger_hook.py index 741bfd95442021855c9f90337012cc5965f8346c..b557c2d0988648554190b02e01ee18aa65466b45 100644 --- a/mmengine/hooks/logger_hook.py +++ b/mmengine/hooks/logger_hook.py @@ -171,18 +171,17 @@ class LoggerHook(Hook): if runner.meta is not None: runner.writer.add_params(runner.meta, file_path=self.yaml_log_path) - def after_train_iter( - self, - runner, - data_batch: DATA_BATCH = None, - outputs: Optional[Sequence[BaseDataSample]] = None) -> None: + def after_train_iter(self, + runner, + data_batch: DATA_BATCH = None, + outputs: Optional[dict] = None) -> None: """Record training logs. Args: runner (Runner): The runner of the training process. data_batch (Sequence[BaseDataSample], optional): Data from dataloader. Defaults to None. - outputs (Sequence[BaseDataSample], optional): Outputs from model. + outputs (dict, optional): Outputs from model. Defaults to None. """ if runner.meta is not None and 'exp_name' in runner.meta: diff --git a/mmengine/hooks/optimizer_hook.py b/mmengine/hooks/optimizer_hook.py index 3a9dabbcdb68706d2af74a772b78b359d7c2373f..418bf06d28528d1dcfa2678a6f1373f55849c43f 100644 --- a/mmengine/hooks/optimizer_hook.py +++ b/mmengine/hooks/optimizer_hook.py @@ -56,11 +56,10 @@ class OptimizerHook(Hook): return clip_grad.clip_grad_norm_(params, **self.grad_clip) return None - def after_train_iter( - self, - runner, - data_batch: DATA_BATCH = None, - outputs: Optional[Sequence[BaseDataSample]] = None) -> None: + def after_train_iter(self, + runner, + data_batch: DATA_BATCH = None, + outputs: Optional[dict] = None) -> None: """All operations need to be finished after each training iteration. This function will finish following 3 operations: @@ -80,7 +79,7 @@ class OptimizerHook(Hook): from dataloader. In order to keep this interface consistent with other hooks, we keep ``data_batch`` here. Defaults to None. - outputs (Sequence[BaseDataSample], optional): Outputs from model. + outputs (dict, optional): Outputs from model. In order to keep this interface consistent with other hooks, we keep ``outputs`` here. Defaults to None. """ diff --git a/mmengine/hooks/param_scheduler_hook.py b/mmengine/hooks/param_scheduler_hook.py index 095b1bb3e700e7511d9087d2e86757641b677c9e..a85ef3ac432d6421605335cf19ac6464ded48ef5 100644 --- a/mmengine/hooks/param_scheduler_hook.py +++ b/mmengine/hooks/param_scheduler_hook.py @@ -15,11 +15,10 @@ class ParamSchedulerHook(Hook): priority = 'LOW' - def after_train_iter( - self, - runner, - data_batch: DATA_BATCH = None, - outputs: Optional[Sequence[BaseDataSample]] = None) -> None: + def after_train_iter(self, + runner, + data_batch: DATA_BATCH = None, + outputs: Optional[dict] = None) -> None: """Call step function for each scheduler after each iteration. Args: @@ -28,7 +27,7 @@ class ParamSchedulerHook(Hook): from dataloader. In order to keep this interface consistent with other hooks, we keep ``data_batch`` here. Defaults to None. - outputs (Sequence[BaseDataSample], optional): Outputs from model. + outputs (dict, optional): Outputs from model. In order to keep this interface consistent with other hooks, we keep ``data_batch`` here. Defaults to None. """