From be9971781e00b99bdb482dd594b2b2dd01b1aa51 Mon Sep 17 00:00:00 2001 From: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com> Date: Mon, 7 Mar 2022 14:00:05 +0800 Subject: [PATCH] [Fix]: Change the type of runner in docstring to Runner (#103) * [Fix]: Change after inter and epoch to after train iter and epoch * [Fix]: Add new UT to param scheduler hook * [Fix]: Change the type of runner in docstring to Runner Co-authored-by: Your <you@example.com> --- mmengine/hooks/checkpoint_hook.py | 8 +-- mmengine/hooks/hook.py | 52 ++++++++++---------- mmengine/hooks/optimizer_hook.py | 4 +- mmengine/hooks/param_scheduler_hook.py | 16 +++--- mmengine/hooks/sampler_seed_hook.py | 2 +- tests/test_hook/test_param_scheduler_hook.py | 4 +- 6 files changed, 43 insertions(+), 43 deletions(-) diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index 89c742f3..af6c1de3 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -72,7 +72,7 @@ class CheckpointHook(Hook): to save these checkpoints of the model. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. """ if not self.out_dir: self.out_dir = runner.work_dir # type: ignore @@ -113,7 +113,7 @@ class CheckpointHook(Hook): """Save the checkpoint and synchronize buffers after each epoch. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. """ if not self.by_epoch: return @@ -137,7 +137,7 @@ class CheckpointHook(Hook): """Save the current checkpoint and delete outdated checkpoint. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. """ runner.save_checkpoint( # type: ignore self.out_dir, @@ -184,7 +184,7 @@ class CheckpointHook(Hook): """Save the checkpoint and synchronize buffers after each iteration. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. outputs (Sequence[BaseDataSample], optional): Outputs from model. diff --git a/mmengine/hooks/hook.py b/mmengine/hooks/hook.py index 2582bc7b..684196de 100644 --- a/mmengine/hooks/hook.py +++ b/mmengine/hooks/hook.py @@ -17,7 +17,7 @@ class Hook: operations before the training process. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. """ pass @@ -26,7 +26,7 @@ class Hook: operations after the training process. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. """ pass @@ -35,7 +35,7 @@ class Hook: operations before each epoch. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. """ pass @@ -44,7 +44,7 @@ class Hook: operations after each epoch. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. """ pass @@ -57,7 +57,7 @@ class Hook: operations before each iter. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. """ @@ -72,7 +72,7 @@ class Hook: operations after each epoch. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. outputs (Sequence[BaseDataSample], optional): Outputs from model. @@ -85,7 +85,7 @@ class Hook: operations before saving the checkpoint. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. checkpoints (dict): Model's checkpoint. """ pass @@ -95,7 +95,7 @@ class Hook: operations after loading the checkpoint. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. checkpoints (dict): Model's checkpoint. """ pass @@ -105,7 +105,7 @@ class Hook: operations before each training epoch. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. """ self.before_epoch(runner) @@ -114,7 +114,7 @@ class Hook: operations before each validation epoch. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. """ self.before_epoch(runner) @@ -123,7 +123,7 @@ class Hook: operations before each test epoch. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. """ self.before_epoch(runner) @@ -132,7 +132,7 @@ class Hook: operations after each training epoch. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. """ self.after_epoch(runner) @@ -141,7 +141,7 @@ class Hook: operations after each validation epoch. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. """ self.after_epoch(runner) @@ -150,7 +150,7 @@ class Hook: operations after each test epoch. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. """ self.after_epoch(runner) @@ -163,7 +163,7 @@ class Hook: operations before each training iteration. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. """ @@ -178,7 +178,7 @@ class Hook: operations before each validation iteration. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. """ @@ -193,7 +193,7 @@ class Hook: operations before each test iteration. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. """ @@ -208,7 +208,7 @@ class Hook: operations after each training iteration. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. outputs (Sequence[BaseDataSample], optional): Outputs from model. @@ -225,7 +225,7 @@ class Hook: operations after each validation iteration. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. outputs (Sequence[BaseDataSample], optional): Outputs from @@ -242,7 +242,7 @@ class Hook: operations after each test iteration. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. Defaults to None. outputs (Sequence[BaseDataSample], optional): Outputs from model. @@ -254,7 +254,7 @@ class Hook: """Test whether or not current epoch can be evenly divided by n. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. n (int): Whether or not current epoch can be evenly divided by n. Returns: @@ -267,7 +267,7 @@ class Hook: n. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. n (int): Whether or not current inner iteration can be evenly divided by n. @@ -282,7 +282,7 @@ class Hook: """Test whether or not current iteration can be evenly divided by n. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. n (int): Whether or not current iteration can be evenly divided by n. @@ -296,7 +296,7 @@ class Hook: """Check whether the current epoch reaches the `max_epochs` or not. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. Returns: bool: whether the end of current epoch or not. @@ -307,7 +307,7 @@ class Hook: """Test whether or not current epoch is the last epoch. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. Returns: bool: bool: Return True if the current epoch reaches the @@ -319,7 +319,7 @@ class Hook: """Test whether or not current epoch is the last iteration. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. Returns: bool: whether or not current iteration is the last iteration. diff --git a/mmengine/hooks/optimizer_hook.py b/mmengine/hooks/optimizer_hook.py index a22cb1bb..bbccd667 100644 --- a/mmengine/hooks/optimizer_hook.py +++ b/mmengine/hooks/optimizer_hook.py @@ -73,7 +73,7 @@ class OptimizerHook(Hook): - Update model parameters with gradients. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. In order to keep this interface consistent with other hooks, we keep ``data_batch`` here. @@ -105,7 +105,7 @@ class OptimizerHook(Hook): Args: loss (torch.Tensor): The loss of current iteration. - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. """ logger = runner.logger # type: ignore parameters_in_graph = set() diff --git a/mmengine/hooks/param_scheduler_hook.py b/mmengine/hooks/param_scheduler_hook.py index e99d9251..fe162d12 100644 --- a/mmengine/hooks/param_scheduler_hook.py +++ b/mmengine/hooks/param_scheduler_hook.py @@ -13,15 +13,15 @@ class ParamSchedulerHook(Hook): priority = 'LOW' - def after_iter(self, - runner: object, - data_batch: Optional[Sequence[Tuple[ - Any, BaseDataSample]]] = None, - outputs: Optional[Sequence[BaseDataSample]] = None) -> None: + def after_train_iter( + self, + runner: object, + data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None, + outputs: Optional[Sequence[BaseDataSample]] = None) -> None: """Call step function for each scheduler after each iteration. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data from dataloader. In order to keep this interface consistent with other hooks, we keep ``data_batch`` here. @@ -34,11 +34,11 @@ class ParamSchedulerHook(Hook): if not scheduler.by_epoch: scheduler.step() - def after_epoch(self, runner: object) -> None: + def after_train_epoch(self, runner: object) -> None: """Call step function for each scheduler after each epoch. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. """ for scheduler in runner.schedulers: # type: ignore if scheduler.by_epoch: diff --git a/mmengine/hooks/sampler_seed_hook.py b/mmengine/hooks/sampler_seed_hook.py index 6d665172..0896636d 100644 --- a/mmengine/hooks/sampler_seed_hook.py +++ b/mmengine/hooks/sampler_seed_hook.py @@ -18,7 +18,7 @@ class DistSamplerSeedHook(Hook): """Set the seed for sampler and batch_sampler. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. """ if hasattr(runner.data_loader.sampler, 'set_epoch'): # type: ignore # in case the data loader uses `SequentialSampler` in Pytorch diff --git a/tests/test_hook/test_param_scheduler_hook.py b/tests/test_hook/test_param_scheduler_hook.py index 75f12c4a..f944d3ed 100644 --- a/tests/test_hook/test_param_scheduler_hook.py +++ b/tests/test_hook/test_param_scheduler_hook.py @@ -13,7 +13,7 @@ class TestParamSchedulerHook: scheduler.step = Mock() scheduler.by_epoch = False Runner.schedulers = [scheduler] - Hook.after_iter(Runner) + Hook.after_train_iter(Runner) scheduler.step.assert_called() def test_after_epoch(self): @@ -23,5 +23,5 @@ class TestParamSchedulerHook: scheduler.step = Mock() scheduler.by_epoch = True Runner.schedulers = [scheduler] - Hook.after_epoch(Runner) + Hook.after_train_epoch(Runner) scheduler.step.assert_called() -- GitLab