Skip to content
Snippets Groups Projects
Unverified Commit f98ba606 authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

[Enhancement] Improve unit tests of mmengine/runner (#182)


* [Enhancement] Add unit test for get_priority

* fix priority ut

* fix typo

Co-authored-by: default avatarWenwei Zhang <40779233+ZwwWayne@users.noreply.github.com>
parent 2708b7ed
No related branches found
No related tags found
No related merge requests found
...@@ -7,6 +7,7 @@ from .checkpoint import (CheckpointLoader, find_latest_checkpoint, ...@@ -7,6 +7,7 @@ from .checkpoint import (CheckpointLoader, find_latest_checkpoint,
get_torchvision_models, load_checkpoint, get_torchvision_models, load_checkpoint,
load_state_dict, save_checkpoint, weights_to_cpu) load_state_dict, save_checkpoint, weights_to_cpu)
from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop
from .priority import Priority, get_priority
from .runner import Runner from .runner import Runner
__all__ = [ __all__ = [
...@@ -14,5 +15,6 @@ __all__ = [ ...@@ -14,5 +15,6 @@ __all__ = [
'get_external_models', 'get_mmcls_models', 'get_deprecated_model_names', 'get_external_models', 'get_mmcls_models', 'get_deprecated_model_names',
'CheckpointLoader', 'load_checkpoint', 'weights_to_cpu', 'get_state_dict', 'CheckpointLoader', 'load_checkpoint', 'weights_to_cpu', 'get_state_dict',
'save_checkpoint', 'EpochBasedTrainLoop', 'IterBasedTrainLoop', 'ValLoop', 'save_checkpoint', 'EpochBasedTrainLoop', 'IterBasedTrainLoop', 'ValLoop',
'TestLoop', 'Runner', 'find_latest_checkpoint', 'autocast' 'TestLoop', 'Runner', 'get_priority', 'Priority', 'find_latest_checkpoint',
'autocast'
] ]
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
from mmengine.runner import Priority, get_priority
def test_get_priority():
# test `priority` parameter which can be int, str or Priority
# `priority` is an integer
assert get_priority(10) == 10
# `priority` is an integer but it exceeds the valid ranges
with pytest.raises(ValueError, match='priority must be between 0 and 100'):
get_priority(-1)
with pytest.raises(ValueError, match='priority must be between 0 and 100'):
get_priority(101)
# `priority` is a Priority enum value
assert get_priority(Priority.HIGHEST) == 0
assert get_priority(Priority.LOWEST) == 100
# `priority` is a string
assert get_priority('HIGHEST') == 0
assert get_priority('LOWEST') == 100
# `priority` is an invalid type
with pytest.raises(
TypeError,
match='priority must be an integer or Priority enum value'):
get_priority([10])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment