Newer
Older
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
from torch.nn.parallel import DistributedDataParallel
from torch.optim import SGD, Adam
from torch.utils.data import DataLoader, Dataset
from mmengine.config import Config
from mmengine.data import DefaultSampler
from mmengine.evaluator import BaseMetric, Evaluator
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, Hook,
IterTimerHook, LoggerHook, ParamSchedulerHook,
RuntimeInfoHook)
from mmengine.logging import LogProcessor, MessageHub, MMLogger
from mmengine.optim import (DefaultOptimWrapperConstructor, MultiStepLR,
OptimWrapper, OptimWrapperDict, StepLR)
from mmengine.registry import (DATASETS, HOOKS, LOG_PROCESSORS, LOOPS, METRICS,
MODEL_WRAPPERS, MODELS,
OPTIM_WRAPPER_CONSTRUCTORS, PARAM_SCHEDULERS,
Registry)
from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop,
Runner, TestLoop, ValLoop)
from mmengine.runner.priority import Priority, get_priority
from mmengine.utils import is_list_of
@MODELS.register_module()
class ToyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(2, 2)
self.linear2 = nn.Linear(2, 1)
def forward(self, data_batch, return_loss=False):
inputs, labels = [], []
for x in data_batch:
inputs.append(x['inputs'])
labels.append(x['data_sample'])
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
inputs = torch.stack(inputs).to(device)
labels = torch.stack(labels).to(device)
outputs = self.linear1(inputs)
outputs = self.linear2(outputs)
if return_loss:
loss = (labels - outputs).sum()
outputs = dict(loss=loss, log_vars=dict(loss=loss.item()))
return outputs
else:
outputs = dict(log_vars=dict(a=1, b=0.5))
return outputs
@MODELS.register_module()
class ToyModel1(ToyModel):
def __init__(self):
super().__init__()
@MODEL_WRAPPERS.register_module()
class CustomModelWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
class ToyMultipleOptimizerConstructor:
def __init__(self, optim_wrapper_cfg, paramwise_cfg=None):
if not isinstance(optim_wrapper_cfg, dict):
raise TypeError('optimizer_cfg should be a dict',
assert paramwise_cfg is None, (
'parawise_cfg should be set in each optimizer separately')
self.optim_wrapper_cfg = optim_wrapper_cfg
for key, cfg in self.optim_wrapper_cfg.items():
_cfg = cfg.copy()
paramwise_cfg_ = _cfg.pop('paramwise_cfg', None)
self.constructors[key] = DefaultOptimWrapperConstructor(
def __call__(self, model: nn.Module) -> OptimWrapperDict:
optimizers = {}
while hasattr(model, 'module'):
model = model.module
for key, constructor in self.constructors.items():
module = getattr(model, key)
optimizers[key] = constructor(module)
@DATASETS.register_module()
class ToyDataset(Dataset):
METAINFO = dict() # type: ignore
data = torch.randn(12, 2)
label = torch.ones(12)
return dict(inputs=self.data[index], data_sample=self.label[index])
@METRICS.register_module()
class ToyMetric1(BaseMetric):
def __init__(self, collect_device='cpu', dummy_metrics=None):
super().__init__(collect_device=collect_device)
self.dummy_metrics = dummy_metrics
def process(self, data_samples, predictions):
result = {'acc': 1}
self.results.append(result)
def compute_metrics(self, results):
return dict(acc=1)
@METRICS.register_module()
class ToyMetric2(BaseMetric):
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
def __init__(self, collect_device='cpu', dummy_metrics=None):
super().__init__(collect_device=collect_device)
self.dummy_metrics = dummy_metrics
def process(self, data_samples, predictions):
result = {'acc': 1}
self.results.append(result)
def compute_metrics(self, results):
return dict(acc=1)
@HOOKS.register_module()
class ToyHook(Hook):
priority = 'Lowest'
def before_train_epoch(self, runner):
pass
@HOOKS.register_module()
class ToyHook2(Hook):
priority = 'Lowest'
def after_train_epoch(self, runner):
pass
@LOOPS.register_module()
class CustomTrainLoop(BaseLoop):
def __init__(self, runner, dataloader, max_epochs):
super().__init__(runner, dataloader)
self._max_epochs = max_epochs
def run(self) -> None:
pass
@LOOPS.register_module()
class CustomValLoop(BaseLoop):
def __init__(self, runner, dataloader, evaluator, interval=1):
super().__init__(runner, dataloader)
self._runner = runner
if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
self.evaluator = runner.build_evaluator(evaluator) # type: ignore
else:
self.evaluator = evaluator
def run(self) -> None:
pass
@LOOPS.register_module()
class CustomTestLoop(BaseLoop):
def __init__(self, runner, dataloader, evaluator):
super().__init__(runner, dataloader)
self._runner = runner
if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
self.evaluator = runner.build_evaluator(evaluator) # type: ignore
else:
self.evaluator = evaluator
def run(self) -> None:
pass
class CustomLogProcessor(LogProcessor):
def __init__(self, window_size=10, by_epoch=True, custom_cfg=None):
self.window_size = window_size
self.by_epoch = by_epoch
self.custom_cfg = custom_cfg if custom_cfg else []
self._check_custom_cfg()
def collate_fn(data_batch):
return data_batch
class TestRunner(TestCase):
def setUp(self):
self.temp_dir = tempfile.mkdtemp()
epoch_based_cfg = dict(
train_dataloader=dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=True),
num_workers=0),
val_dataloader=dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=False),
num_workers=0),
test_dataloader=dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=False),
optim_wrapper=dict(
type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)),
param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]),
val_evaluator=dict(type='ToyMetric1'),
test_evaluator=dict(type='ToyMetric1'),
val_cfg=dict(interval=1, begin=1),
test_cfg=dict(),
custom_hooks=[],
default_hooks=dict(
runtime_info=dict(type='RuntimeInfoHook'),
optimizer=dict(type='OptimizerHook', grad_clip=None),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(
type='CheckpointHook', interval=1, by_epoch=True),
sampler_seed=dict(type='DistSamplerSeedHook')),
launcher='none',
env_cfg=dict(dist_cfg=dict(backend='nccl')),
)
self.epoch_based_cfg = Config(epoch_based_cfg)
self.iter_based_cfg = copy.deepcopy(self.epoch_based_cfg)
self.iter_based_cfg.train_dataloader = dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='InfiniteSampler', shuffle=True),
batch_size=3,
num_workers=0)
self.iter_based_cfg.train_cfg = dict(by_epoch=False, max_iters=12)
self.iter_based_cfg.default_hooks = dict(
runtime_info=dict(type='RuntimeInfoHook'),
optimizer=dict(type='OptimizerHook', grad_clip=None),
logger=dict(type='LoggerHook'),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', interval=1, by_epoch=False),
sampler_seed=dict(type='DistSamplerSeedHook'))
shutil.rmtree(self.temp_dir)
def test_init(self):
# 1. test arguments
# 1.1 train_dataloader, train_cfg, optimizer and param_scheduler
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.pop('train_cfg')
with self.assertRaisesRegex(ValueError, 'either all None or not None'):
Runner(**cfg)
# all of training related configs are None and param_scheduler should
# also be None
cfg.experiment_name = 'test_init2'
cfg.pop('param_scheduler')
runner = Runner(**cfg)
self.assertIsInstance(runner, Runner)
# all of training related configs are not None
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_init3'
runner = Runner(**cfg)
self.assertIsInstance(runner, Runner)
# all of training related configs are not None and param_scheduler
# can be None
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_init4'
cfg.pop('param_scheduler')
runner = Runner(**cfg)
self.assertIsInstance(runner, Runner)
self.assertEqual(runner.param_schedulers, [])
# param_scheduler should be None when optimizer is None
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_init5'
cfg.pop('train_cfg')
cfg.pop('train_dataloader')
with self.assertRaisesRegex(ValueError, 'should be None'):
runner = Runner(**cfg)
# 1.2 val_dataloader, val_evaluator, val_cfg
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.pop('val_cfg')
with self.assertRaisesRegex(ValueError, 'either all None or not None'):
Runner(**cfg)
cfg.pop('val_dataloader')
cfg.pop('val_evaluator')
runner = Runner(**cfg)
self.assertIsInstance(runner, Runner)
cfg = copy.deepcopy(self.epoch_based_cfg)
runner = Runner(**cfg)
self.assertIsInstance(runner, Runner)
# 1.3 test_dataloader, test_evaluator and test_cfg
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.pop('test_cfg')
with self.assertRaisesRegex(ValueError, 'either all None or not None'):
runner = Runner(**cfg)
cfg.pop('test_dataloader')
cfg.pop('test_evaluator')
runner = Runner(**cfg)
self.assertIsInstance(runner, Runner)
# 1.4 test env params
cfg = copy.deepcopy(self.epoch_based_cfg)
runner = Runner(**cfg)
self.assertFalse(runner.distributed)
self.assertFalse(runner.deterministic)
# they are all not specified
cfg = copy.deepcopy(self.epoch_based_cfg)
runner = Runner(**cfg)
self.assertIsInstance(runner.logger, MMLogger)
self.assertIsInstance(runner.message_hub, MessageHub)
self.assertIsInstance(runner.visualizer, Visualizer)
# they are all specified
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_init13'
cfg.log_level = 'INFO'
runner = Runner(**cfg)
self.assertIsInstance(runner.logger, MMLogger)
self.assertIsInstance(runner.message_hub, MessageHub)
self.assertIsInstance(runner.visualizer, Visualizer)
assert runner.distributed is False
assert runner.seed is not None
assert runner.work_dir == self.temp_dir
# 2 model should be initialized
self.assertIsInstance(runner.model,
(nn.Module, DistributedDataParallel))
self.assertEqual(runner.model_name, 'ToyModel')
# 3. test lazy initialization
self.assertIsInstance(runner._train_dataloader, dict)
self.assertIsInstance(runner._val_dataloader, dict)
self.assertIsInstance(runner._test_dataloader, dict)
self.assertIsInstance(runner.optim_wrapper, dict)
self.assertIsInstance(runner.param_schedulers[0], dict)
# After calling runner.train(),
# train_dataloader and val_loader should be initialized but
# test_dataloader should also be dict
self.assertIsInstance(runner._train_loop, BaseLoop)
self.assertIsInstance(runner.train_dataloader, DataLoader)
self.assertIsInstance(runner.optim_wrapper, OptimWrapper)
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
self.assertIsInstance(runner._val_loop, BaseLoop)
self.assertIsInstance(runner._val_loop.dataloader, DataLoader)
self.assertIsInstance(runner._val_loop.evaluator, Evaluator)
# After calling runner.test(), test_dataloader should be initialized
self.assertIsInstance(runner._test_loop, dict)
self.assertIsInstance(runner._test_loop, BaseLoop)
self.assertIsInstance(runner._test_loop.dataloader, DataLoader)
self.assertIsInstance(runner._test_loop.evaluator, Evaluator)
# 4. initialize runner with objects rather than config
train_dataloader = DataLoader(ToyDataset(), collate_fn=collate_fn)
val_dataloader = DataLoader(ToyDataset(), collate_fn=collate_fn)
test_dataloader = DataLoader(ToyDataset(), collate_fn=collate_fn)
optim_wrapper=optim_wrapper,
param_scheduler=MultiStepLR(optim_wrapper, milestones=[1, 2]),
val_cfg=dict(interval=1, begin=1),
test_cfg=dict(),
test_dataloader=test_dataloader,
custom_hooks=[toy_hook2],
experiment_name='test_init14')
def test_dump_config(self):
# dump config from dict.
cfg = copy.deepcopy(self.epoch_based_cfg)
for idx, cfg in enumerate((cfg, cfg._cfg_dict)):
cfg.experiment_name = f'test_dump{idx}'
runner = Runner.from_cfg(cfg=cfg)
assert osp.exists(
osp.join(runner.work_dir, f'{runner.timestamp}.py'))
# dump config from file.
with tempfile.TemporaryDirectory() as temp_config_dir:
temp_config_file = tempfile.NamedTemporaryFile(
dir=temp_config_dir, suffix='.py')
file_cfg = Config(
self.epoch_based_cfg._cfg_dict,
filename=temp_config_file.name)
file_cfg.experiment_name = f'test_dump2{idx}'
runner = Runner.from_cfg(cfg=file_cfg)
assert osp.exists(
osp.join(runner.work_dir,
osp.basename(temp_config_file.name)))
def test_from_cfg(self):
runner = Runner.from_cfg(cfg=self.epoch_based_cfg)
def test_build_logger(self):
self.epoch_based_cfg.experiment_name = 'test_build_logger1'
runner = Runner.from_cfg(self.epoch_based_cfg)
self.assertIsInstance(runner.logger, MMLogger)
self.assertEqual(runner.experiment_name, runner.logger.instance_name)
# input is a dict
logger = runner.build_logger(name='test_build_logger2')
self.assertIsInstance(logger, MMLogger)
self.assertEqual(logger.instance_name, 'test_build_logger2')
# input is a dict but does not contain name key
runner._experiment_name = 'test_build_logger3'
logger = runner.build_logger()
self.assertIsInstance(logger, MMLogger)
self.assertEqual(logger.instance_name, 'test_build_logger3')
self.epoch_based_cfg.experiment_name = 'test_build_message_hub1'
runner = Runner.from_cfg(self.epoch_based_cfg)
self.assertIsInstance(runner.message_hub, MessageHub)
self.assertEqual(runner.message_hub.instance_name,
runner.experiment_name)
# input is a dict
message_hub_cfg = dict(name='test_build_message_hub2')
message_hub = runner.build_message_hub(message_hub_cfg)
self.assertIsInstance(message_hub, MessageHub)
self.assertEqual(message_hub.instance_name, 'test_build_message_hub2')
# input is a dict but does not contain name key
runner._experiment_name = 'test_build_message_hub3'
message_hub_cfg = dict()
message_hub = runner.build_message_hub(message_hub_cfg)
self.assertIsInstance(message_hub, MessageHub)
self.assertEqual(message_hub.instance_name, 'test_build_message_hub3')
# input is not a valid type
with self.assertRaisesRegex(TypeError, 'message_hub should be'):
runner.build_message_hub('invalid-type')
def test_build_visualizer(self):
self.epoch_based_cfg.experiment_name = 'test_build_visualizer1'
runner = Runner.from_cfg(self.epoch_based_cfg)
self.assertIsInstance(runner.visualizer, Visualizer)
self.assertEqual(runner.experiment_name,
runner.visualizer.instance_name)
id(runner.build_visualizer(runner.visualizer)),
id(runner.visualizer))
visualizer_cfg = dict(type='Visualizer', name='test_build_visualizer2')
visualizer = runner.build_visualizer(visualizer_cfg)
self.assertIsInstance(visualizer, Visualizer)
self.assertEqual(visualizer.instance_name, 'test_build_visualizer2')
# input is a dict but does not contain name key
runner._experiment_name = 'test_build_visualizer3'
visualizer_cfg = None
visualizer = runner.build_visualizer(visualizer_cfg)
self.assertIsInstance(visualizer, Visualizer)
self.assertEqual(visualizer.instance_name, 'test_build_visualizer3')
with self.assertRaisesRegex(TypeError, 'visualizer should be'):
runner.build_visualizer('invalid-type')
def test_default_scope(self):
TOY_SCHEDULERS = Registry(
'parameter scheduler', parent=PARAM_SCHEDULERS, scope='toy')
@TOY_SCHEDULERS.register_module()
class ToyScheduler(MultiStepLR):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.epoch_based_cfg.param_scheduler = dict(
self.epoch_based_cfg.default_scope = 'toy'
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_default_scope'
runner = Runner.from_cfg(cfg)
self.assertIsInstance(runner.param_schedulers[0], ToyScheduler)
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_model'
runner = Runner.from_cfg(cfg)
self.assertIsInstance(runner.model, ToyModel)
# input should be a nn.Module object or dict
with self.assertRaisesRegex(TypeError, 'model should be'):
runner.build_model('invalid-type')
# input is a nn.Module object
_model = ToyModel1()
model = runner.build_model(_model)
self.assertEqual(id(model), id(_model))
# input is a dict
model = runner.build_model(dict(type='ToyModel1'))
self.assertIsInstance(model, ToyModel1)
# test init weights
@MODELS.register_module()
class ToyModel2(ToyModel):
def __init__(self):
super().__init__()
self.initiailzed = False
def init_weights(self):
self.initiailzed = True
model = runner.build_model(dict(type='ToyModel2'))
self.assertTrue(model.initiailzed)
# test init weights with model object
_model = ToyModel2()
model = runner.build_model(_model)
self.assertFalse(model.initiailzed)
def test_wrap_model(self):
# TODO: test on distributed environment
# custom model wrapper
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_wrap_model'
cfg.model_wrapper_cfg = dict(type='CustomModelWrapper')
self.assertIsInstance(runner.model, CustomModelWrapper)
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_optim_wrapper'
# input should be an Optimizer object or dict
with self.assertRaisesRegex(TypeError, 'optimizer wrapper should be'):
runner.build_optim_wrapper('invalid-type')
# 1. test one optimizer
# 1.1 input is an Optimizer object
optimizer = SGD(runner.model.parameters(), lr=0.01)
optim_wrapper = OptimWrapper(optimizer)
optim_wrapper = runner.build_optim_wrapper(optim_wrapper)
self.assertEqual(id(optimizer), id(optim_wrapper.optimizer))
optim_wrapper = runner.build_optim_wrapper(
dict(type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)))
self.assertIsInstance(optim_wrapper, OptimWrapper)
# 2. test multiple optmizers
# 2.1 input is a dict which contains multiple optimizer objects
optimizer1 = SGD(runner.model.linear1.parameters(), lr=0.01)
optim_wrapper1 = OptimWrapper(optimizer1)
optimizer2 = Adam(runner.model.linear2.parameters(), lr=0.02)
optim_wrapper2 = OptimWrapper(optimizer2)
optim_wrapper_cfg = dict(key1=optim_wrapper1, key2=optim_wrapper2)
optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg)
self.assertIsInstance(optim_wrapper, OptimWrapperDict)
self.assertIsInstance(optim_wrapper['key1'].optimizer, SGD)
self.assertIsInstance(optim_wrapper['key2'].optimizer, Adam)
# 2.2 each item mush be an optimizer object when "type" and
# "constructor" are not in optimizer
optimizer1 = SGD(runner.model.linear1.parameters(), lr=0.01)
optim_wrapper1 = OptimWrapper(optimizer1)
optim_wrapper2 = dict(
type='OptimWrapper', optimizer=dict(type='Adam', lr=0.01))
optim_cfg = dict(key1=optim_wrapper1, key2=optim_wrapper2)
with self.assertRaisesRegex(ValueError,
'each item mush be an optimizer object'):
# 2.3 input is a dict which contains multiple configs
optim_wrapper_cfg = dict(
linear1=dict(
type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)),
linear2=dict(
type='OptimWrapper', optimizer=dict(type='Adam', lr=0.02)),
constructor='ToyMultipleOptimizerConstructor')
optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg)
self.assertIsInstance(optim_wrapper, OptimWrapperDict)
self.assertIsInstance(optim_wrapper['linear1'].optimizer, SGD)
self.assertIsInstance(optim_wrapper['linear2'].optimizer, Adam)
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_param_scheduler'
runner = Runner.from_cfg(cfg)
# `build_optim_wrapper` should be called before
# `build_param_scheduler`
cfg = dict(type='MultiStepLR', milestones=[1, 2])
runner.optim_wrapper = dict(
key1=dict(type=OptimWrapper, optimizer=dict(type='SGD', lr=0.01)),
key2=dict(type=OptimWrapper, optimizer=dict(type='Adam', lr=0.02)),
with self.assertRaisesRegex(AssertionError, 'should be called before'):
runner.optim_wrapper = runner.build_optim_wrapper(
dict(type=OptimWrapper, optimizer=dict(type='SGD', lr=0.01)))
param_schedulers = runner.build_param_scheduler(cfg)
self.assertIsInstance(param_schedulers, list)
self.assertEqual(len(param_schedulers), 1)
self.assertIsInstance(param_schedulers[0], MultiStepLR)
# 1. test one optimizer and one parameter scheduler
# 1.1 input is a ParamScheduler object
param_scheduler = MultiStepLR(runner.optim_wrapper, milestones=[1, 2])
param_schedulers = runner.build_param_scheduler(param_scheduler)
self.assertEqual(len(param_schedulers), 1)
self.assertEqual(id(param_schedulers[0]), id(param_scheduler))
# 1.2 input is a dict
param_schedulers = runner.build_param_scheduler(param_scheduler)
self.assertEqual(len(param_schedulers), 1)
self.assertIsInstance(param_schedulers[0], MultiStepLR)
# 2. test one optimizer and list of parameter schedulers
# 2.1 input is a list of dict
cfg = [
dict(type='MultiStepLR', milestones=[1, 2]),
dict(type='StepLR', step_size=1)
]
param_schedulers = runner.build_param_scheduler(cfg)
self.assertEqual(len(param_schedulers), 2)
self.assertIsInstance(param_schedulers[0], MultiStepLR)
self.assertIsInstance(param_schedulers[1], StepLR)
# 2.2 input is a list and some items are ParamScheduler objects
cfg = [param_scheduler, dict(type='StepLR', step_size=1)]
param_schedulers = runner.build_param_scheduler(cfg)
self.assertEqual(len(param_schedulers), 2)
self.assertIsInstance(param_schedulers[0], MultiStepLR)
self.assertIsInstance(param_schedulers[1], StepLR)
# 3. test multiple optimizers and list of parameter schedulers
optimizer1 = SGD(runner.model.linear1.parameters(), lr=0.01)
optim_wrapper1 = OptimWrapper(optimizer1)
optimizer2 = Adam(runner.model.linear2.parameters(), lr=0.02)
optim_wrapper2 = OptimWrapper(optimizer2)
optim_wrapper_cfg = dict(key1=optim_wrapper1, key2=optim_wrapper2)
runner.optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg)
cfg = [
dict(type='MultiStepLR', milestones=[1, 2]),
dict(type='StepLR', step_size=1)
]
param_schedulers = runner.build_param_scheduler(cfg)
print(param_schedulers)
self.assertIsInstance(param_schedulers, dict)
self.assertEqual(len(param_schedulers), 2)
self.assertEqual(len(param_schedulers['key1']), 2)
self.assertEqual(len(param_schedulers['key2']), 2)
# 4. test multiple optimizers and multiple parameter shceduers
cfg = dict(
key1=dict(type='MultiStepLR', milestones=[1, 2]),
key2=[
dict(type='MultiStepLR', milestones=[1, 2]),
dict(type='StepLR', step_size=1)
])
param_schedulers = runner.build_param_scheduler(cfg)
self.assertIsInstance(param_schedulers, dict)
self.assertEqual(len(param_schedulers), 2)
self.assertEqual(len(param_schedulers['key1']), 1)
self.assertEqual(len(param_schedulers['key2']), 2)
# 5. test converting epoch-based scheduler to iter-based
runner.optim_wrapper = runner.build_optim_wrapper(
dict(type=OptimWrapper, optimizer=dict(type='SGD', lr=0.01)))
# 5.1 train loop should be built before converting scheduler
cfg = dict(
type='MultiStepLR', milestones=[1, 2], convert_to_iter_based=True)
with self.assertRaisesRegex(
AssertionError,
'Scheduler can only be converted to iter-based when '
'train loop is built.'):
# 5.2 convert epoch-based to iter-based scheduler
cfg = dict(
type='MultiStepLR',
milestones=[1, 2],
begin=1,
end=7,
convert_to_iter_based=True)
runner._train_loop = runner.build_train_loop(runner.train_loop)
param_schedulers = runner.build_param_scheduler(cfg)
self.assertFalse(param_schedulers[0].by_epoch)
self.assertEqual(param_schedulers[0].begin, 4)
self.assertEqual(param_schedulers[0].end, 28)
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_evaluator'
runner = Runner.from_cfg(cfg)
# input is a BaseEvaluator or ComposedEvaluator object
evaluator = Evaluator(ToyMetric1())
self.assertEqual(id(runner.build_evaluator(evaluator)), id(evaluator))
evaluator = Evaluator([ToyMetric1(), ToyMetric2()])
self.assertEqual(id(runner.build_evaluator(evaluator)), id(evaluator))
# input is a dict
evaluator = dict(type='ToyMetric1')
self.assertIsInstance(runner.build_evaluator(evaluator), Evaluator)
# input is a list of dict
evaluator = [dict(type='ToyMetric1'), dict(type='ToyMetric2')]
self.assertIsInstance(runner.build_evaluator(evaluator), Evaluator)
# test collect device
evaluator = [
dict(type='ToyMetric1', collect_device='cpu'),
dict(type='ToyMetric2', collect_device='gpu')
]
_evaluator = runner.build_evaluator(evaluator)
self.assertEqual(_evaluator.metrics[0].collect_device, 'cpu')
self.assertEqual(_evaluator.metrics[1].collect_device, 'gpu')
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_dataloader'
runner = Runner.from_cfg(cfg)
cfg = dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=True),
batch_size=1,
num_workers=0)
seed = np.random.randint(2**31)
dataloader = runner.build_dataloader(cfg, seed=seed)
self.assertIsInstance(dataloader, DataLoader)
self.assertIsInstance(dataloader.dataset, ToyDataset)
self.assertIsInstance(dataloader.sampler, DefaultSampler)
self.assertEqual(dataloader.sampler.seed, seed)
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_train_loop'
runner = Runner.from_cfg(cfg)
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
# input should be a Loop object or dict
with self.assertRaisesRegex(TypeError, 'should be'):
runner.build_train_loop('invalid-type')
# Only one of type or by_epoch can exist in cfg
cfg = dict(type='EpochBasedTrainLoop', by_epoch=True, max_epochs=3)
with self.assertRaisesRegex(RuntimeError, 'Only one'):
runner.build_train_loop(cfg)
# input is a dict and contains type key
cfg = dict(type='EpochBasedTrainLoop', max_epochs=3)
loop = runner.build_train_loop(cfg)
self.assertIsInstance(loop, EpochBasedTrainLoop)
cfg = dict(type='IterBasedTrainLoop', max_iters=3)
loop = runner.build_train_loop(cfg)
self.assertIsInstance(loop, IterBasedTrainLoop)
# input is a dict and does not contain type key
cfg = dict(by_epoch=True, max_epochs=3)
loop = runner.build_train_loop(cfg)
self.assertIsInstance(loop, EpochBasedTrainLoop)
cfg = dict(by_epoch=False, max_iters=3)
loop = runner.build_train_loop(cfg)
self.assertIsInstance(loop, IterBasedTrainLoop)
# input is a Loop object
self.assertEqual(id(runner.build_train_loop(loop)), id(loop))
# param_schedulers can be []
cfg = dict(type='EpochBasedTrainLoop', max_epochs=3)
runner.param_schedulers = []
loop = runner.build_train_loop(cfg)
self.assertIsInstance(loop, EpochBasedTrainLoop)
# test custom training loop
cfg = dict(type='CustomTrainLoop', max_epochs=3)
loop = runner.build_train_loop(cfg)
self.assertIsInstance(loop, CustomTrainLoop)
def test_build_val_loop(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_val_loop'
runner = Runner.from_cfg(cfg)
# input should be a Loop object or dict
with self.assertRaisesRegex(TypeError, 'should be'):
runner.build_test_loop('invalid-type')
# input is a dict and contains type key
cfg = dict(type='ValLoop', interval=1, begin=1)
loop = runner.build_test_loop(cfg)
self.assertIsInstance(loop, ValLoop)
# input is a dict but does not contain type key
cfg = dict(interval=1, begin=1)
loop = runner.build_val_loop(cfg)
self.assertIsInstance(loop, ValLoop)
# input is a Loop object
self.assertEqual(id(runner.build_val_loop(loop)), id(loop))
# test custom validation loop
cfg = dict(type='CustomValLoop', interval=1)
loop = runner.build_val_loop(cfg)
self.assertIsInstance(loop, CustomValLoop)
def test_build_test_loop(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_test_loop'
runner = Runner.from_cfg(cfg)
# input should be a Loop object or dict
with self.assertRaisesRegex(TypeError, 'should be'):
runner.build_test_loop('invalid-type')
# input is a dict and contains type key
cfg = dict(type='TestLoop')
loop = runner.build_test_loop(cfg)
self.assertIsInstance(loop, TestLoop)
# input is a dict but does not contain type key
cfg = dict()
loop = runner.build_test_loop(cfg)
self.assertIsInstance(loop, TestLoop)
# input is a Loop object
self.assertEqual(id(runner.build_test_loop(loop)), id(loop))
# test custom validation loop
cfg = dict(type='CustomTestLoop')
loop = runner.build_val_loop(cfg)
self.assertIsInstance(loop, CustomTestLoop)
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
def test_build_log_processor(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_log_processor'
runner = Runner.from_cfg(cfg)
# input should be a LogProcessor object or dict
with self.assertRaisesRegex(TypeError, 'should be'):
runner.build_log_processor('invalid-type')
# input is a dict and contains type key
cfg = dict(type='LogProcessor')
log_processor = runner.build_log_processor(cfg)
self.assertIsInstance(log_processor, LogProcessor)
# input is a dict but does not contain type key
cfg = dict()
log_processor = runner.build_log_processor(cfg)
self.assertIsInstance(log_processor, LogProcessor)
# input is a LogProcessor object
self.assertEqual(
id(runner.build_log_processor(log_processor)), id(log_processor))
# test custom validation log_processor
cfg = dict(type='CustomLogProcessor')
log_processor = runner.build_log_processor(cfg)
self.assertIsInstance(log_processor, CustomLogProcessor)
def test_train(self):
# 1. test `self.train_loop` is None
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.pop('train_dataloader')
cfg.pop('train_cfg')
with self.assertRaisesRegex(RuntimeError, 'should not be None'):
runner.train()
# 2. test iter and epoch counter of EpochBasedTrainLoop and timing of
# running ValLoop
iter_targets = [i for i in range(4 * 3)]
batch_idx_results = []
batch_idx_targets = [i for i in range(4)] * 3 # train and val
val_epoch_results = []
val_epoch_targets = [i for i in range(2, 4)]
def before_train_epoch(self, runner):
epoch_results.append(runner.epoch)
def before_train_iter(self, runner, batch_idx, data_batch=None):
batch_idx_results.append(batch_idx)
def before_val_epoch(self, runner):
val_epoch_results.append(runner.epoch)
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_train2'
cfg.custom_hooks = [dict(type='TestEpochHook', priority=50)]
cfg.val_cfg = dict(begin=2)
assert isinstance(runner.train_loop, EpochBasedTrainLoop)