Skip to content
Snippets Groups Projects
Unverified Commit 64b1d183 authored by RangiLyu's avatar RangiLyu Committed by GitHub
Browse files

Add runner unit tests. (#68)

* add runner unit tests

* update

* update

* add test custom loop and hook

* add test model wrapper

* add test setup env

* fix typo

* fix launcher

* fix typo

* test default scope

* add logger test

* fix dataloader

* add test loop

* resolve comments

* resolve comments
parent c87adc66
No related branches found
No related tags found
No related merge requests found
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import Mock
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from mmengine.runner.loop import (EpochBasedTrainLoop, IterBasedTrainLoop,
TestLoop, ValLoop)
class ToyDataset(Dataset):
META = dict() # type: ignore
data = np.zeros((30, 1, 1, 1))
def __len__(self):
return self.data.shape[0]
def __getitem__(self, index):
return torch.from_numpy(self.data[index])
class TestLoops(TestCase):
def setUp(self) -> None:
self.runner = Mock()
self.runner.call_hooks = Mock()
self.runner.model = Mock()
self.runner.epoch = 0
self.runner.iter = 0
self.runner.inner_iter = 0
self.runner.model.train_step = Mock()
self.runner.model.val_step = Mock()
self.evaluator = Mock()
self.evaluator.process = Mock()
self.evaluator.evaluate = Mock()
def test_epoch_based_train_loop(self):
train_loop = EpochBasedTrainLoop(
runner=self.runner, loader=DataLoader(ToyDataset()), max_epoch=3)
train_loop.run()
assert train_loop.runner.epoch == 3
assert train_loop.runner.iter == 90
def test_iter_based_train_loop(self):
train_loop = IterBasedTrainLoop(
runner=self.runner, loader=DataLoader(ToyDataset()), max_iter=25)
train_loop.run()
assert train_loop.runner.epoch == 0
assert train_loop.runner.iter == 25
def test_val_loop(self):
val_loop = ValLoop(
runner=self.runner,
loader=DataLoader(ToyDataset()),
evaluator=self.evaluator)
val_loop.run()
def test_test_loop(self):
test_loop = TestLoop(
runner=self.runner,
loader=DataLoader(ToyDataset()),
evaluator=self.evaluator)
test_loop.run()
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment