From fd8515641284b332a11f3f6f943aeab3f20887e1 Mon Sep 17 00:00:00 2001
From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Date: Sat, 5 Mar 2022 17:44:31 +0800
Subject: [PATCH] fix type hint and format (#88)

---
 mmengine/data/base_data_element.py       |  2 +-
 mmengine/data/base_data_sample.py        |  2 +-
 mmengine/evaluator/base.py               | 11 ++++++-----
 mmengine/evaluator/builder.py            |  6 +++++-
 mmengine/evaluator/composed_evaluator.py | 10 +++++-----
 mmengine/hooks/checkpoint_hook.py        |  2 ++
 mmengine/hooks/empty_cache_hook.py       |  2 ++
 mmengine/hooks/hook.py                   |  2 ++
 mmengine/hooks/iter_timer_hook.py        |  2 ++
 mmengine/hooks/optimizer_hook.py         |  2 ++
 mmengine/hooks/param_scheduler_hook.py   |  2 ++
 mmengine/hooks/sampler_seed_hook.py      |  2 ++
 mmengine/hooks/sync_buffer_hook.py       |  2 ++
 mmengine/model/wrappers/utils.py         |  2 +-
 14 files changed, 35 insertions(+), 14 deletions(-)

diff --git a/mmengine/data/base_data_element.py b/mmengine/data/base_data_element.py
index ac5870bc..609ce538 100644
--- a/mmengine/data/base_data_element.py
+++ b/mmengine/data/base_data_element.py
@@ -406,7 +406,7 @@ class BaseDataElement:
 
     # Tensor-like methods
     def numpy(self) -> 'BaseDataElement':
-        """Convert all tensor to np.narray in metainfo and data."""
+        """Convert all tensor to np.narray in metainfo and data."""
         new_data = self.new()
         for k, v in self.data_items():
             if isinstance(v, torch.Tensor):
diff --git a/mmengine/data/base_data_sample.py b/mmengine/data/base_data_sample.py
index dd6fafd1..659a4a8b 100644
--- a/mmengine/data/base_data_sample.py
+++ b/mmengine/data/base_data_sample.py
@@ -500,7 +500,7 @@ class BaseDataSample:
 
     # Tensor-like methods
     def numpy(self) -> 'BaseDataSample':
-        """Convert all tensor to np.narray in metainfo and data."""
+        """Convert all tensor to np.narray in metainfo and data."""
         new_data = self.new()
         for k, v in self.data_items():
             if isinstance(v, (torch.Tensor, BaseDataElement)):
diff --git a/mmengine/evaluator/base.py b/mmengine/evaluator/base.py
index 287c2fe2..51fbd17e 100644
--- a/mmengine/evaluator/base.py
+++ b/mmengine/evaluator/base.py
@@ -10,6 +10,7 @@ from typing import Any, List, Optional, Union
 import torch
 import torch.distributed as dist
 
+from mmengine.data import BaseDataSample
 from mmengine.utils import mkdir_or_exist
 
 
@@ -45,13 +46,13 @@ class BaseEvaluator(metaclass=ABCMeta):
         self._dataset_meta = dataset_meta
 
     @abstractmethod
-    def process(self, data_samples: dict, predictions: dict) -> None:
+    def process(self, data_samples: BaseDataSample, predictions: dict) -> None:
         """Process one batch of data samples and predictions. The processed
         results should be stored in ``self.results``, which will be used to
         compute the metrics when all batches have been processed.
 
         Args:
-            data_samples (dict): The data samples from the dataset.
+            data_samples (BaseDataSample): The data samples from the dataset.
             predictions (dict): The output of the model.
         """
 
@@ -61,6 +62,7 @@ class BaseEvaluator(metaclass=ABCMeta):
 
         Args:
             results (list): The processed results of each batch.
+
         Returns:
             dict: The computed metrics. The keys are the names of the metrics,
             and the values are corresponding results.
@@ -78,9 +80,8 @@ class BaseEvaluator(metaclass=ABCMeta):
                 this size.
 
         Returns:
-            metrics (dict): Evaluation metrics dict on the val dataset. The
-            keys are the names of the metrics, and the values are
-            corresponding results.
+            dict: Evaluation metrics dict on the val dataset. The keys are the
+            names of the metrics, and the values are corresponding results.
         """
         if len(self.results) == 0:
             warnings.warn(
diff --git a/mmengine/evaluator/builder.py b/mmengine/evaluator/builder.py
index 710c6554..2a8fb3d8 100644
--- a/mmengine/evaluator/builder.py
+++ b/mmengine/evaluator/builder.py
@@ -1,9 +1,13 @@
 # Copyright (c) OpenMMLab. All rights reserved.
+from typing import Union
+
 from ..registry import EVALUATORS
+from .base import BaseEvaluator
 from .composed_evaluator import ComposedEvaluator
 
 
-def build_evaluator(cfg: dict) -> object:
+def build_evaluator(
+        cfg: Union[dict, list]) -> Union[BaseEvaluator, ComposedEvaluator]:
     """Build function of evaluator.
 
     When the evaluator config is a list, it will automatically build composed
diff --git a/mmengine/evaluator/composed_evaluator.py b/mmengine/evaluator/composed_evaluator.py
index 225284e7..c0ba27f9 100644
--- a/mmengine/evaluator/composed_evaluator.py
+++ b/mmengine/evaluator/composed_evaluator.py
@@ -1,6 +1,7 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 from typing import Optional, Sequence, Union
 
+from mmengine.data import BaseDataSample
 from .base import BaseEvaluator
 
 
@@ -31,11 +32,11 @@ class ComposedEvaluator:
         for evaluator in self.evaluators:
             evaluator.dataset_meta = dataset_meta
 
-    def process(self, data_samples: dict, predictions: dict):
+    def process(self, data_samples: BaseDataSample, predictions: dict):
         """Invoke process method of each wrapped evaluator.
 
         Args:
-            data_samples (dict): The data samples from the dataset.
+            data_samples (BaseDataSample): The data samples from the dataset.
             predictions (dict): The output of the model.
         """
 
@@ -54,9 +55,8 @@ class ComposedEvaluator:
                 this size.
 
         Returns:
-            metrics (dict): Evaluation metrics of all wrapped evaluators. The
-            keys are the names of the metrics, and the values are
-            corresponding results.
+            dict: Evaluation metrics of all wrapped evaluators. The keys are
+            the names of the metrics, and the values are corresponding results.
         """
         metrics = {}
         for evaluator in self.evaluators:
diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py
index 7baa99ed..14a7ab7b 100644
--- a/mmengine/hooks/checkpoint_hook.py
+++ b/mmengine/hooks/checkpoint_hook.py
@@ -43,6 +43,8 @@ class CheckpointHook(Hook):
             Default: None.
     """
 
+    priority = 'VERY_LOW'
+
     def __init__(self,
                  interval: int = -1,
                  by_epoch: bool = True,
diff --git a/mmengine/hooks/empty_cache_hook.py b/mmengine/hooks/empty_cache_hook.py
index b457f2c0..44bf53ec 100644
--- a/mmengine/hooks/empty_cache_hook.py
+++ b/mmengine/hooks/empty_cache_hook.py
@@ -22,6 +22,8 @@ class EmptyCacheHook(Hook):
             Defaults to False.
     """
 
+    priority = 'NORMAL'
+
     def __init__(self,
                  before_epoch: bool = False,
                  after_epoch: bool = True,
diff --git a/mmengine/hooks/hook.py b/mmengine/hooks/hook.py
index 8321af83..f0ccb1f7 100644
--- a/mmengine/hooks/hook.py
+++ b/mmengine/hooks/hook.py
@@ -10,6 +10,8 @@ class Hook:
     All hooks should inherit from this class.
     """
 
+    priority = 'NORMAL'
+
     def before_run(self, runner: object) -> None:
         """All subclasses should override this method, if they need any
         operations before the training process.
diff --git a/mmengine/hooks/iter_timer_hook.py b/mmengine/hooks/iter_timer_hook.py
index ecc84465..3c637056 100644
--- a/mmengine/hooks/iter_timer_hook.py
+++ b/mmengine/hooks/iter_timer_hook.py
@@ -14,6 +14,8 @@ class IterTimerHook(Hook):
     Eg. ``data_time`` for loading data and ``time`` for a model train step.
     """
 
+    priority = 'NORMAL'
+
     def before_epoch(self, runner: object) -> None:
         """Record time flag before start a epoch.
 
diff --git a/mmengine/hooks/optimizer_hook.py b/mmengine/hooks/optimizer_hook.py
index 8689b9fa..99f010ab 100644
--- a/mmengine/hooks/optimizer_hook.py
+++ b/mmengine/hooks/optimizer_hook.py
@@ -30,6 +30,8 @@ class OptimizerHook(Hook):
             Defaults to False.
     """
 
+    priority = 'HIGH'
+
     def __init__(self,
                  grad_clip: Optional[dict] = None,
                  detect_anomalous_params: bool = False) -> None:
diff --git a/mmengine/hooks/param_scheduler_hook.py b/mmengine/hooks/param_scheduler_hook.py
index 1bbf610f..425ab123 100644
--- a/mmengine/hooks/param_scheduler_hook.py
+++ b/mmengine/hooks/param_scheduler_hook.py
@@ -11,6 +11,8 @@ class ParamSchedulerHook(Hook):
     """A hook to update some hyper-parameters in optimizer, e.g learning rate
     and momentum."""
 
+    priority = 'LOW'
+
     def after_iter(self,
                    runner: object,
                    data_batch: Optional[Sequence[BaseDataSample]] = None,
diff --git a/mmengine/hooks/sampler_seed_hook.py b/mmengine/hooks/sampler_seed_hook.py
index d2c991a9..6d665172 100644
--- a/mmengine/hooks/sampler_seed_hook.py
+++ b/mmengine/hooks/sampler_seed_hook.py
@@ -12,6 +12,8 @@ class DistSamplerSeedHook(Hook):
     purpose with :obj:`IterLoader`.
     """
 
+    priority = 'NORMAL'
+
     def before_epoch(self, runner: object) -> None:
         """Set the seed for sampler and batch_sampler.
 
diff --git a/mmengine/hooks/sync_buffer_hook.py b/mmengine/hooks/sync_buffer_hook.py
index 89edb55d..f62910e8 100644
--- a/mmengine/hooks/sync_buffer_hook.py
+++ b/mmengine/hooks/sync_buffer_hook.py
@@ -84,6 +84,8 @@ class SyncBuffersHook(Hook):
     """Synchronize model buffers such as running_mean and running_var in BN at
     the end of each epoch."""
 
+    priority = 'NORMAL'
+
     def __init__(self) -> None:
         self.distributed = dist.IS_DIST
 
diff --git a/mmengine/model/wrappers/utils.py b/mmengine/model/wrappers/utils.py
index f888f49c..f952e9b1 100644
--- a/mmengine/model/wrappers/utils.py
+++ b/mmengine/model/wrappers/utils.py
@@ -8,7 +8,7 @@ def is_model_wrapper(model):
     The following 4 model in MMEngine (and their subclasses) are regarded as
     model wrappers: DataParallel, DistributedDataParallel,
     MMDataParallel, MMDistributedDataParallel. You may add you own
-    model wrapper by registering it to mmengine.registry.MODEL_WRAPPERS.
+    model wrapper by registering it to ``mmengine.registry.MODEL_WRAPPERS``.
 
     Args:
         model (nn.Module): The model to be checked.
-- 
GitLab