From a9ad09bded3e573ee2f7811ee7052cfa2cf0b6c7 Mon Sep 17 00:00:00 2001
From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Date: Tue, 23 Aug 2022 16:56:47 +0800
Subject: [PATCH] [Fix] Fix utils ut (#458)

---
 mmengine/utils/__init__.py           |  4 ++--
 tests/test_utils/test_progressbar.py | 29 ++++++++++++++--------------
 tests/test_utils/test_timer.py       | 19 +++++++++---------
 3 files changed, 27 insertions(+), 25 deletions(-)

diff --git a/mmengine/utils/__init__.py b/mmengine/utils/__init__.py
index f193a22e..8e53cb7c 100644
--- a/mmengine/utils/__init__.py
+++ b/mmengine/utils/__init__.py
@@ -18,7 +18,7 @@ from .progressbar import (ProgressBar, track_iter_progress,
                           track_parallel_progress, track_progress)
 from .setup_env import set_multi_processing
 from .sync_bn import revert_sync_batchnorm
-from .timer import Timer, check_time
+from .timer import Timer, TimerError, check_time
 from .torch_ops import torch_meshgrid
 from .trace import is_jit_tracing
 from .version_utils import digit_version, get_git_hash
@@ -38,7 +38,7 @@ __all__ = [
     'ManagerMeta', 'ManagerMixin', 'set_multi_processing', 'has_batch_norm',
     'is_abs', 'is_installed', 'call_command', 'get_installed_path',
     'check_install_package', 'is_abs', 'revert_sync_batchnorm', 'collect_env',
-    'Timer', 'check_time', 'ProgressBar', 'track_iter_progress',
+    'Timer', 'check_time', 'TimerError', 'ProgressBar', 'track_iter_progress',
     'track_parallel_progress', 'track_progress', 'torch_meshgrid',
     'is_jit_tracing'
 ]
diff --git a/tests/test_utils/test_progressbar.py b/tests/test_utils/test_progressbar.py
index 982aa247..04e7cb84 100644
--- a/tests/test_utils/test_progressbar.py
+++ b/tests/test_utils/test_progressbar.py
@@ -4,7 +4,7 @@ import time
 from io import StringIO
 from unittest.mock import patch
 
-import mmcv
+import mmengine
 
 
 def reset_string_io(io):
@@ -18,20 +18,21 @@ class TestProgressBar:
         out = StringIO()
         bar_width = 20
         # without total task num
-        prog_bar = mmcv.ProgressBar(bar_width=bar_width, file=out)
+        prog_bar = mmengine.ProgressBar(bar_width=bar_width, file=out)
         assert out.getvalue() == 'completed: 0, elapsed: 0s'
         reset_string_io(out)
-        prog_bar = mmcv.ProgressBar(bar_width=bar_width, start=False, file=out)
+        prog_bar = mmengine.ProgressBar(
+            bar_width=bar_width, start=False, file=out)
         assert out.getvalue() == ''
         reset_string_io(out)
         prog_bar.start()
         assert out.getvalue() == 'completed: 0, elapsed: 0s'
         # with total task num
         reset_string_io(out)
-        prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out)
+        prog_bar = mmengine.ProgressBar(10, bar_width=bar_width, file=out)
         assert out.getvalue() == f'[{" " * bar_width}] 0/10, elapsed: 0s, ETA:'
         reset_string_io(out)
-        prog_bar = mmcv.ProgressBar(
+        prog_bar = mmengine.ProgressBar(
             10, bar_width=bar_width, start=False, file=out)
         assert out.getvalue() == ''
         reset_string_io(out)
@@ -42,14 +43,14 @@ class TestProgressBar:
         out = StringIO()
         bar_width = 20
         # without total task num
-        prog_bar = mmcv.ProgressBar(bar_width=bar_width, file=out)
+        prog_bar = mmengine.ProgressBar(bar_width=bar_width, file=out)
         time.sleep(1)
         reset_string_io(out)
         prog_bar.update()
         assert out.getvalue() == 'completed: 1, elapsed: 1s, 1.0 tasks/s'
         reset_string_io(out)
         # with total task num
-        prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out)
+        prog_bar = mmengine.ProgressBar(10, bar_width=bar_width, file=out)
         time.sleep(1)
         reset_string_io(out)
         prog_bar.update()
@@ -60,7 +61,7 @@ class TestProgressBar:
         with patch.dict('os.environ', {'COLUMNS': '80'}):
             out = StringIO()
             bar_width = 20
-            prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out)
+            prog_bar = mmengine.ProgressBar(10, bar_width=bar_width, file=out)
             time.sleep(1)
             reset_string_io(out)
             prog_bar.update()
@@ -84,7 +85,7 @@ def sleep_1s(num):
 
 def test_track_progress_list():
     out = StringIO()
-    ret = mmcv.track_progress(sleep_1s, [1, 2, 3], bar_width=3, file=out)
+    ret = mmengine.track_progress(sleep_1s, [1, 2, 3], bar_width=3, file=out)
     assert out.getvalue() == (
         '[   ] 0/3, elapsed: 0s, ETA:'
         '\r[>  ] 1/3, 1.0 task/s, elapsed: 1s, ETA:     2s'
@@ -95,7 +96,7 @@ def test_track_progress_list():
 
 def test_track_progress_iterator():
     out = StringIO()
-    ret = mmcv.track_progress(
+    ret = mmengine.track_progress(
         sleep_1s, ((i for i in [1, 2, 3]), 3), bar_width=3, file=out)
     assert out.getvalue() == (
         '[   ] 0/3, elapsed: 0s, ETA:'
@@ -108,7 +109,7 @@ def test_track_progress_iterator():
 def test_track_iter_progress():
     out = StringIO()
     ret = []
-    for num in mmcv.track_iter_progress([1, 2, 3], bar_width=3, file=out):
+    for num in mmengine.track_iter_progress([1, 2, 3], bar_width=3, file=out):
         ret.append(sleep_1s(num))
     assert out.getvalue() == (
         '[   ] 0/3, elapsed: 0s, ETA:'
@@ -123,7 +124,7 @@ def test_track_enum_progress():
     ret = []
     count = []
     for i, num in enumerate(
-            mmcv.track_iter_progress([1, 2, 3], bar_width=3, file=out)):
+            mmengine.track_iter_progress([1, 2, 3], bar_width=3, file=out)):
         ret.append(sleep_1s(num))
         count.append(i)
     assert out.getvalue() == (
@@ -137,7 +138,7 @@ def test_track_enum_progress():
 
 def test_track_parallel_progress_list():
     out = StringIO()
-    results = mmcv.track_parallel_progress(
+    results = mmengine.track_parallel_progress(
         sleep_1s, [1, 2, 3, 4], 2, bar_width=4, file=out)
     # The following cannot pass CI on Github Action
     # assert out.getvalue() == (
@@ -151,7 +152,7 @@ def test_track_parallel_progress_list():
 
 def test_track_parallel_progress_iterator():
     out = StringIO()
-    results = mmcv.track_parallel_progress(
+    results = mmengine.track_parallel_progress(
         sleep_1s, ((i for i in [1, 2, 3, 4]), 4), 2, bar_width=4, file=out)
     # The following cannot pass CI on Github Action
     # assert out.getvalue() == (
diff --git a/tests/test_utils/test_timer.py b/tests/test_utils/test_timer.py
index e9f59135..4051775a 100644
--- a/tests/test_utils/test_timer.py
+++ b/tests/test_utils/test_timer.py
@@ -1,39 +1,40 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 import time
 
-import mmcv
 import pytest
 
+import mmengine
+
 
 def test_timer_init():
-    timer = mmcv.Timer(start=False)
+    timer = mmengine.Timer(start=False)
     assert not timer.is_running
     timer.start()
     assert timer.is_running
-    timer = mmcv.Timer()
+    timer = mmengine.Timer()
     assert timer.is_running
 
 
 def test_timer_run():
-    timer = mmcv.Timer()
+    timer = mmengine.Timer()
     time.sleep(1)
     assert abs(timer.since_start() - 1) < 1e-2
     time.sleep(1)
     assert abs(timer.since_last_check() - 1) < 1e-2
     assert abs(timer.since_start() - 2) < 1e-2
-    timer = mmcv.Timer(False)
-    with pytest.raises(mmcv.TimerError):
+    timer = mmengine.Timer(False)
+    with pytest.raises(mmengine.TimerError):
         timer.since_start()
-    with pytest.raises(mmcv.TimerError):
+    with pytest.raises(mmengine.TimerError):
         timer.since_last_check()
 
 
 def test_timer_context(capsys):
-    with mmcv.Timer():
+    with mmengine.Timer():
         time.sleep(1)
     out, _ = capsys.readouterr()
     assert abs(float(out) - 1) < 1e-2
-    with mmcv.Timer(print_tmpl='time: {:.1f}s'):
+    with mmengine.Timer(print_tmpl='time: {:.1f}s'):
         time.sleep(1)
     out, _ = capsys.readouterr()
     assert out == 'time: 1.0s\n'
-- 
GitLab