From 64b1d183b992699d10da409981239301bd2e372c Mon Sep 17 00:00:00 2001
From: RangiLyu <lyuchqi@gmail.com>
Date: Thu, 3 Mar 2022 19:44:36 +0800
Subject: [PATCH] Add runner unit tests. (#68)

* add runner unit tests

* update

* update

* add test custom loop and hook

* add test model wrapper

* add test setup env

* fix typo

* fix launcher

* fix typo

* test default scope

* add logger test

* fix dataloader

* add test loop

* resolve comments

* resolve comments
---
 tests/test_runner/test_loop.py   |  66 ++++
 tests/test_runner/test_runner.py | 570 +++++++++++++++++++++++++++++++
 2 files changed, 636 insertions(+)
 create mode 100644 tests/test_runner/test_loop.py
 create mode 100644 tests/test_runner/test_runner.py

diff --git a/tests/test_runner/test_loop.py b/tests/test_runner/test_loop.py
new file mode 100644
index 00000000..f4a1e1c3
--- /dev/null
+++ b/tests/test_runner/test_loop.py
@@ -0,0 +1,66 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from unittest import TestCase
+from unittest.mock import Mock
+
+import numpy as np
+import torch
+from torch.utils.data import DataLoader, Dataset
+
+from mmengine.runner.loop import (EpochBasedTrainLoop, IterBasedTrainLoop,
+                                  TestLoop, ValLoop)
+
+
+class ToyDataset(Dataset):
+    META = dict()  # type: ignore
+    data = np.zeros((30, 1, 1, 1))
+
+    def __len__(self):
+        return self.data.shape[0]
+
+    def __getitem__(self, index):
+        return torch.from_numpy(self.data[index])
+
+
+class TestLoops(TestCase):
+
+    def setUp(self) -> None:
+        self.runner = Mock()
+        self.runner.call_hooks = Mock()
+        self.runner.model = Mock()
+        self.runner.epoch = 0
+        self.runner.iter = 0
+        self.runner.inner_iter = 0
+        self.runner.model.train_step = Mock()
+        self.runner.model.val_step = Mock()
+
+        self.evaluator = Mock()
+        self.evaluator.process = Mock()
+        self.evaluator.evaluate = Mock()
+
+    def test_epoch_based_train_loop(self):
+        train_loop = EpochBasedTrainLoop(
+            runner=self.runner, loader=DataLoader(ToyDataset()), max_epoch=3)
+        train_loop.run()
+        assert train_loop.runner.epoch == 3
+        assert train_loop.runner.iter == 90
+
+    def test_iter_based_train_loop(self):
+        train_loop = IterBasedTrainLoop(
+            runner=self.runner, loader=DataLoader(ToyDataset()), max_iter=25)
+        train_loop.run()
+        assert train_loop.runner.epoch == 0
+        assert train_loop.runner.iter == 25
+
+    def test_val_loop(self):
+        val_loop = ValLoop(
+            runner=self.runner,
+            loader=DataLoader(ToyDataset()),
+            evaluator=self.evaluator)
+        val_loop.run()
+
+    def test_test_loop(self):
+        test_loop = TestLoop(
+            runner=self.runner,
+            loader=DataLoader(ToyDataset()),
+            evaluator=self.evaluator)
+        test_loop.run()
diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py
new file mode 100644
index 00000000..d1bbba89
--- /dev/null
+++ b/tests/test_runner/test_runner.py
@@ -0,0 +1,570 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import logging
+import multiprocessing as mp
+import os
+import os.path as osp
+import platform
+import tempfile
+from unittest import TestCase
+from unittest.mock import patch
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.optim import SGD
+from torch.utils.data import DataLoader, Dataset
+
+from mmengine.config import Config
+from mmengine.evaluator import BaseEvaluator
+from mmengine.hooks import Hook
+from mmengine.logging import MessageHub, MMLogger
+from mmengine.model.wrappers import MMDataParallel, MMDistributedDataParallel
+from mmengine.optim.scheduler import MultiStepLR
+from mmengine.registry import (DATASETS, EVALUATORS, HOOKS, LOOPS,
+                               MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS,
+                               Registry)
+from mmengine.runner import Runner
+from mmengine.runner.loop import EpochBasedTrainLoop, IterBasedTrainLoop
+
+
+@MODELS.register_module()
+class ToyModel(nn.Module):
+
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 1, 1)
+        self.conv2 = nn.Conv2d(1, 1, 1)
+
+    def forward(self, x):
+        return self.conv2(F.relu(self.conv1(x)))
+
+    def train_step(self, *inputs, **kwargs):
+        pass
+
+    def val_step(self, *inputs, **kwargs):
+        pass
+
+
+@DATASETS.register_module()
+class ToyDataset(Dataset):
+    META = dict()  # type: ignore
+    data = np.zeros((10, 1, 1, 1))
+
+    def __len__(self):
+        return self.data.shape[0]
+
+    def __getitem__(self, index):
+        return torch.from_numpy(self.data[index])
+
+
+@EVALUATORS.register_module()
+class ToyEvaluator(BaseEvaluator):
+
+    def __init__(self, collect_device='cpu', dummy_metrics=None):
+        super().__init__(collect_device=collect_device)
+        self.dummy_metrics = dummy_metrics
+
+    def process(self, data_samples, predictions):
+        result = {'acc': 1}
+        self.results.append(result)
+
+    def compute_metrics(self, results):
+        return dict(acc=1)
+
+
+class TestRunner(TestCase):
+
+    def setUp(self):
+        self.temp_dir = tempfile.gettempdir()
+        full_cfg = dict(
+            model=dict(type='ToyModel'),
+            train_dataloader=dict(
+                dataset=dict(type='ToyDataset'),
+                sampler=dict(type='DefaultSampler', shuffle=True),
+                batch_size=1,
+                num_workers=0),
+            val_dataloader=dict(
+                dataset=dict(type='ToyDataset'),
+                sampler=dict(type='DefaultSampler', shuffle=False),
+                batch_size=1,
+                num_workers=0),
+            test_dataloader=dict(
+                dataset=dict(type='ToyDataset'),
+                sampler=dict(type='DefaultSampler', shuffle=False),
+                batch_size=1,
+                num_workers=0),
+            optimizer=dict(type='SGD', lr=0.01),
+            param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]),
+            evaluator=dict(type='ToyEvaluator'),
+            train_cfg=dict(by_epoch=True, max_epochs=3),
+            validation_cfg=dict(interval=1),
+            test_cfg=dict(),
+            custom_hooks=[],
+            default_hooks=dict(
+                timer=dict(type='IterTimerHook'),
+                checkpoint=dict(type='CheckpointHook', interval=1),
+                logger=dict(type='TextLoggerHook'),
+                optimizer=dict(type='OptimizerHook', grad_clip=False),
+                param_scheduler=dict(type='ParamSchedulerHook')),
+            env_cfg=dict(dist_params=dict(backend='nccl'), ),
+            log_cfg=dict(log_level='INFO'),
+            work_dir=self.temp_dir)
+        self.full_cfg = Config(full_cfg)
+
+    def tearDown(self):
+        os.removedirs(self.temp_dir)
+
+    def test_build_from_cfg(self):
+        runner = Runner.build_from_cfg(cfg=self.full_cfg)
+        # test env params
+        assert runner.distributed is False
+        assert runner.seed is not None
+        assert runner.work_dir == self.temp_dir
+
+        # model should be full initialized
+        assert isinstance(runner.model, (nn.Module, MMDataParallel))
+        # lazy init
+        assert isinstance(runner.optimzier, dict)
+        assert isinstance(runner.scheduler, list)
+        assert isinstance(runner.train_dataloader, dict)
+        assert isinstance(runner.val_dataloader, dict)
+        assert isinstance(runner.test_dataloader, dict)
+        assert isinstance(runner.evaluator, dict)
+
+        # after runner.train(), train and val loader should be initialized
+        # test loader should still be config
+        runner.train()
+        assert isinstance(runner.test_dataloader, dict)
+        assert isinstance(runner.train_dataloader, DataLoader)
+        assert isinstance(runner.val_dataloader, DataLoader)
+        assert isinstance(runner.optimzier, SGD)
+        assert isinstance(runner.evaluator, ToyEvaluator)
+
+        runner.test()
+        assert isinstance(runner.test_dataloader, DataLoader)
+
+        # cannot run runner.test() without evaluator cfg
+        with self.assertRaisesRegex(AssertionError,
+                                    'evaluator does not exist'):
+            cfg = copy.deepcopy(self.full_cfg)
+            cfg.pop('evaluator')
+            runner = Runner.build_from_cfg(cfg)
+            runner.test()
+
+        # cannot run runner.train() without optimizer cfg
+        with self.assertRaisesRegex(AssertionError,
+                                    'optimizer does not exist'):
+            cfg = copy.deepcopy(self.full_cfg)
+            cfg.pop('optimizer')
+            runner = Runner.build_from_cfg(cfg)
+            runner.train()
+
+        # can run runner.train() without validation
+        cfg = copy.deepcopy(self.full_cfg)
+        cfg.validation_cfg = None
+        cfg.pop('evaluator')
+        cfg.pop('val_dataloader')
+        runner = Runner.build_from_cfg(cfg)
+        runner.train()
+
+    def test_manually_init(self):
+        model = ToyModel()
+        optimizer = SGD(
+            model.parameters(),
+            lr=0.01,
+        )
+
+        class ToyHook(Hook):
+
+            def before_train_epoch(self, runner):
+                pass
+
+        class ToyHook2(Hook):
+
+            def after_train_epoch(self, runner):
+                pass
+
+        toy_hook = ToyHook()
+        toy_hook2 = ToyHook2()
+        runner = Runner(
+            model=model,
+            train_dataloader=DataLoader(dataset=ToyDataset()),
+            val_dataloader=DataLoader(dataset=ToyDataset()),
+            optimzier=optimizer,
+            param_scheduler=MultiStepLR(optimizer, milestones=[1, 2]),
+            evaluator=ToyEvaluator(),
+            train_cfg=dict(by_epoch=True, max_epochs=3),
+            validation_cfg=dict(interval=1),
+            default_hooks=dict(param_scheduler=toy_hook),
+            custom_hooks=[toy_hook2])
+        runner.train()
+        hook_names = [hook.__class__.__name__ for hook in runner.hooks]
+        # test custom hook registered in runner
+        assert 'ToyHook2' in hook_names
+        # test default hook is replaced
+        assert 'ToyHook' in hook_names
+        # test other default hooks
+        assert 'IterTimerHook' in hook_names
+
+        # cannot run runner.test() when test_dataloader is None
+        with self.assertRaisesRegex(AssertionError,
+                                    'test dataloader does not exist'):
+            runner.test()
+
+        # cannot run runner.train() when optimizer is None
+        with self.assertRaisesRegex(AssertionError,
+                                    'optimizer does not exist'):
+            runner = Runner(
+                model=model,
+                train_dataloader=DataLoader(dataset=ToyDataset()),
+                val_dataloader=DataLoader(dataset=ToyDataset()),
+                param_scheduler=MultiStepLR(optimizer, milestones=[1, 2]),
+                evaluator=ToyEvaluator(),
+                train_cfg=dict(by_epoch=True, max_epochs=3),
+                validation_cfg=dict(interval=1))
+            runner.train()
+
+        # cannot run runner.train() when validation_cfg is set
+        # but val loader is None
+        with self.assertRaisesRegex(AssertionError,
+                                    'optimizer does not exist'):
+            runner = Runner(
+                model=model,
+                train_dataloader=DataLoader(dataset=ToyDataset()),
+                optimzier=optimizer,
+                param_scheduler=MultiStepLR(optimizer, milestones=[1, 2]),
+                train_cfg=dict(by_epoch=True, max_epochs=3),
+                validation_cfg=dict(interval=1))
+            runner.train()
+
+        # run runner.train() without validation
+        runner = Runner(
+            model=model,
+            train_dataloader=DataLoader(dataset=ToyDataset()),
+            optimzier=optimizer,
+            param_scheduler=MultiStepLR(optimizer, milestones=[1, 2]),
+            train_cfg=dict(by_epoch=True, max_epochs=3),
+            validation_cfg=None)
+        runner.train()
+
+    def test_setup_env(self):
+        # temporarily store system setting
+        sys_start_mehod = mp.get_start_method(allow_none=True)
+        # pop and temp save system env vars
+        sys_omp_threads = os.environ.pop('OMP_NUM_THREADS', default=None)
+        sys_mkl_threads = os.environ.pop('MKL_NUM_THREADS', default=None)
+
+        # test default multi-processing setting when workers > 1
+        cfg = copy.deepcopy(self.full_cfg)
+        cfg.train_dataloader.num_workers = 4
+        cfg.test_dataloader.num_workers = 4
+        cfg.val_dataloader.num_workers = 4
+        Runner.build_from_cfg(cfg)
+        assert os.getenv('OMP_NUM_THREADS') == '1'
+        assert os.getenv('MKL_NUM_THREADS') == '1'
+        if platform.system() != 'Windows':
+            assert mp.get_start_method() == 'fork'
+
+        # test default multi-processing setting when workers <= 1
+        os.environ.pop('OMP_NUM_THREADS')
+        os.environ.pop('MKL_NUM_THREADS')
+        cfg = copy.deepcopy(self.full_cfg)
+        cfg.train_dataloader.num_workers = 0
+        cfg.test_dataloader.num_workers = 0
+        cfg.val_dataloader.num_workers = 0
+        Runner.build_from_cfg(cfg)
+        assert 'OMP_NUM_THREADS' not in os.environ
+        assert 'MKL_NUM_THREADS' not in os.environ
+
+        # test manually set env var
+        os.environ['OMP_NUM_THREADS'] = '3'
+        cfg = copy.deepcopy(self.full_cfg)
+        cfg.train_dataloader.num_workers = 2
+        cfg.test_dataloader.num_workers = 2
+        cfg.val_dataloader.num_workers = 2
+        Runner.build_from_cfg(cfg)
+        assert os.getenv('OMP_NUM_THREADS') == '3'
+
+        # test manually set mp start method
+        cfg = copy.deepcopy(self.full_cfg)
+        cfg.env_cfg.mp_cfg = dict(mp_start_method='spawn')
+        Runner.build_from_cfg(cfg)
+        assert mp.get_start_method() == 'spawn'
+
+        # revert setting to avoid affecting other programs
+        if sys_start_mehod:
+            mp.set_start_method(sys_start_mehod, force=True)
+        if sys_omp_threads:
+            os.environ['OMP_NUM_THREADS'] = sys_omp_threads
+        else:
+            os.environ.pop('OMP_NUM_THREADS')
+        if sys_mkl_threads:
+            os.environ['MKL_NUM_THREADS'] = sys_mkl_threads
+        else:
+            os.environ.pop('MKL_NUM_THREADS')
+
+    def test_logger(self):
+        runner = Runner.build_from_cfg(self.full_cfg)
+        assert isinstance(runner.logger, MMLogger)
+        # test latest logger and runner logger are the same
+        assert runner.logger.level == logging.INFO
+        assert MMLogger.get_instance(
+        ).instance_name == runner.logger.instance_name
+        # test latest message hub and runner message hub are the same
+        assert isinstance(runner.message_hub, MessageHub)
+        assert MessageHub.get_instance(
+        ).instance_name == runner.message_hub.instance_name
+
+        # test set log level in cfg
+        self.full_cfg.log_cfg.log_level = 'DEBUG'
+        runner = Runner.build_from_cfg(self.full_cfg)
+        assert runner.logger.level == logging.DEBUG
+
+    @patch('torch.distributed.get_rank', lambda: 0)
+    @patch('torch.distributed.is_initialized', lambda: True)
+    @patch('torch.distributed.is_available', lambda: True)
+    def test_model_wrapper(self):
+        # non-distributed model build from config
+        runner = Runner.build_from_cfg(self.full_cfg)
+        assert isinstance(runner.model, MMDataParallel)
+
+        # non-distributed model build manually
+        model = ToyModel()
+        runner = Runner(
+            model=model, train_cfg=dict(by_epoch=True, max_epochs=3))
+        assert isinstance(runner.model, MMDataParallel)
+
+        # distributed model build from config
+        cfg = copy.deepcopy(self.full_cfg)
+        cfg.launcher = 'pytorch'
+        runner = Runner.build_from_cfg(cfg)
+        assert isinstance(runner.model, MMDistributedDataParallel)
+
+        # distributed model build manually
+        model = ToyModel()
+        runner = Runner(
+            model=model,
+            train_cfg=dict(by_epoch=True, max_epochs=3),
+            env_cfg=dict(dist_params=dict(backend='nccl')),
+            launcher='pytorch')
+        assert isinstance(runner.model, MMDistributedDataParallel)
+
+        # custom model wrapper
+        @MODEL_WRAPPERS.register_module()
+        class CustomModelWrapper:
+
+            def train_step(self, *inputs, **kwargs):
+                pass
+
+            def val_step(self, *inputs, **kwargs):
+                pass
+
+        cfg = copy.deepcopy(self.full_cfg)
+        cfg.model_wrapper = dict(type='CustomModelWrapper')
+        runner = Runner.build_from_cfg(cfg)
+        assert isinstance(runner.model, CustomModelWrapper)
+
+    def test_default_scope(self):
+        TOY_SCHEDULERS = Registry(
+            'parameter scheduler', parent=PARAM_SCHEDULERS, scope='toy')
+
+        @TOY_SCHEDULERS.register_module()
+        class ToyScheduler(MultiStepLR):
+
+            def __init__(self, *args, **kwargs):
+                super().__init__(*args, **kwargs)
+
+        self.full_cfg.param_scheduler = dict(
+            type='ToyScheduler', milestones=[1, 2])
+        self.full_cfg.default_scope = 'toy'
+
+        runner = Runner.build_from_cfg(self.full_cfg)
+        runner.train()
+        assert isinstance(runner.scheduler[0], ToyScheduler)
+
+    def test_checkpoint(self):
+        runner = Runner.build_from_cfg(self.full_cfg)
+        runner.run()
+        path = osp.join(self.temp_dir, 'epoch_3.pth')
+        runner.save_checkpoint(path)
+        assert osp.exists(path)
+        ckpt = torch.load(path)
+        # scheduler should saved in the checkpoint
+        assert isinstance(ckpt['scheduler'], list)
+
+        # load by a new runner but not resume
+        runner2 = Runner.build_from_cfg(self.full_cfg)
+        runner2.load_checkpoint(path, resume=False)
+        self.assertNotEqual(runner2.epoch, runner.epoch)
+        self.assertNotEqual(runner2.iter, runner.iter)
+
+        # load by a new runner and resume
+        runner3 = Runner.build_from_cfg(self.full_cfg)
+        runner3.load_checkpoint(path, resume=True)
+        self.assertEqual(runner3.epoch, runner.epoch)
+        self.assertEqual(runner3.iter, runner.iter)
+
+    def test_custom_hooks(self):
+        results = []
+        targets = [0, 1, 2]
+
+        @HOOKS.register_module()
+        class ToyHook(Hook):
+
+            def before_train_epoch(self, runner):
+                results.append(runner.epoch)
+
+        self.full_cfg.custom_hooks = [dict(type='ToyHook', priority=50)]
+        runner = Runner.build_from_cfg(self.full_cfg)
+
+        # test hook registered in runner
+        hook_names = [hook.__class__.__name__ for hook in runner.hooks]
+        assert 'ToyHook' in hook_names
+
+        # test hook behavior
+        runner.train()
+        for result, target, in zip(results, targets):
+            self.assertEqual(result, target)
+
+    def test_iter_based(self):
+        self.full_cfg.train_cfg = dict(by_epoch=False, max_iters=30)
+
+        # test iter and epoch counter of IterBasedTrainLoop
+        epoch_results = []
+        iter_results = []
+        inner_iter_results = []
+        iter_targets = [i for i in range(30)]
+
+        @HOOKS.register_module()
+        class TestIterHook(Hook):
+
+            def before_train_epoch(self, runner):
+                epoch_results.append(runner.epoch)
+
+            def before_train_iter(self, runner):
+                iter_results.append(runner.iter)
+                inner_iter_results.append(runner.inner_iter)
+
+        self.full_cfg.custom_hooks = [dict(type='TestIterHook', priority=50)]
+        runner = Runner.build_from_cfg(self.full_cfg)
+
+        assert isinstance(runner._train_loop, IterBasedTrainLoop)
+
+        runner.train()
+
+        self.assertEqual(len(epoch_results), 1)
+        self.assertEqual(epoch_results[0], 0)
+        for result, target, in zip(iter_results, iter_targets):
+            self.assertEqual(result, target)
+        for result, target, in zip(inner_iter_results, iter_targets):
+            self.assertEqual(result, target)
+
+    def test_epoch_based(self):
+        self.full_cfg.train_cfg = dict(by_epoch=True, max_epochs=3)
+
+        # test iter and epoch counter of EpochBasedTrainLoop
+        epoch_results = []
+        epoch_targets = [i for i in range(3)]
+        iter_results = []
+        iter_targets = [i for i in range(10 * 3)]
+        inner_iter_results = []
+        inner_iter_targets = [i for i in range(10)] * 3  # train and val
+
+        @HOOKS.register_module()
+        class TestEpochHook(Hook):
+
+            def before_train_epoch(self, runner):
+                epoch_results.append(runner.epoch)
+
+            def before_train_iter(self, runner, data_batch=None):
+                iter_results.append(runner.iter)
+                inner_iter_results.append(runner.inner_iter)
+
+        self.full_cfg.custom_hooks = [dict(type='TestEpochHook', priority=50)]
+        runner = Runner.build_from_cfg(self.full_cfg)
+
+        assert isinstance(runner._train_loop, EpochBasedTrainLoop)
+
+        runner.train()
+
+        for result, target, in zip(epoch_results, epoch_targets):
+            self.assertEqual(result, target)
+        for result, target, in zip(iter_results, iter_targets):
+            self.assertEqual(result, target)
+        for result, target, in zip(inner_iter_results, inner_iter_targets):
+            self.assertEqual(result, target)
+
+    def test_custom_loop(self):
+        # test custom loop with additional hook
+        @LOOPS.register_module()
+        class CustomTrainLoop(EpochBasedTrainLoop):
+            """custom train loop with additional warmup stage."""
+
+            def __init__(self, runner, loader, max_epochs, warmup_loader,
+                         max_warmup_iters):
+                super().__init__(
+                    runner=runner, loader=loader, max_epochs=max_epochs)
+                self.warmup_loader = self.runner.build_dataloader(
+                    warmup_loader)
+                self.max_warmup_iters = max_warmup_iters
+
+            def run(self):
+                self.runner.call_hooks('before_run')
+                for idx, data_batch in enumerate(self.warmup_loader):
+                    self.warmup_iter(data_batch)
+                    if idx >= self.max_warmup_iters:
+                        break
+
+                self.runner.call_hooks('before_train_epoch')
+                while self.runner.iter < self._max_iter:
+                    data_batch = next(self.loader)
+                    self.run_iter(data_batch)
+                self.runner.call_hooks('after_train_epoch')
+                self.runner.call_hooks('after_run')
+
+            def warmup_iter(self, data_batch):
+                self.runner.call_hooks(
+                    'before_warmup_iter', args=dict(data_batch=data_batch))
+                outputs = self.runner.model.train_step(data_batch)
+                self.runner.call_hooks(
+                    'after_warmup_iter',
+                    args=dict(data_batch=data_batch, outputs=outputs))
+
+        before_warmup_iter_results = []
+        after_warmup_iter_results = []
+
+        @HOOKS.register_module()
+        class TestWarmupHook(Hook):
+            """test custom train loop."""
+
+            def before_warmup_iter(self, data_batch=None):
+                before_warmup_iter_results.append('before')
+
+            def after_warmup_iter(self, data_batch=None, outputs=None):
+                after_warmup_iter_results.append('after')
+
+        self.full_cfg.train_cfg = dict(
+            type='CustomTrainLoop',
+            max_epochs=3,
+            warmup_loader=dict(
+                dataset=dict(type='ToyDataset'),
+                sampler=dict(type='DefaultSampler', shuffle=True),
+                batch_size=1,
+                num_workers=0),
+            max_warmup_iters=5)
+        self.full_cfg.custom_hooks = [dict(type='TestWarmupHook', priority=50)]
+        runner = Runner.build_from_cfg(self.full_cfg)
+
+        assert isinstance(runner._train_loop, CustomTrainLoop)
+
+        runner.train()
+
+        # test custom hook triggered normally
+        self.assertEqual(len(before_warmup_iter_results), 5)
+        self.assertEqual(len(after_warmup_iter_results), 5)
+        for before, after in zip(before_warmup_iter_results,
+                                 after_warmup_iter_results):
+            self.assertEqual(before, 'before')
+            self.assertEqual(after, 'after')
-- 
GitLab