From e907931fb8df733d416bc934298e5da1d8512ab5 Mon Sep 17 00:00:00 2001
From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>
Date: Sun, 21 Aug 2022 14:54:24 +0800
Subject: [PATCH] Fix unit tests (#449)

---
 mmengine/runner/loops.py          | 3 +--
 tests/test_hooks/test_ema_hook.py | 3 ++-
 tests/test_runner/test_runner.py  | 4 ++--
 3 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py
index 7c52a31e..8524269d 100644
--- a/mmengine/runner/loops.py
+++ b/mmengine/runner/loops.py
@@ -9,7 +9,6 @@ from torch.utils.data import DataLoader
 
 from mmengine.evaluator import Evaluator
 from mmengine.registry import LOOPS
-from mmengine.utils import is_list_of
 from .amp import autocast
 from .base_loop import BaseLoop
 from .utils import calc_dynamic_intervals
@@ -389,7 +388,7 @@ class TestLoop(BaseLoop):
                  fp16: bool = False):
         super().__init__(runner, dataloader)
 
-        if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
+        if isinstance(evaluator, dict) or isinstance(evaluator, list):
             self.evaluator = runner.build_evaluator(evaluator)  # type: ignore
         else:
             self.evaluator = evaluator  # type: ignore
diff --git a/tests/test_hooks/test_ema_hook.py b/tests/test_hooks/test_ema_hook.py
index 571125bc..f1aeea56 100644
--- a/tests/test_hooks/test_ema_hook.py
+++ b/tests/test_hooks/test_ema_hook.py
@@ -8,6 +8,7 @@ import torch
 import torch.nn as nn
 from torch.utils.data import Dataset
 
+from mmengine.evaluator import Evaluator
 from mmengine.hooks import EMAHook
 from mmengine.model import BaseModel, ExponentialMovingAverage
 from mmengine.optim import OptimWrapper
@@ -81,7 +82,7 @@ class TestEMAHook(TestCase):
     def test_ema_hook(self):
         device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
         model = ToyModel1().to(device)
-        evaluator = Mock()
+        evaluator = Evaluator([])
         evaluator.evaluate = Mock(return_value=dict(acc=0.5))
         runner = Runner(
             model=model,
diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py
index 79896b51..7499ead2 100644
--- a/tests/test_runner/test_runner.py
+++ b/tests/test_runner/test_runner.py
@@ -558,10 +558,10 @@ class TestRunner(TestCase):
             param_scheduler=MultiStepLR(optim_wrapper, milestones=[1, 2]),
             val_cfg=dict(),
             val_dataloader=val_dataloader,
-            val_evaluator=ToyMetric1(),
+            val_evaluator=[ToyMetric1()],
             test_cfg=dict(),
             test_dataloader=test_dataloader,
-            test_evaluator=ToyMetric1(),
+            test_evaluator=[ToyMetric1()],
             default_hooks=dict(param_scheduler=toy_hook),
             custom_hooks=[toy_hook2],
             experiment_name='test_init14')
-- 
GitLab