From ed8dcb4c61dd5faad07bc7bae05f96b38e9af6e5 Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Mon, 7 Mar 2022 19:35:37 +0800 Subject: [PATCH] fix type hint in hooks (#106) --- mmengine/hooks/checkpoint_hook.py | 63 ++++++++---------- mmengine/hooks/empty_cache_hook.py | 22 +++--- mmengine/hooks/hook.py | 92 +++++++++++--------------- mmengine/hooks/iter_timer_hook.py | 37 +++++------ mmengine/hooks/optimizer_hook.py | 31 ++++----- mmengine/hooks/param_scheduler_hook.py | 12 ++-- mmengine/hooks/sampler_seed_hook.py | 13 ++-- mmengine/hooks/sync_buffer_hook.py | 6 +- 8 files changed, 122 insertions(+), 154 deletions(-) diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index af6c1de3..a312307f 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -9,6 +9,8 @@ from mmengine.fileio import FileClient from mmengine.registry import HOOKS from .hook import Hook +DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]] + @HOOKS.register_module() class CheckpointHook(Hook): @@ -65,7 +67,7 @@ class CheckpointHook(Hook): self.sync_buffer = sync_buffer self.file_client_args = file_client_args - def before_run(self, runner: object) -> None: + def before_run(self, runner) -> None: """Finish all operations, related to checkpoint. This function will get the appropriate file client, and the directory @@ -75,7 +77,7 @@ class CheckpointHook(Hook): runner (Runner): The runner of the training process. """ if not self.out_dir: - self.out_dir = runner.work_dir # type: ignore + self.out_dir = runner.work_dir self.file_client = FileClient.infer_client(self.file_client_args, self.out_dir) @@ -84,17 +86,13 @@ class CheckpointHook(Hook): # `self.out_dir` is set so the final `self.out_dir` is the # concatenation of `self.out_dir` and the last level directory of # `runner.work_dir` - if self.out_dir != runner.work_dir: # type: ignore - basename = osp.basename( - runner.work_dir.rstrip( # type: ignore - osp.sep)) + if self.out_dir != runner.work_dir: + basename = osp.basename(runner.work_dir.rstrip(osp.sep)) self.out_dir = self.file_client.join_path( - self.out_dir, # type: ignore - basename) + self.out_dir, basename) # type: ignore # noqa: E501 - runner.logger.info(( # type: ignore - f'Checkpoints will be saved to {self.out_dir} by ' - f'{self.file_client.name}.')) + runner.logger.info((f'Checkpoints will be saved to {self.out_dir} by ' + f'{self.file_client.name}.')) # disable the create_symlink option because some file backends do not # allow to create a symlink @@ -109,7 +107,7 @@ class CheckpointHook(Hook): else: self.args['create_symlink'] = self.file_client.allow_symlink - def after_train_epoch(self, runner: object) -> None: + def after_train_epoch(self, runner) -> None: """Save the checkpoint and synchronize buffers after each epoch. Args: @@ -124,46 +122,40 @@ class CheckpointHook(Hook): if self.every_n_epochs( runner, self.interval) or (self.save_last and self.is_last_epoch(runner)): - runner.logger.info( # type: ignore - f'Saving checkpoint at \ - {runner.epoch + 1} epochs') # type: ignore + runner.logger.info(f'Saving checkpoint at \ + {runner.epoch + 1} epochs') if self.sync_buffer: pass # TODO self._save_checkpoint(runner) # TODO Add master_only decorator - def _save_checkpoint(self, runner: object) -> None: + def _save_checkpoint(self, runner) -> None: """Save the current checkpoint and delete outdated checkpoint. Args: runner (Runner): The runner of the training process. """ - runner.save_checkpoint( # type: ignore - self.out_dir, - save_optimizer=self.save_optimizer, - **self.args) - if runner.meta is not None: # type: ignore + runner.save_checkpoint( + self.out_dir, save_optimizer=self.save_optimizer, **self.args) + if runner.meta is not None: if self.by_epoch: cur_ckpt_filename = self.args.get( - 'filename_tmpl', - 'epoch_{}.pth').format(runner.epoch + 1) # type: ignore + 'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1) else: cur_ckpt_filename = self.args.get( - 'filename_tmpl', - 'iter_{}.pth').format(runner.iter + 1) # type: ignore - runner.meta.setdefault('hook_msgs', dict()) # type: ignore - runner.meta['hook_msgs'][ # type: ignore - 'last_ckpt'] = self.file_client.join_path( - self.out_dir, cur_ckpt_filename) # type: ignore + 'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1) + runner.meta.setdefault('hook_msgs', dict()) + runner.meta['hook_msgs']['last_ckpt'] = self.file_client.join_path( + self.out_dir, cur_ckpt_filename) # type: ignore # remove other checkpoints if self.max_keep_ckpts > 0: if self.by_epoch: name = 'epoch_{}.pth' - current_ckpt = runner.epoch + 1 # type: ignore + current_ckpt = runner.epoch + 1 else: name = 'iter_{}.pth' - current_ckpt = runner.iter + 1 # type: ignore + current_ckpt = runner.iter + 1 redundant_ckpts = range( current_ckpt - self.max_keep_ckpts * self.interval, 0, -self.interval) @@ -178,8 +170,8 @@ class CheckpointHook(Hook): def after_train_iter( self, - runner: object, - data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None, + runner, + data_batch: DATA_BATCH = None, outputs: Optional[Sequence[BaseDataSample]] = None) -> None: """Save the checkpoint and synchronize buffers after each iteration. @@ -199,9 +191,8 @@ class CheckpointHook(Hook): if self.every_n_iters( runner, self.interval) or (self.save_last and self.is_last_iter(runner)): - runner.logger.info( # type: ignore - f'Saving checkpoint at \ - {runner.iter + 1} iterations') # type: ignore + runner.logger.info(f'Saving checkpoint at \ + {runner.iter + 1} iterations') if self.sync_buffer: pass # TODO diff --git a/mmengine/hooks/empty_cache_hook.py b/mmengine/hooks/empty_cache_hook.py index 44bf53ec..2f621131 100644 --- a/mmengine/hooks/empty_cache_hook.py +++ b/mmengine/hooks/empty_cache_hook.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Optional, Sequence +from typing import Any, Optional, Sequence, Tuple import torch @@ -7,6 +7,8 @@ from mmengine.data import BaseDataSample from mmengine.registry import HOOKS from .hook import Hook +DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]] + @HOOKS.register_module() class EmptyCacheHook(Hook): @@ -33,35 +35,35 @@ class EmptyCacheHook(Hook): self._after_iter = after_iter def after_iter(self, - runner: object, - data_batch: Optional[Sequence[BaseDataSample]] = None, + runner, + data_batch: DATA_BATCH = None, outputs: Optional[Sequence[BaseDataSample]] = None) -> None: """Empty cache after an iteration. Args: - runner (object): The runner of the training process. - data_batch (Sequence[BaseDataSample]): Data from dataloader. - Defaults to None. + runner (Runner): The runner of the training process. + data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data + from dataloader. Defaults to None. outputs (Sequence[BaseDataSample]): Outputs from model. Defaults to None. """ if self._after_iter: torch.cuda.empty_cache() - def before_epoch(self, runner: object) -> None: + def before_epoch(self, runner) -> None: """Empty cache before an epoch. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. """ if self._before_epoch: torch.cuda.empty_cache() - def after_epoch(self, runner: object) -> None: + def after_epoch(self, runner) -> None: """Empty cache after an epoch. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. """ if self._after_epoch: torch.cuda.empty_cache() diff --git a/mmengine/hooks/hook.py b/mmengine/hooks/hook.py index 684196de..f2d52fc9 100644 --- a/mmengine/hooks/hook.py +++ b/mmengine/hooks/hook.py @@ -3,6 +3,8 @@ from typing import Any, Optional, Sequence, Tuple from mmengine.data import BaseDataSample +DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]] + class Hook: """Base hook class. @@ -12,7 +14,7 @@ class Hook: priority = 'NORMAL' - def before_run(self, runner: object) -> None: + def before_run(self, runner) -> None: """All subclasses should override this method, if they need any operations before the training process. @@ -21,7 +23,7 @@ class Hook: """ pass - def after_run(self, runner: object) -> None: + def after_run(self, runner) -> None: """All subclasses should override this method, if they need any operations after the training process. @@ -30,7 +32,7 @@ class Hook: """ pass - def before_epoch(self, runner: object) -> None: + def before_epoch(self, runner) -> None: """All subclasses should override this method, if they need any operations before each epoch. @@ -39,7 +41,7 @@ class Hook: """ pass - def after_epoch(self, runner: object) -> None: + def after_epoch(self, runner) -> None: """All subclasses should override this method, if they need any operations after each epoch. @@ -48,11 +50,7 @@ class Hook: """ pass - def before_iter( - self, - runner: object, - data_batch: Optional[Sequence[Tuple[Any, - BaseDataSample]]] = None) -> None: + def before_iter(self, runner, data_batch: DATA_BATCH = None) -> None: """All subclasses should override this method, if they need any operations before each iter. @@ -64,9 +62,8 @@ class Hook: pass def after_iter(self, - runner: object, - data_batch: Optional[Sequence[Tuple[ - Any, BaseDataSample]]] = None, + runner, + data_batch: DATA_BATCH = None, outputs: Optional[Sequence[BaseDataSample]] = None) -> None: """All subclasses should override this method, if they need any operations after each epoch. @@ -80,7 +77,7 @@ class Hook: """ pass - def before_save_checkpoint(self, runner: object, checkpoint: dict) -> None: + def before_save_checkpoint(self, runner, checkpoint: dict) -> None: """All subclasses should override this method, if they need any operations before saving the checkpoint. @@ -90,7 +87,7 @@ class Hook: """ pass - def after_load_checkpoint(self, runner: object, checkpoint: dict) -> None: + def after_load_checkpoint(self, runner, checkpoint: dict) -> None: """All subclasses should override this method, if they need any operations after loading the checkpoint. @@ -100,7 +97,7 @@ class Hook: """ pass - def before_train_epoch(self, runner: object) -> None: + def before_train_epoch(self, runner) -> None: """All subclasses should override this method, if they need any operations before each training epoch. @@ -109,7 +106,7 @@ class Hook: """ self.before_epoch(runner) - def before_val_epoch(self, runner: object) -> None: + def before_val_epoch(self, runner) -> None: """All subclasses should override this method, if they need any operations before each validation epoch. @@ -118,7 +115,7 @@ class Hook: """ self.before_epoch(runner) - def before_test_epoch(self, runner: object) -> None: + def before_test_epoch(self, runner) -> None: """All subclasses should override this method, if they need any operations before each test epoch. @@ -127,7 +124,7 @@ class Hook: """ self.before_epoch(runner) - def after_train_epoch(self, runner: object) -> None: + def after_train_epoch(self, runner) -> None: """All subclasses should override this method, if they need any operations after each training epoch. @@ -136,7 +133,7 @@ class Hook: """ self.after_epoch(runner) - def after_val_epoch(self, runner: object) -> None: + def after_val_epoch(self, runner) -> None: """All subclasses should override this method, if they need any operations after each validation epoch. @@ -145,7 +142,7 @@ class Hook: """ self.after_epoch(runner) - def after_test_epoch(self, runner: object) -> None: + def after_test_epoch(self, runner) -> None: """All subclasses should override this method, if they need any operations after each test epoch. @@ -154,11 +151,7 @@ class Hook: """ self.after_epoch(runner) - def before_train_iter( - self, - runner: object, - data_batch: Optional[Sequence[Tuple[Any, - BaseDataSample]]] = None) -> None: + def before_train_iter(self, runner, data_batch: DATA_BATCH = None) -> None: """All subclasses should override this method, if they need any operations before each training iteration. @@ -169,11 +162,7 @@ class Hook: """ self.before_iter(runner, data_batch=None) - def before_val_iter( - self, - runner: object, - data_batch: Optional[Sequence[Tuple[Any, - BaseDataSample]]] = None) -> None: + def before_val_iter(self, runner, data_batch: DATA_BATCH = None) -> None: """All subclasses should override this method, if they need any operations before each validation iteration. @@ -184,11 +173,7 @@ class Hook: """ self.before_iter(runner, data_batch=None) - def before_test_iter( - self, - runner: object, - data_batch: Optional[Sequence[Tuple[Any, - BaseDataSample]]] = None) -> None: + def before_test_iter(self, runner, data_batch: DATA_BATCH = None) -> None: """All subclasses should override this method, if they need any operations before each test iteration. @@ -201,8 +186,8 @@ class Hook: def after_train_iter( self, - runner: object, - data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None, + runner, + data_batch: DATA_BATCH = None, outputs: Optional[Sequence[BaseDataSample]] = None) -> None: """All subclasses should override this method, if they need any operations after each training iteration. @@ -218,8 +203,8 @@ class Hook: def after_val_iter( self, - runner: object, - data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None, + runner, + data_batch: DATA_BATCH = None, outputs: Optional[Sequence[BaseDataSample]] = None) -> None: """All subclasses should override this method, if they need any operations after each validation iteration. @@ -235,8 +220,8 @@ class Hook: def after_test_iter( self, - runner: object, - data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None, + runner, + data_batch: DATA_BATCH = None, outputs: Optional[Sequence[BaseDataSample]] = None) -> None: """All subclasses should override this method, if they need any operations after each test iteration. @@ -250,7 +235,7 @@ class Hook: """ self.after_iter(runner, data_batch=None, outputs=None) - def every_n_epochs(self, runner: object, n: int) -> bool: + def every_n_epochs(self, runner, n: int) -> bool: """Test whether or not current epoch can be evenly divided by n. Args: @@ -260,9 +245,9 @@ class Hook: Returns: bool: whether or not current epoch can be evenly divided by n. """ - return (runner.epoch + 1) % n == 0 if n > 0 else False # type: ignore + return (runner.epoch + 1) % n == 0 if n > 0 else False - def every_n_inner_iters(self, runner: object, n: int) -> bool: + def every_n_inner_iters(self, runner, n: int) -> bool: """Test whether or not current inner iteration can be evenly divided by n. @@ -275,10 +260,9 @@ class Hook: bool: whether or not current inner iteration can be evenly divided by n. """ - return (runner.inner_iter + # type: ignore - 1) % n == 0 if n > 0 else False + return (runner.inner_iter + 1) % n == 0 if n > 0 else False - def every_n_iters(self, runner: object, n: int) -> bool: + def every_n_iters(self, runner, n: int) -> bool: """Test whether or not current iteration can be evenly divided by n. Args: @@ -290,9 +274,9 @@ class Hook: bool: Return True if the current iteration can be evenly divided by n, otherwise False. """ - return (runner.iter + 1) % n == 0 if n > 0 else False # type: ignore + return (runner.iter + 1) % n == 0 if n > 0 else False - def end_of_epoch(self, runner: object) -> bool: + def end_of_epoch(self, runner) -> bool: """Check whether the current epoch reaches the `max_epochs` or not. Args: @@ -301,9 +285,9 @@ class Hook: Returns: bool: whether the end of current epoch or not. """ - return runner.inner_iter + 1 == len(runner.data_loader) # type: ignore + return runner.inner_iter + 1 == len(runner.data_loader) - def is_last_epoch(self, runner: object) -> bool: + def is_last_epoch(self, runner) -> bool: """Test whether or not current epoch is the last epoch. Args: @@ -313,9 +297,9 @@ class Hook: bool: bool: Return True if the current epoch reaches the `max_epochs`, otherwise False. """ - return runner.epoch + 1 == runner._max_epochs # type: ignore + return runner.epoch + 1 == runner._max_epochs - def is_last_iter(self, runner: object) -> bool: + def is_last_iter(self, runner) -> bool: """Test whether or not current epoch is the last iteration. Args: @@ -324,4 +308,4 @@ class Hook: Returns: bool: whether or not current iteration is the last iteration. """ - return runner.iter + 1 == runner._max_iters # type: ignore + return runner.iter + 1 == runner._max_iters diff --git a/mmengine/hooks/iter_timer_hook.py b/mmengine/hooks/iter_timer_hook.py index 3c637056..d1d6404f 100644 --- a/mmengine/hooks/iter_timer_hook.py +++ b/mmengine/hooks/iter_timer_hook.py @@ -1,11 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. import time -from typing import Optional, Sequence +from typing import Any, Optional, Sequence, Tuple from mmengine.data import BaseDataSample from mmengine.registry import HOOKS from .hook import Hook +DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]] + @HOOKS.register_module() class IterTimerHook(Hook): @@ -16,45 +18,38 @@ class IterTimerHook(Hook): priority = 'NORMAL' - def before_epoch(self, runner: object) -> None: + def before_epoch(self, runner) -> None: """Record time flag before start a epoch. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. """ self.t = time.time() - def before_iter( - self, - runner: object, - data_batch: Optional[Sequence[BaseDataSample]] = None) -> None: + def before_iter(self, runner, data_batch: DATA_BATCH = None) -> None: """Logging time for loading data and update the time flag. Args: - runner (object): The runner of the training process. - data_batch (Sequence[BaseDataSample]): Data from dataloader. - Defaults to None. + runner (Runner): The runner of the training process. + data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data + from dataloader. Defaults to None. """ # TODO: update for new logging system - runner.log_buffer.update({ # type: ignore - 'data_time': time.time() - self.t - }) + runner.log_buffer.update({'data_time': time.time() - self.t}) def after_iter(self, - runner: object, - data_batch: Optional[Sequence[BaseDataSample]] = None, + runner, + data_batch: DATA_BATCH = None, outputs: Optional[Sequence[BaseDataSample]] = None) -> None: """Logging time for a iteration and update the time flag. Args: - runner (object): The runner of the training process. - data_batch (Sequence[BaseDataSample]): Data from dataloader. - Defaults to None. + runner (Runner): The runner of the training process. + data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data + from dataloader. Defaults to None. outputs (Sequence[BaseDataSample]): Outputs from model. Defaults to None. """ # TODO: update for new logging system - runner.log_buffer.update({ # type: ignore - 'time': time.time() - self.t - }) + runner.log_buffer.update({'time': time.time() - self.t}) self.t = time.time() diff --git a/mmengine/hooks/optimizer_hook.py b/mmengine/hooks/optimizer_hook.py index bbccd667..3a9dabbc 100644 --- a/mmengine/hooks/optimizer_hook.py +++ b/mmengine/hooks/optimizer_hook.py @@ -10,6 +10,8 @@ from mmengine.data import BaseDataSample from mmengine.registry import HOOKS from .hook import Hook +DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]] + @HOOKS.register_module() class OptimizerHook(Hook): @@ -56,8 +58,8 @@ class OptimizerHook(Hook): def after_train_iter( self, - runner: object, - data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None, + runner, + data_batch: DATA_BATCH = None, outputs: Optional[Sequence[BaseDataSample]] = None) -> None: """All operations need to be finished after each training iteration. @@ -82,32 +84,27 @@ class OptimizerHook(Hook): In order to keep this interface consistent with other hooks, we keep ``outputs`` here. Defaults to None. """ - runner.optimizer.zero_grad() # type: ignore + runner.optimizer.zero_grad() if self.detect_anomalous_params: - self.detect_anomalous_parameters( - runner.outputs['loss'], # type: ignore - runner) - runner.outputs['loss'].backward() # type: ignore + self.detect_anomalous_parameters(runner.outputs['loss'], runner) + runner.outputs['loss'].backward() if self.grad_clip is not None: - grad_norm = self.clip_grads( - runner.model.parameters()) # type: ignore + grad_norm = self.clip_grads(runner.model.parameters()) if grad_norm is not None: # Add grad norm to the logger - runner.log_buffer.update( # type: ignore - {'grad_norm': float(grad_norm)}, - runner.outputs['num_samples']) # type: ignore - runner.optimizer.step() # type: ignore + runner.log_buffer.update({'grad_norm': float(grad_norm)}, + runner.outputs['num_samples']) + runner.optimizer.step() - def detect_anomalous_parameters(self, loss: torch.Tensor, - runner: object) -> None: + def detect_anomalous_parameters(self, loss: torch.Tensor, runner) -> None: """Detect anomalous parameters that are not included in the graph. Args: loss (torch.Tensor): The loss of current iteration. runner (Runner): The runner of the training process. """ - logger = runner.logger # type: ignore + logger = runner.logger parameters_in_graph = set() visited = set() @@ -125,7 +122,7 @@ class OptimizerHook(Hook): traverse(grad_fn) traverse(loss.grad_fn) - for n, p in runner.model.named_parameters(): # type: ignore + for n, p in runner.model.named_parameters(): if p not in parameters_in_graph and p.requires_grad: logger.log( level=logging.ERROR, diff --git a/mmengine/hooks/param_scheduler_hook.py b/mmengine/hooks/param_scheduler_hook.py index fe162d12..095b1bb3 100644 --- a/mmengine/hooks/param_scheduler_hook.py +++ b/mmengine/hooks/param_scheduler_hook.py @@ -5,6 +5,8 @@ from mmengine.data import BaseDataSample from mmengine.registry import HOOKS from .hook import Hook +DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]] + @HOOKS.register_module() class ParamSchedulerHook(Hook): @@ -15,8 +17,8 @@ class ParamSchedulerHook(Hook): def after_train_iter( self, - runner: object, - data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None, + runner, + data_batch: DATA_BATCH = None, outputs: Optional[Sequence[BaseDataSample]] = None) -> None: """Call step function for each scheduler after each iteration. @@ -30,16 +32,16 @@ class ParamSchedulerHook(Hook): In order to keep this interface consistent with other hooks, we keep ``data_batch`` here. Defaults to None. """ - for scheduler in runner.schedulers: # type: ignore + for scheduler in runner.schedulers: if not scheduler.by_epoch: scheduler.step() - def after_train_epoch(self, runner: object) -> None: + def after_train_epoch(self, runner) -> None: """Call step function for each scheduler after each epoch. Args: runner (Runner): The runner of the training process. """ - for scheduler in runner.schedulers: # type: ignore + for scheduler in runner.schedulers: if scheduler.by_epoch: scheduler.step() diff --git a/mmengine/hooks/sampler_seed_hook.py b/mmengine/hooks/sampler_seed_hook.py index 0896636d..d7d243d8 100644 --- a/mmengine/hooks/sampler_seed_hook.py +++ b/mmengine/hooks/sampler_seed_hook.py @@ -14,18 +14,15 @@ class DistSamplerSeedHook(Hook): priority = 'NORMAL' - def before_epoch(self, runner: object) -> None: + def before_epoch(self, runner) -> None: """Set the seed for sampler and batch_sampler. Args: runner (Runner): The runner of the training process. """ - if hasattr(runner.data_loader.sampler, 'set_epoch'): # type: ignore + if hasattr(runner.data_loader.sampler, 'set_epoch'): # in case the data loader uses `SequentialSampler` in Pytorch - runner.data_loader.sampler.set_epoch(runner.epoch) # type: ignore - elif hasattr( - runner.data_loader.batch_sampler.sampler, # type: ignore - 'set_epoch'): + runner.data_loader.sampler.set_epoch(runner.epoch) + elif hasattr(runner.data_loader.batch_sampler.sampler, 'set_epoch'): # batch sampler in pytorch warps the sampler as its attributes. - runner.data_loader.batch_sampler.sampler.set_epoch( # type: ignore - runner.epoch) # type: ignore + runner.data_loader.batch_sampler.sampler.set_epoch(runner.epoch) diff --git a/mmengine/hooks/sync_buffer_hook.py b/mmengine/hooks/sync_buffer_hook.py index f62910e8..9aa6402d 100644 --- a/mmengine/hooks/sync_buffer_hook.py +++ b/mmengine/hooks/sync_buffer_hook.py @@ -89,11 +89,11 @@ class SyncBuffersHook(Hook): def __init__(self) -> None: self.distributed = dist.IS_DIST - def after_epoch(self, runner: object) -> None: + def after_epoch(self, runner) -> None: """All-reduce model buffers at the end of each epoch. Args: - runner (object): The runner of the training process. + runner (Runner): The runner of the training process. """ if self.distributed: - allreduce_params(runner.model.buffers()) # type: ignore + allreduce_params(runner.model.buffers()) -- GitLab