From e907931fb8df733d416bc934298e5da1d8512ab5 Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Sun, 21 Aug 2022 14:54:24 +0800 Subject: [PATCH] Fix unit tests (#449) --- mmengine/runner/loops.py | 3 +-- tests/test_hooks/test_ema_hook.py | 3 ++- tests/test_runner/test_runner.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 7c52a31e..8524269d 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -9,7 +9,6 @@ from torch.utils.data import DataLoader from mmengine.evaluator import Evaluator from mmengine.registry import LOOPS -from mmengine.utils import is_list_of from .amp import autocast from .base_loop import BaseLoop from .utils import calc_dynamic_intervals @@ -389,7 +388,7 @@ class TestLoop(BaseLoop): fp16: bool = False): super().__init__(runner, dataloader) - if isinstance(evaluator, dict) or is_list_of(evaluator, dict): + if isinstance(evaluator, dict) or isinstance(evaluator, list): self.evaluator = runner.build_evaluator(evaluator) # type: ignore else: self.evaluator = evaluator # type: ignore diff --git a/tests/test_hooks/test_ema_hook.py b/tests/test_hooks/test_ema_hook.py index 571125bc..f1aeea56 100644 --- a/tests/test_hooks/test_ema_hook.py +++ b/tests/test_hooks/test_ema_hook.py @@ -8,6 +8,7 @@ import torch import torch.nn as nn from torch.utils.data import Dataset +from mmengine.evaluator import Evaluator from mmengine.hooks import EMAHook from mmengine.model import BaseModel, ExponentialMovingAverage from mmengine.optim import OptimWrapper @@ -81,7 +82,7 @@ class TestEMAHook(TestCase): def test_ema_hook(self): device = 'cuda:0' if torch.cuda.is_available() else 'cpu' model = ToyModel1().to(device) - evaluator = Mock() + evaluator = Evaluator([]) evaluator.evaluate = Mock(return_value=dict(acc=0.5)) runner = Runner( model=model, diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index 79896b51..7499ead2 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -558,10 +558,10 @@ class TestRunner(TestCase): param_scheduler=MultiStepLR(optim_wrapper, milestones=[1, 2]), val_cfg=dict(), val_dataloader=val_dataloader, - val_evaluator=ToyMetric1(), + val_evaluator=[ToyMetric1()], test_cfg=dict(), test_dataloader=test_dataloader, - test_evaluator=ToyMetric1(), + test_evaluator=[ToyMetric1()], default_hooks=dict(param_scheduler=toy_hook), custom_hooks=[toy_hook2], experiment_name='test_init14') -- GitLab