From fd8515641284b332a11f3f6f943aeab3f20887e1 Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Sat, 5 Mar 2022 17:44:31 +0800 Subject: [PATCH] fix type hint and format (#88) --- mmengine/data/base_data_element.py | 2 +- mmengine/data/base_data_sample.py | 2 +- mmengine/evaluator/base.py | 11 ++++++----- mmengine/evaluator/builder.py | 6 +++++- mmengine/evaluator/composed_evaluator.py | 10 +++++----- mmengine/hooks/checkpoint_hook.py | 2 ++ mmengine/hooks/empty_cache_hook.py | 2 ++ mmengine/hooks/hook.py | 2 ++ mmengine/hooks/iter_timer_hook.py | 2 ++ mmengine/hooks/optimizer_hook.py | 2 ++ mmengine/hooks/param_scheduler_hook.py | 2 ++ mmengine/hooks/sampler_seed_hook.py | 2 ++ mmengine/hooks/sync_buffer_hook.py | 2 ++ mmengine/model/wrappers/utils.py | 2 +- 14 files changed, 35 insertions(+), 14 deletions(-) diff --git a/mmengine/data/base_data_element.py b/mmengine/data/base_data_element.py index ac5870bc..609ce538 100644 --- a/mmengine/data/base_data_element.py +++ b/mmengine/data/base_data_element.py @@ -406,7 +406,7 @@ class BaseDataElement: # Tensor-like methods def numpy(self) -> 'BaseDataElement': - """Convert all tensor to np.narray in metainfo and data.""" + """Convert all tensor to np.narray in metainfo and data.""" new_data = self.new() for k, v in self.data_items(): if isinstance(v, torch.Tensor): diff --git a/mmengine/data/base_data_sample.py b/mmengine/data/base_data_sample.py index dd6fafd1..659a4a8b 100644 --- a/mmengine/data/base_data_sample.py +++ b/mmengine/data/base_data_sample.py @@ -500,7 +500,7 @@ class BaseDataSample: # Tensor-like methods def numpy(self) -> 'BaseDataSample': - """Convert all tensor to np.narray in metainfo and data.""" + """Convert all tensor to np.narray in metainfo and data.""" new_data = self.new() for k, v in self.data_items(): if isinstance(v, (torch.Tensor, BaseDataElement)): diff --git a/mmengine/evaluator/base.py b/mmengine/evaluator/base.py index 287c2fe2..51fbd17e 100644 --- a/mmengine/evaluator/base.py +++ b/mmengine/evaluator/base.py @@ -10,6 +10,7 @@ from typing import Any, List, Optional, Union import torch import torch.distributed as dist +from mmengine.data import BaseDataSample from mmengine.utils import mkdir_or_exist @@ -45,13 +46,13 @@ class BaseEvaluator(metaclass=ABCMeta): self._dataset_meta = dataset_meta @abstractmethod - def process(self, data_samples: dict, predictions: dict) -> None: + def process(self, data_samples: BaseDataSample, predictions: dict) -> None: """Process one batch of data samples and predictions. The processed results should be stored in ``self.results``, which will be used to compute the metrics when all batches have been processed. Args: - data_samples (dict): The data samples from the dataset. + data_samples (BaseDataSample): The data samples from the dataset. predictions (dict): The output of the model. """ @@ -61,6 +62,7 @@ class BaseEvaluator(metaclass=ABCMeta): Args: results (list): The processed results of each batch. + Returns: dict: The computed metrics. The keys are the names of the metrics, and the values are corresponding results. @@ -78,9 +80,8 @@ class BaseEvaluator(metaclass=ABCMeta): this size. Returns: - metrics (dict): Evaluation metrics dict on the val dataset. The - keys are the names of the metrics, and the values are - corresponding results. + dict: Evaluation metrics dict on the val dataset. The keys are the + names of the metrics, and the values are corresponding results. """ if len(self.results) == 0: warnings.warn( diff --git a/mmengine/evaluator/builder.py b/mmengine/evaluator/builder.py index 710c6554..2a8fb3d8 100644 --- a/mmengine/evaluator/builder.py +++ b/mmengine/evaluator/builder.py @@ -1,9 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Union + from ..registry import EVALUATORS +from .base import BaseEvaluator from .composed_evaluator import ComposedEvaluator -def build_evaluator(cfg: dict) -> object: +def build_evaluator( + cfg: Union[dict, list]) -> Union[BaseEvaluator, ComposedEvaluator]: """Build function of evaluator. When the evaluator config is a list, it will automatically build composed diff --git a/mmengine/evaluator/composed_evaluator.py b/mmengine/evaluator/composed_evaluator.py index 225284e7..c0ba27f9 100644 --- a/mmengine/evaluator/composed_evaluator.py +++ b/mmengine/evaluator/composed_evaluator.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import Optional, Sequence, Union +from mmengine.data import BaseDataSample from .base import BaseEvaluator @@ -31,11 +32,11 @@ class ComposedEvaluator: for evaluator in self.evaluators: evaluator.dataset_meta = dataset_meta - def process(self, data_samples: dict, predictions: dict): + def process(self, data_samples: BaseDataSample, predictions: dict): """Invoke process method of each wrapped evaluator. Args: - data_samples (dict): The data samples from the dataset. + data_samples (BaseDataSample): The data samples from the dataset. predictions (dict): The output of the model. """ @@ -54,9 +55,8 @@ class ComposedEvaluator: this size. Returns: - metrics (dict): Evaluation metrics of all wrapped evaluators. The - keys are the names of the metrics, and the values are - corresponding results. + dict: Evaluation metrics of all wrapped evaluators. The keys are + the names of the metrics, and the values are corresponding results. """ metrics = {} for evaluator in self.evaluators: diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index 7baa99ed..14a7ab7b 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -43,6 +43,8 @@ class CheckpointHook(Hook): Default: None. """ + priority = 'VERY_LOW' + def __init__(self, interval: int = -1, by_epoch: bool = True, diff --git a/mmengine/hooks/empty_cache_hook.py b/mmengine/hooks/empty_cache_hook.py index b457f2c0..44bf53ec 100644 --- a/mmengine/hooks/empty_cache_hook.py +++ b/mmengine/hooks/empty_cache_hook.py @@ -22,6 +22,8 @@ class EmptyCacheHook(Hook): Defaults to False. """ + priority = 'NORMAL' + def __init__(self, before_epoch: bool = False, after_epoch: bool = True, diff --git a/mmengine/hooks/hook.py b/mmengine/hooks/hook.py index 8321af83..f0ccb1f7 100644 --- a/mmengine/hooks/hook.py +++ b/mmengine/hooks/hook.py @@ -10,6 +10,8 @@ class Hook: All hooks should inherit from this class. """ + priority = 'NORMAL' + def before_run(self, runner: object) -> None: """All subclasses should override this method, if they need any operations before the training process. diff --git a/mmengine/hooks/iter_timer_hook.py b/mmengine/hooks/iter_timer_hook.py index ecc84465..3c637056 100644 --- a/mmengine/hooks/iter_timer_hook.py +++ b/mmengine/hooks/iter_timer_hook.py @@ -14,6 +14,8 @@ class IterTimerHook(Hook): Eg. ``data_time`` for loading data and ``time`` for a model train step. """ + priority = 'NORMAL' + def before_epoch(self, runner: object) -> None: """Record time flag before start a epoch. diff --git a/mmengine/hooks/optimizer_hook.py b/mmengine/hooks/optimizer_hook.py index 8689b9fa..99f010ab 100644 --- a/mmengine/hooks/optimizer_hook.py +++ b/mmengine/hooks/optimizer_hook.py @@ -30,6 +30,8 @@ class OptimizerHook(Hook): Defaults to False. """ + priority = 'HIGH' + def __init__(self, grad_clip: Optional[dict] = None, detect_anomalous_params: bool = False) -> None: diff --git a/mmengine/hooks/param_scheduler_hook.py b/mmengine/hooks/param_scheduler_hook.py index 1bbf610f..425ab123 100644 --- a/mmengine/hooks/param_scheduler_hook.py +++ b/mmengine/hooks/param_scheduler_hook.py @@ -11,6 +11,8 @@ class ParamSchedulerHook(Hook): """A hook to update some hyper-parameters in optimizer, e.g learning rate and momentum.""" + priority = 'LOW' + def after_iter(self, runner: object, data_batch: Optional[Sequence[BaseDataSample]] = None, diff --git a/mmengine/hooks/sampler_seed_hook.py b/mmengine/hooks/sampler_seed_hook.py index d2c991a9..6d665172 100644 --- a/mmengine/hooks/sampler_seed_hook.py +++ b/mmengine/hooks/sampler_seed_hook.py @@ -12,6 +12,8 @@ class DistSamplerSeedHook(Hook): purpose with :obj:`IterLoader`. """ + priority = 'NORMAL' + def before_epoch(self, runner: object) -> None: """Set the seed for sampler and batch_sampler. diff --git a/mmengine/hooks/sync_buffer_hook.py b/mmengine/hooks/sync_buffer_hook.py index 89edb55d..f62910e8 100644 --- a/mmengine/hooks/sync_buffer_hook.py +++ b/mmengine/hooks/sync_buffer_hook.py @@ -84,6 +84,8 @@ class SyncBuffersHook(Hook): """Synchronize model buffers such as running_mean and running_var in BN at the end of each epoch.""" + priority = 'NORMAL' + def __init__(self) -> None: self.distributed = dist.IS_DIST diff --git a/mmengine/model/wrappers/utils.py b/mmengine/model/wrappers/utils.py index f888f49c..f952e9b1 100644 --- a/mmengine/model/wrappers/utils.py +++ b/mmengine/model/wrappers/utils.py @@ -8,7 +8,7 @@ def is_model_wrapper(model): The following 4 model in MMEngine (and their subclasses) are regarded as model wrappers: DataParallel, DistributedDataParallel, MMDataParallel, MMDistributedDataParallel. You may add you own - model wrapper by registering it to mmengine.registry.MODEL_WRAPPERS. + model wrapper by registering it to ``mmengine.registry.MODEL_WRAPPERS``. Args: model (nn.Module): The model to be checked. -- GitLab