From ec3034b7650b79fb9d7f692d2b9384403adbcc8e Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Wed, 9 Mar 2022 23:10:19 +0800 Subject: [PATCH] [Fix] Fix output argument of after_iter, train_after_ter and val_after_iter (#115) * Fix hook * Fix * Fix docs * FIx * Fix * Fix as comment --- mmengine/hooks/checkpoint_hook.py | 11 ++- mmengine/hooks/empty_cache_hook.py | 8 ++- mmengine/hooks/hook.py | 93 +++++++++++++++++++++----- mmengine/hooks/iter_timer_hook.py | 10 +-- mmengine/hooks/logger_hook.py | 11 ++- mmengine/hooks/optimizer_hook.py | 11 ++- mmengine/hooks/param_scheduler_hook.py | 11 ++- 7 files changed, 106 insertions(+), 49 deletions(-) diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index a312307f..d8c75d1f 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -168,18 +168,17 @@ class CheckpointHook(Hook): else: break - def after_train_iter( - self, - runner, - data_batch: DATA_BATCH = None, - outputs: Optional[Sequence[BaseDataSample]] = None) -> None: + def after_train_iter(self, + runner, + data_batch: DATA_BATCH = None, + outputs=Optional[dict]) -> None: """Save the checkpoint and synchronize buffers after each iteration. Args: runner (Runner): The runner of the training process. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. - outputs (Sequence[BaseDataSample], optional): Outputs from model. + outputs (dict, optional): Outputs from model. Defaults to None. """ if self.by_epoch: diff --git a/mmengine/hooks/empty_cache_hook.py b/mmengine/hooks/empty_cache_hook.py index 2f621131..a37de337 100644 --- a/mmengine/hooks/empty_cache_hook.py +++ b/mmengine/hooks/empty_cache_hook.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Optional, Sequence, Tuple +from typing import Any, Optional, Sequence, Tuple, Union import torch @@ -37,14 +37,16 @@ class EmptyCacheHook(Hook): def after_iter(self, runner, data_batch: DATA_BATCH = None, - outputs: Optional[Sequence[BaseDataSample]] = None) -> None: + outputs: + Optional[Union[dict, Sequence[BaseDataSample]]] = None)\ + -> None: """Empty cache after an iteration. Args: runner (Runner): The runner of the training process. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. - outputs (Sequence[BaseDataSample]): Outputs from model. + outputs (dict or sequence, optional): Outputs from model. Defaults to None. """ if self._after_iter: diff --git a/mmengine/hooks/hook.py b/mmengine/hooks/hook.py index f2d52fc9..91729375 100644 --- a/mmengine/hooks/hook.py +++ b/mmengine/hooks/hook.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Optional, Sequence, Tuple +from typing import Any, Optional, Sequence, Tuple, Union from mmengine.data import BaseDataSample @@ -19,7 +19,8 @@ class Hook: operations before the training process. Args: - runner (Runner): The runner of the training process. + runner (Runner): The runner of the training/validation/testing + process. """ pass @@ -27,11 +28,66 @@ class Hook: """All subclasses should override this method, if they need any operations after the training process. + Args: + runner (Runner): The runner of the training/validation/testing + process. + """ + pass + + def before_train(self, runner) -> None: + """All subclasses should override this method, if they need any + operations before train. + + Args: + runner (Runner): The runner of the training process. + """ + pass + + def after_train(self, runner) -> None: + """All subclasses should override this method, if they need any + operations after train. + Args: runner (Runner): The runner of the training process. """ pass + def before_val(self, runner) -> None: + """All subclasses should override this method, if they need any + operations before val. + + Args: + runner (Runner): The runner of the validation process. + """ + pass + + def after_val(self, runner) -> None: + """All subclasses should override this method, if they need any + operations after val. + + Args: + runner (Runner): The runner of the validation process. + """ + pass + + def before_test(self, runner) -> None: + """All subclasses should override this method, if they need any + operations before test. + + Args: + runner (Runner): The runner of the testing process. + """ + pass + + def after_test(self, runner) -> None: + """All subclasses should override this method, if they need any + operations after test. + + Args: + runner (Runner): The runner of the testing process. + """ + pass + def before_epoch(self, runner) -> None: """All subclasses should override this method, if they need any operations before each epoch. @@ -64,7 +120,9 @@ class Hook: def after_iter(self, runner, data_batch: DATA_BATCH = None, - outputs: Optional[Sequence[BaseDataSample]] = None) -> None: + outputs: + Optional[Union[dict, Sequence[BaseDataSample]]] = None) \ + -> None: """All subclasses should override this method, if they need any operations after each epoch. @@ -72,8 +130,8 @@ class Hook: runner (Runner): The runner of the training process. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. - outputs (Sequence[BaseDataSample], optional): Outputs from model. - Defaults to None. + outputs (dict or sequence, optional): Outputs from model. Defaults + to None. """ pass @@ -184,11 +242,10 @@ class Hook: """ self.before_iter(runner, data_batch=None) - def after_train_iter( - self, - runner, - data_batch: DATA_BATCH = None, - outputs: Optional[Sequence[BaseDataSample]] = None) -> None: + def after_train_iter(self, + runner, + data_batch: DATA_BATCH = None, + outputs: Optional[dict] = None) -> None: """All subclasses should override this method, if they need any operations after each training iteration. @@ -196,16 +253,16 @@ class Hook: runner (Runner): The runner of the training process. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. - outputs (Sequence[BaseDataSample], optional): Outputs from model. + outputs (dict, optional): Outputs from model. Defaults to None. """ self.after_iter(runner, data_batch=None, outputs=None) - def after_val_iter( - self, - runner, - data_batch: DATA_BATCH = None, - outputs: Optional[Sequence[BaseDataSample]] = None) -> None: + def after_val_iter(self, + runner, + data_batch: DATA_BATCH = None, + outputs: Optional[Sequence[BaseDataSample]] = None) \ + -> None: """All subclasses should override this method, if they need any operations after each validation iteration. @@ -213,7 +270,7 @@ class Hook: runner (Runner): The runner of the training process. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. - outputs (Sequence[BaseDataSample], optional): Outputs from + outputs (dict or sequence, optional): Outputs from model. Defaults to None. """ self.after_iter(runner, data_batch=None, outputs=None) @@ -230,7 +287,7 @@ class Hook: runner (Runner): The runner of the training process. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. - outputs (Sequence[BaseDataSample], optional): Outputs from model. + outputs (dict, optional): Outputs from model. Defaults to None. """ self.after_iter(runner, data_batch=None, outputs=None) diff --git a/mmengine/hooks/iter_timer_hook.py b/mmengine/hooks/iter_timer_hook.py index d1d6404f..72701504 100644 --- a/mmengine/hooks/iter_timer_hook.py +++ b/mmengine/hooks/iter_timer_hook.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import time -from typing import Any, Optional, Sequence, Tuple +from typing import Any, Optional, Sequence, Tuple, Union from mmengine.data import BaseDataSample from mmengine.registry import HOOKS @@ -40,15 +40,17 @@ class IterTimerHook(Hook): def after_iter(self, runner, data_batch: DATA_BATCH = None, - outputs: Optional[Sequence[BaseDataSample]] = None) -> None: + outputs: + Optional[Union[dict, Sequence[BaseDataSample]]] = None) \ + -> None: """Logging time for a iteration and update the time flag. Args: runner (Runner): The runner of the training process. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. - outputs (Sequence[BaseDataSample]): Outputs from model. - Defaults to None. + outputs (dict or sequence, optional): Outputs from model. Defaults + to None. """ # TODO: update for new logging system runner.log_buffer.update({'time': time.time() - self.t}) diff --git a/mmengine/hooks/logger_hook.py b/mmengine/hooks/logger_hook.py index 741bfd95..b557c2d0 100644 --- a/mmengine/hooks/logger_hook.py +++ b/mmengine/hooks/logger_hook.py @@ -171,18 +171,17 @@ class LoggerHook(Hook): if runner.meta is not None: runner.writer.add_params(runner.meta, file_path=self.yaml_log_path) - def after_train_iter( - self, - runner, - data_batch: DATA_BATCH = None, - outputs: Optional[Sequence[BaseDataSample]] = None) -> None: + def after_train_iter(self, + runner, + data_batch: DATA_BATCH = None, + outputs: Optional[dict] = None) -> None: """Record training logs. Args: runner (Runner): The runner of the training process. data_batch (Sequence[BaseDataSample], optional): Data from dataloader. Defaults to None. - outputs (Sequence[BaseDataSample], optional): Outputs from model. + outputs (dict, optional): Outputs from model. Defaults to None. """ if runner.meta is not None and 'exp_name' in runner.meta: diff --git a/mmengine/hooks/optimizer_hook.py b/mmengine/hooks/optimizer_hook.py index 3a9dabbc..418bf06d 100644 --- a/mmengine/hooks/optimizer_hook.py +++ b/mmengine/hooks/optimizer_hook.py @@ -56,11 +56,10 @@ class OptimizerHook(Hook): return clip_grad.clip_grad_norm_(params, **self.grad_clip) return None - def after_train_iter( - self, - runner, - data_batch: DATA_BATCH = None, - outputs: Optional[Sequence[BaseDataSample]] = None) -> None: + def after_train_iter(self, + runner, + data_batch: DATA_BATCH = None, + outputs: Optional[dict] = None) -> None: """All operations need to be finished after each training iteration. This function will finish following 3 operations: @@ -80,7 +79,7 @@ class OptimizerHook(Hook): from dataloader. In order to keep this interface consistent with other hooks, we keep ``data_batch`` here. Defaults to None. - outputs (Sequence[BaseDataSample], optional): Outputs from model. + outputs (dict, optional): Outputs from model. In order to keep this interface consistent with other hooks, we keep ``outputs`` here. Defaults to None. """ diff --git a/mmengine/hooks/param_scheduler_hook.py b/mmengine/hooks/param_scheduler_hook.py index 095b1bb3..a85ef3ac 100644 --- a/mmengine/hooks/param_scheduler_hook.py +++ b/mmengine/hooks/param_scheduler_hook.py @@ -15,11 +15,10 @@ class ParamSchedulerHook(Hook): priority = 'LOW' - def after_train_iter( - self, - runner, - data_batch: DATA_BATCH = None, - outputs: Optional[Sequence[BaseDataSample]] = None) -> None: + def after_train_iter(self, + runner, + data_batch: DATA_BATCH = None, + outputs: Optional[dict] = None) -> None: """Call step function for each scheduler after each iteration. Args: @@ -28,7 +27,7 @@ class ParamSchedulerHook(Hook): from dataloader. In order to keep this interface consistent with other hooks, we keep ``data_batch`` here. Defaults to None. - outputs (Sequence[BaseDataSample], optional): Outputs from model. + outputs (dict, optional): Outputs from model. In order to keep this interface consistent with other hooks, we keep ``data_batch`` here. Defaults to None. """ -- GitLab