Skip to content
Snippets Groups Projects
Unverified Commit a9ad09bd authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

[Fix] Fix utils ut (#458)

parent 6c607bd2
No related branches found
No related tags found
No related merge requests found
......@@ -18,7 +18,7 @@ from .progressbar import (ProgressBar, track_iter_progress,
track_parallel_progress, track_progress)
from .setup_env import set_multi_processing
from .sync_bn import revert_sync_batchnorm
from .timer import Timer, check_time
from .timer import Timer, TimerError, check_time
from .torch_ops import torch_meshgrid
from .trace import is_jit_tracing
from .version_utils import digit_version, get_git_hash
......@@ -38,7 +38,7 @@ __all__ = [
'ManagerMeta', 'ManagerMixin', 'set_multi_processing', 'has_batch_norm',
'is_abs', 'is_installed', 'call_command', 'get_installed_path',
'check_install_package', 'is_abs', 'revert_sync_batchnorm', 'collect_env',
'Timer', 'check_time', 'ProgressBar', 'track_iter_progress',
'Timer', 'check_time', 'TimerError', 'ProgressBar', 'track_iter_progress',
'track_parallel_progress', 'track_progress', 'torch_meshgrid',
'is_jit_tracing'
]
......@@ -4,7 +4,7 @@ import time
from io import StringIO
from unittest.mock import patch
import mmcv
import mmengine
def reset_string_io(io):
......@@ -18,20 +18,21 @@ class TestProgressBar:
out = StringIO()
bar_width = 20
# without total task num
prog_bar = mmcv.ProgressBar(bar_width=bar_width, file=out)
prog_bar = mmengine.ProgressBar(bar_width=bar_width, file=out)
assert out.getvalue() == 'completed: 0, elapsed: 0s'
reset_string_io(out)
prog_bar = mmcv.ProgressBar(bar_width=bar_width, start=False, file=out)
prog_bar = mmengine.ProgressBar(
bar_width=bar_width, start=False, file=out)
assert out.getvalue() == ''
reset_string_io(out)
prog_bar.start()
assert out.getvalue() == 'completed: 0, elapsed: 0s'
# with total task num
reset_string_io(out)
prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out)
prog_bar = mmengine.ProgressBar(10, bar_width=bar_width, file=out)
assert out.getvalue() == f'[{" " * bar_width}] 0/10, elapsed: 0s, ETA:'
reset_string_io(out)
prog_bar = mmcv.ProgressBar(
prog_bar = mmengine.ProgressBar(
10, bar_width=bar_width, start=False, file=out)
assert out.getvalue() == ''
reset_string_io(out)
......@@ -42,14 +43,14 @@ class TestProgressBar:
out = StringIO()
bar_width = 20
# without total task num
prog_bar = mmcv.ProgressBar(bar_width=bar_width, file=out)
prog_bar = mmengine.ProgressBar(bar_width=bar_width, file=out)
time.sleep(1)
reset_string_io(out)
prog_bar.update()
assert out.getvalue() == 'completed: 1, elapsed: 1s, 1.0 tasks/s'
reset_string_io(out)
# with total task num
prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out)
prog_bar = mmengine.ProgressBar(10, bar_width=bar_width, file=out)
time.sleep(1)
reset_string_io(out)
prog_bar.update()
......@@ -60,7 +61,7 @@ class TestProgressBar:
with patch.dict('os.environ', {'COLUMNS': '80'}):
out = StringIO()
bar_width = 20
prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out)
prog_bar = mmengine.ProgressBar(10, bar_width=bar_width, file=out)
time.sleep(1)
reset_string_io(out)
prog_bar.update()
......@@ -84,7 +85,7 @@ def sleep_1s(num):
def test_track_progress_list():
out = StringIO()
ret = mmcv.track_progress(sleep_1s, [1, 2, 3], bar_width=3, file=out)
ret = mmengine.track_progress(sleep_1s, [1, 2, 3], bar_width=3, file=out)
assert out.getvalue() == (
'[ ] 0/3, elapsed: 0s, ETA:'
'\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s'
......@@ -95,7 +96,7 @@ def test_track_progress_list():
def test_track_progress_iterator():
out = StringIO()
ret = mmcv.track_progress(
ret = mmengine.track_progress(
sleep_1s, ((i for i in [1, 2, 3]), 3), bar_width=3, file=out)
assert out.getvalue() == (
'[ ] 0/3, elapsed: 0s, ETA:'
......@@ -108,7 +109,7 @@ def test_track_progress_iterator():
def test_track_iter_progress():
out = StringIO()
ret = []
for num in mmcv.track_iter_progress([1, 2, 3], bar_width=3, file=out):
for num in mmengine.track_iter_progress([1, 2, 3], bar_width=3, file=out):
ret.append(sleep_1s(num))
assert out.getvalue() == (
'[ ] 0/3, elapsed: 0s, ETA:'
......@@ -123,7 +124,7 @@ def test_track_enum_progress():
ret = []
count = []
for i, num in enumerate(
mmcv.track_iter_progress([1, 2, 3], bar_width=3, file=out)):
mmengine.track_iter_progress([1, 2, 3], bar_width=3, file=out)):
ret.append(sleep_1s(num))
count.append(i)
assert out.getvalue() == (
......@@ -137,7 +138,7 @@ def test_track_enum_progress():
def test_track_parallel_progress_list():
out = StringIO()
results = mmcv.track_parallel_progress(
results = mmengine.track_parallel_progress(
sleep_1s, [1, 2, 3, 4], 2, bar_width=4, file=out)
# The following cannot pass CI on Github Action
# assert out.getvalue() == (
......@@ -151,7 +152,7 @@ def test_track_parallel_progress_list():
def test_track_parallel_progress_iterator():
out = StringIO()
results = mmcv.track_parallel_progress(
results = mmengine.track_parallel_progress(
sleep_1s, ((i for i in [1, 2, 3, 4]), 4), 2, bar_width=4, file=out)
# The following cannot pass CI on Github Action
# assert out.getvalue() == (
......
# Copyright (c) OpenMMLab. All rights reserved.
import time
import mmcv
import pytest
import mmengine
def test_timer_init():
timer = mmcv.Timer(start=False)
timer = mmengine.Timer(start=False)
assert not timer.is_running
timer.start()
assert timer.is_running
timer = mmcv.Timer()
timer = mmengine.Timer()
assert timer.is_running
def test_timer_run():
timer = mmcv.Timer()
timer = mmengine.Timer()
time.sleep(1)
assert abs(timer.since_start() - 1) < 1e-2
time.sleep(1)
assert abs(timer.since_last_check() - 1) < 1e-2
assert abs(timer.since_start() - 2) < 1e-2
timer = mmcv.Timer(False)
with pytest.raises(mmcv.TimerError):
timer = mmengine.Timer(False)
with pytest.raises(mmengine.TimerError):
timer.since_start()
with pytest.raises(mmcv.TimerError):
with pytest.raises(mmengine.TimerError):
timer.since_last_check()
def test_timer_context(capsys):
with mmcv.Timer():
with mmengine.Timer():
time.sleep(1)
out, _ = capsys.readouterr()
assert abs(float(out) - 1) < 1e-2
with mmcv.Timer(print_tmpl='time: {:.1f}s'):
with mmengine.Timer(print_tmpl='time: {:.1f}s'):
time.sleep(1)
out, _ = capsys.readouterr()
assert out == 'time: 1.0s\n'
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment