From f98ba606296f207366429dab2237c4a293a20177 Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Mon, 15 Aug 2022 10:57:58 +0800 Subject: [PATCH] [Enhancement] Improve unit tests of mmengine/runner (#182) * [Enhancement] Add unit test for get_priority * fix priority ut * fix typo Co-authored-by: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com> --- mmengine/runner/__init__.py | 4 +++- tests/test_runner/test_priority.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) create mode 100644 tests/test_runner/test_priority.py diff --git a/mmengine/runner/__init__.py b/mmengine/runner/__init__.py index fce566d9..801347c3 100644 --- a/mmengine/runner/__init__.py +++ b/mmengine/runner/__init__.py @@ -7,6 +7,7 @@ from .checkpoint import (CheckpointLoader, find_latest_checkpoint, get_torchvision_models, load_checkpoint, load_state_dict, save_checkpoint, weights_to_cpu) from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop +from .priority import Priority, get_priority from .runner import Runner __all__ = [ @@ -14,5 +15,6 @@ __all__ = [ 'get_external_models', 'get_mmcls_models', 'get_deprecated_model_names', 'CheckpointLoader', 'load_checkpoint', 'weights_to_cpu', 'get_state_dict', 'save_checkpoint', 'EpochBasedTrainLoop', 'IterBasedTrainLoop', 'ValLoop', - 'TestLoop', 'Runner', 'find_latest_checkpoint', 'autocast' + 'TestLoop', 'Runner', 'get_priority', 'Priority', 'find_latest_checkpoint', + 'autocast' ] diff --git a/tests/test_runner/test_priority.py b/tests/test_runner/test_priority.py new file mode 100644 index 00000000..20658978 --- /dev/null +++ b/tests/test_runner/test_priority.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest + +from mmengine.runner import Priority, get_priority + + +def test_get_priority(): + # test `priority` parameter which can be int, str or Priority + # `priority` is an integer + assert get_priority(10) == 10 + # `priority` is an integer but it exceeds the valid ranges + with pytest.raises(ValueError, match='priority must be between 0 and 100'): + get_priority(-1) + with pytest.raises(ValueError, match='priority must be between 0 and 100'): + get_priority(101) + + # `priority` is a Priority enum value + assert get_priority(Priority.HIGHEST) == 0 + assert get_priority(Priority.LOWEST) == 100 + + # `priority` is a string + assert get_priority('HIGHEST') == 0 + assert get_priority('LOWEST') == 100 + + # `priority` is an invalid type + with pytest.raises( + TypeError, + match='priority must be an integer or Priority enum value'): + get_priority([10]) -- GitLab