diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 7c52a31eb8fe0923739d210c7d43ba59692d6ecf..8524269dbfcdcd25ff4095e8c82f494d33aeb8f7 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -9,7 +9,6 @@ from torch.utils.data import DataLoader from mmengine.evaluator import Evaluator from mmengine.registry import LOOPS -from mmengine.utils import is_list_of from .amp import autocast from .base_loop import BaseLoop from .utils import calc_dynamic_intervals @@ -389,7 +388,7 @@ class TestLoop(BaseLoop): fp16: bool = False): super().__init__(runner, dataloader) - if isinstance(evaluator, dict) or is_list_of(evaluator, dict): + if isinstance(evaluator, dict) or isinstance(evaluator, list): self.evaluator = runner.build_evaluator(evaluator) # type: ignore else: self.evaluator = evaluator # type: ignore diff --git a/tests/test_hooks/test_ema_hook.py b/tests/test_hooks/test_ema_hook.py index 571125bc4e8bd1a08394ce5814e07874528baa53..f1aeea56e385ebc6fc02fa5eb920d5ed547b0f6c 100644 --- a/tests/test_hooks/test_ema_hook.py +++ b/tests/test_hooks/test_ema_hook.py @@ -8,6 +8,7 @@ import torch import torch.nn as nn from torch.utils.data import Dataset +from mmengine.evaluator import Evaluator from mmengine.hooks import EMAHook from mmengine.model import BaseModel, ExponentialMovingAverage from mmengine.optim import OptimWrapper @@ -81,7 +82,7 @@ class TestEMAHook(TestCase): def test_ema_hook(self): device = 'cuda:0' if torch.cuda.is_available() else 'cpu' model = ToyModel1().to(device) - evaluator = Mock() + evaluator = Evaluator([]) evaluator.evaluate = Mock(return_value=dict(acc=0.5)) runner = Runner( model=model, diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index 79896b5108d595d98016c7f15fb02564aa45c47e..7499ead2ea1db2cfa488e21382044fb4a14869ca 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -558,10 +558,10 @@ class TestRunner(TestCase): param_scheduler=MultiStepLR(optim_wrapper, milestones=[1, 2]), val_cfg=dict(), val_dataloader=val_dataloader, - val_evaluator=ToyMetric1(), + val_evaluator=[ToyMetric1()], test_cfg=dict(), test_dataloader=test_dataloader, - test_evaluator=ToyMetric1(), + test_evaluator=[ToyMetric1()], default_hooks=dict(param_scheduler=toy_hook), custom_hooks=[toy_hook2], experiment_name='test_init14')