Skip to content
Snippets Groups Projects
Unverified Commit e907931f authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

Fix unit tests (#449)

parent 429bb279
No related branches found
No related tags found
No related merge requests found
...@@ -9,7 +9,6 @@ from torch.utils.data import DataLoader ...@@ -9,7 +9,6 @@ from torch.utils.data import DataLoader
from mmengine.evaluator import Evaluator from mmengine.evaluator import Evaluator
from mmengine.registry import LOOPS from mmengine.registry import LOOPS
from mmengine.utils import is_list_of
from .amp import autocast from .amp import autocast
from .base_loop import BaseLoop from .base_loop import BaseLoop
from .utils import calc_dynamic_intervals from .utils import calc_dynamic_intervals
...@@ -389,7 +388,7 @@ class TestLoop(BaseLoop): ...@@ -389,7 +388,7 @@ class TestLoop(BaseLoop):
fp16: bool = False): fp16: bool = False):
super().__init__(runner, dataloader) super().__init__(runner, dataloader)
if isinstance(evaluator, dict) or is_list_of(evaluator, dict): if isinstance(evaluator, dict) or isinstance(evaluator, list):
self.evaluator = runner.build_evaluator(evaluator) # type: ignore self.evaluator = runner.build_evaluator(evaluator) # type: ignore
else: else:
self.evaluator = evaluator # type: ignore self.evaluator = evaluator # type: ignore
......
...@@ -8,6 +8,7 @@ import torch ...@@ -8,6 +8,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.data import Dataset from torch.utils.data import Dataset
from mmengine.evaluator import Evaluator
from mmengine.hooks import EMAHook from mmengine.hooks import EMAHook
from mmengine.model import BaseModel, ExponentialMovingAverage from mmengine.model import BaseModel, ExponentialMovingAverage
from mmengine.optim import OptimWrapper from mmengine.optim import OptimWrapper
...@@ -81,7 +82,7 @@ class TestEMAHook(TestCase): ...@@ -81,7 +82,7 @@ class TestEMAHook(TestCase):
def test_ema_hook(self): def test_ema_hook(self):
device = 'cuda:0' if torch.cuda.is_available() else 'cpu' device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model = ToyModel1().to(device) model = ToyModel1().to(device)
evaluator = Mock() evaluator = Evaluator([])
evaluator.evaluate = Mock(return_value=dict(acc=0.5)) evaluator.evaluate = Mock(return_value=dict(acc=0.5))
runner = Runner( runner = Runner(
model=model, model=model,
......
...@@ -558,10 +558,10 @@ class TestRunner(TestCase): ...@@ -558,10 +558,10 @@ class TestRunner(TestCase):
param_scheduler=MultiStepLR(optim_wrapper, milestones=[1, 2]), param_scheduler=MultiStepLR(optim_wrapper, milestones=[1, 2]),
val_cfg=dict(), val_cfg=dict(),
val_dataloader=val_dataloader, val_dataloader=val_dataloader,
val_evaluator=ToyMetric1(), val_evaluator=[ToyMetric1()],
test_cfg=dict(), test_cfg=dict(),
test_dataloader=test_dataloader, test_dataloader=test_dataloader,
test_evaluator=ToyMetric1(), test_evaluator=[ToyMetric1()],
default_hooks=dict(param_scheduler=toy_hook), default_hooks=dict(param_scheduler=toy_hook),
custom_hooks=[toy_hook2], custom_hooks=[toy_hook2],
experiment_name='test_init14') experiment_name='test_init14')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment