From f98ba606296f207366429dab2237c4a293a20177 Mon Sep 17 00:00:00 2001
From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Date: Mon, 15 Aug 2022 10:57:58 +0800
Subject: [PATCH] [Enhancement] Improve unit tests of mmengine/runner (#182)

* [Enhancement] Add unit test for get_priority

* fix priority ut

* fix typo

Co-authored-by: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com>
---
 mmengine/runner/__init__.py        |  4 +++-
 tests/test_runner/test_priority.py | 29 +++++++++++++++++++++++++++++
 2 files changed, 32 insertions(+), 1 deletion(-)
 create mode 100644 tests/test_runner/test_priority.py

diff --git a/mmengine/runner/__init__.py b/mmengine/runner/__init__.py
index fce566d9..801347c3 100644
--- a/mmengine/runner/__init__.py
+++ b/mmengine/runner/__init__.py
@@ -7,6 +7,7 @@ from .checkpoint import (CheckpointLoader, find_latest_checkpoint,
                          get_torchvision_models, load_checkpoint,
                          load_state_dict, save_checkpoint, weights_to_cpu)
 from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop
+from .priority import Priority, get_priority
 from .runner import Runner
 
 __all__ = [
@@ -14,5 +15,6 @@ __all__ = [
     'get_external_models', 'get_mmcls_models', 'get_deprecated_model_names',
     'CheckpointLoader', 'load_checkpoint', 'weights_to_cpu', 'get_state_dict',
     'save_checkpoint', 'EpochBasedTrainLoop', 'IterBasedTrainLoop', 'ValLoop',
-    'TestLoop', 'Runner', 'find_latest_checkpoint', 'autocast'
+    'TestLoop', 'Runner', 'get_priority', 'Priority', 'find_latest_checkpoint',
+    'autocast'
 ]
diff --git a/tests/test_runner/test_priority.py b/tests/test_runner/test_priority.py
new file mode 100644
index 00000000..20658978
--- /dev/null
+++ b/tests/test_runner/test_priority.py
@@ -0,0 +1,29 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import pytest
+
+from mmengine.runner import Priority, get_priority
+
+
+def test_get_priority():
+    # test `priority` parameter which can be int, str or Priority
+    # `priority` is an integer
+    assert get_priority(10) == 10
+    # `priority` is an integer but it exceeds the valid ranges
+    with pytest.raises(ValueError, match='priority must be between 0 and 100'):
+        get_priority(-1)
+    with pytest.raises(ValueError, match='priority must be between 0 and 100'):
+        get_priority(101)
+
+    # `priority` is a Priority enum value
+    assert get_priority(Priority.HIGHEST) == 0
+    assert get_priority(Priority.LOWEST) == 100
+
+    # `priority` is a string
+    assert get_priority('HIGHEST') == 0
+    assert get_priority('LOWEST') == 100
+
+    # `priority` is an invalid type
+    with pytest.raises(
+            TypeError,
+            match='priority must be an integer or Priority enum value'):
+        get_priority([10])
-- 
GitLab