From a9ad09bded3e573ee2f7811ee7052cfa2cf0b6c7 Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Tue, 23 Aug 2022 16:56:47 +0800 Subject: [PATCH] [Fix] Fix utils ut (#458) --- mmengine/utils/__init__.py | 4 ++-- tests/test_utils/test_progressbar.py | 29 ++++++++++++++-------------- tests/test_utils/test_timer.py | 19 +++++++++--------- 3 files changed, 27 insertions(+), 25 deletions(-) diff --git a/mmengine/utils/__init__.py b/mmengine/utils/__init__.py index f193a22e..8e53cb7c 100644 --- a/mmengine/utils/__init__.py +++ b/mmengine/utils/__init__.py @@ -18,7 +18,7 @@ from .progressbar import (ProgressBar, track_iter_progress, track_parallel_progress, track_progress) from .setup_env import set_multi_processing from .sync_bn import revert_sync_batchnorm -from .timer import Timer, check_time +from .timer import Timer, TimerError, check_time from .torch_ops import torch_meshgrid from .trace import is_jit_tracing from .version_utils import digit_version, get_git_hash @@ -38,7 +38,7 @@ __all__ = [ 'ManagerMeta', 'ManagerMixin', 'set_multi_processing', 'has_batch_norm', 'is_abs', 'is_installed', 'call_command', 'get_installed_path', 'check_install_package', 'is_abs', 'revert_sync_batchnorm', 'collect_env', - 'Timer', 'check_time', 'ProgressBar', 'track_iter_progress', + 'Timer', 'check_time', 'TimerError', 'ProgressBar', 'track_iter_progress', 'track_parallel_progress', 'track_progress', 'torch_meshgrid', 'is_jit_tracing' ] diff --git a/tests/test_utils/test_progressbar.py b/tests/test_utils/test_progressbar.py index 982aa247..04e7cb84 100644 --- a/tests/test_utils/test_progressbar.py +++ b/tests/test_utils/test_progressbar.py @@ -4,7 +4,7 @@ import time from io import StringIO from unittest.mock import patch -import mmcv +import mmengine def reset_string_io(io): @@ -18,20 +18,21 @@ class TestProgressBar: out = StringIO() bar_width = 20 # without total task num - prog_bar = mmcv.ProgressBar(bar_width=bar_width, file=out) + prog_bar = mmengine.ProgressBar(bar_width=bar_width, file=out) assert out.getvalue() == 'completed: 0, elapsed: 0s' reset_string_io(out) - prog_bar = mmcv.ProgressBar(bar_width=bar_width, start=False, file=out) + prog_bar = mmengine.ProgressBar( + bar_width=bar_width, start=False, file=out) assert out.getvalue() == '' reset_string_io(out) prog_bar.start() assert out.getvalue() == 'completed: 0, elapsed: 0s' # with total task num reset_string_io(out) - prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out) + prog_bar = mmengine.ProgressBar(10, bar_width=bar_width, file=out) assert out.getvalue() == f'[{" " * bar_width}] 0/10, elapsed: 0s, ETA:' reset_string_io(out) - prog_bar = mmcv.ProgressBar( + prog_bar = mmengine.ProgressBar( 10, bar_width=bar_width, start=False, file=out) assert out.getvalue() == '' reset_string_io(out) @@ -42,14 +43,14 @@ class TestProgressBar: out = StringIO() bar_width = 20 # without total task num - prog_bar = mmcv.ProgressBar(bar_width=bar_width, file=out) + prog_bar = mmengine.ProgressBar(bar_width=bar_width, file=out) time.sleep(1) reset_string_io(out) prog_bar.update() assert out.getvalue() == 'completed: 1, elapsed: 1s, 1.0 tasks/s' reset_string_io(out) # with total task num - prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out) + prog_bar = mmengine.ProgressBar(10, bar_width=bar_width, file=out) time.sleep(1) reset_string_io(out) prog_bar.update() @@ -60,7 +61,7 @@ class TestProgressBar: with patch.dict('os.environ', {'COLUMNS': '80'}): out = StringIO() bar_width = 20 - prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out) + prog_bar = mmengine.ProgressBar(10, bar_width=bar_width, file=out) time.sleep(1) reset_string_io(out) prog_bar.update() @@ -84,7 +85,7 @@ def sleep_1s(num): def test_track_progress_list(): out = StringIO() - ret = mmcv.track_progress(sleep_1s, [1, 2, 3], bar_width=3, file=out) + ret = mmengine.track_progress(sleep_1s, [1, 2, 3], bar_width=3, file=out) assert out.getvalue() == ( '[ ] 0/3, elapsed: 0s, ETA:' '\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s' @@ -95,7 +96,7 @@ def test_track_progress_list(): def test_track_progress_iterator(): out = StringIO() - ret = mmcv.track_progress( + ret = mmengine.track_progress( sleep_1s, ((i for i in [1, 2, 3]), 3), bar_width=3, file=out) assert out.getvalue() == ( '[ ] 0/3, elapsed: 0s, ETA:' @@ -108,7 +109,7 @@ def test_track_progress_iterator(): def test_track_iter_progress(): out = StringIO() ret = [] - for num in mmcv.track_iter_progress([1, 2, 3], bar_width=3, file=out): + for num in mmengine.track_iter_progress([1, 2, 3], bar_width=3, file=out): ret.append(sleep_1s(num)) assert out.getvalue() == ( '[ ] 0/3, elapsed: 0s, ETA:' @@ -123,7 +124,7 @@ def test_track_enum_progress(): ret = [] count = [] for i, num in enumerate( - mmcv.track_iter_progress([1, 2, 3], bar_width=3, file=out)): + mmengine.track_iter_progress([1, 2, 3], bar_width=3, file=out)): ret.append(sleep_1s(num)) count.append(i) assert out.getvalue() == ( @@ -137,7 +138,7 @@ def test_track_enum_progress(): def test_track_parallel_progress_list(): out = StringIO() - results = mmcv.track_parallel_progress( + results = mmengine.track_parallel_progress( sleep_1s, [1, 2, 3, 4], 2, bar_width=4, file=out) # The following cannot pass CI on Github Action # assert out.getvalue() == ( @@ -151,7 +152,7 @@ def test_track_parallel_progress_list(): def test_track_parallel_progress_iterator(): out = StringIO() - results = mmcv.track_parallel_progress( + results = mmengine.track_parallel_progress( sleep_1s, ((i for i in [1, 2, 3, 4]), 4), 2, bar_width=4, file=out) # The following cannot pass CI on Github Action # assert out.getvalue() == ( diff --git a/tests/test_utils/test_timer.py b/tests/test_utils/test_timer.py index e9f59135..4051775a 100644 --- a/tests/test_utils/test_timer.py +++ b/tests/test_utils/test_timer.py @@ -1,39 +1,40 @@ # Copyright (c) OpenMMLab. All rights reserved. import time -import mmcv import pytest +import mmengine + def test_timer_init(): - timer = mmcv.Timer(start=False) + timer = mmengine.Timer(start=False) assert not timer.is_running timer.start() assert timer.is_running - timer = mmcv.Timer() + timer = mmengine.Timer() assert timer.is_running def test_timer_run(): - timer = mmcv.Timer() + timer = mmengine.Timer() time.sleep(1) assert abs(timer.since_start() - 1) < 1e-2 time.sleep(1) assert abs(timer.since_last_check() - 1) < 1e-2 assert abs(timer.since_start() - 2) < 1e-2 - timer = mmcv.Timer(False) - with pytest.raises(mmcv.TimerError): + timer = mmengine.Timer(False) + with pytest.raises(mmengine.TimerError): timer.since_start() - with pytest.raises(mmcv.TimerError): + with pytest.raises(mmengine.TimerError): timer.since_last_check() def test_timer_context(capsys): - with mmcv.Timer(): + with mmengine.Timer(): time.sleep(1) out, _ = capsys.readouterr() assert abs(float(out) - 1) < 1e-2 - with mmcv.Timer(print_tmpl='time: {:.1f}s'): + with mmengine.Timer(print_tmpl='time: {:.1f}s'): time.sleep(1) out, _ = capsys.readouterr() assert out == 'time: 1.0s\n' -- GitLab