Skip to content
Snippets Groups Projects
Unverified Commit 80a46c48 authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Fix] fix build optimizer wrapper without type (#272)

* fix build optimizer wrapper without type

* refine logic

* fix as comment

* fix optim_wrapper config error in docstring and unit test

* refine docstring of build_optim_wrapper
parent 3e3866c1
No related branches found
No related tags found
No related merge requests found
......@@ -89,7 +89,7 @@ class DefaultOptimWrapperConstructor:
Example 1:
>>> model = torch.nn.modules.Conv1d(1, 1, 1)
>>> optim_wrapper_cfg = dict(
>>> dict(type=OptimWrapper, optimizer=dict(type='SGD', lr=0.01,
>>> dict(type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01,
>>> momentum=0.9, weight_decay=0.0001))
>>> paramwise_cfg = dict(norm_decay_mult=0.)
>>> optim_wrapper_builder = DefaultOptimWrapperConstructor(
......@@ -98,7 +98,7 @@ class DefaultOptimWrapperConstructor:
Example 2:
>>> # assume model have attribute model.backbone and model.cls_head
>>> optim_wrapper_cfg = dict(type=OptimWrapper, optimizer=dict(
>>> optim_wrapper_cfg = dict(type='OptimWrapper', optimizer=dict(
>>> type='SGD', lr=0.01, weight_decay=0.95))
>>> paramwise_cfg = dict(custom_keys={
>>> '.backbone': dict(lr_mult=0.1, decay_mult=0.9)})
......
......@@ -818,6 +818,21 @@ class Runner:
) -> Union[OptimWrapper, OptimWrapperDict]:
"""Build optimizer wrapper.
If ``optim_wrapper`` is a config dict for only one optimizer,
the keys must contain ``optimizer``, and ``type`` is optional.
It will build a :obj:`OptimWrapper` by default.
If ``optim_wrapper`` is a config dict for multiple optimizers, i.e.,
it has multiple keys and each key is for an optimizer wrapper. The
constructor must be specified since
:obj:`DefaultOptimizerConstructor` cannot handle the building of
training with multiple optimizers.
If ``optim_wrapper`` is a dict of pre-built optimizer wrappers, i.e.,
each value of ``optim_wrapper`` represents an ``OptimWrapper``
instance. ``build_optim_wrapper`` will directly build the
:obj:`OptimWrapperDict` instance from ``optim_wrapper``.
Args:
optim_wrapper (OptimWrapper or dict): An OptimWrapper object or a
dict to build OptimWrapper objects. If ``optim_wrapper`` is an
......@@ -840,7 +855,22 @@ class Runner:
nesterov: False
weight_decay: 0
)
>>> # build optimizer without `type`
>>> optim_wrapper_cfg = dict(optimizer=dict(type='SGD', lr=0.01))
>>> optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg)
>>> optim_wrapper
Type: OptimWrapper
accumulative_iters: 1
optimizer:
SGD (
Parameter Group 0
dampening: 0
lr: 0.01
maximize: False
momentum: 0
nesterov: False
weight_decay: 0
)
>>> # build multiple optimizers
>>> optim_wrapper_cfg = dict(
... generator=dict(type='OptimWrapper', optimizer=dict(
......@@ -848,7 +878,7 @@ class Runner:
... discriminator=dict(type='OptimWrapper', optimizer=dict(
... type='Adam', lr=0.001))
... # need to customize a multiple optimizer constructor
... constructor='CustomizedMultipleOptimizersConstructor',
... constructor='CustomMultiOptimizerConstructor',
...)
>>> optim_wrapper = runner.optim_wrapper(optim_wrapper_cfg)
>>> optim_wrapper
......@@ -879,7 +909,7 @@ class Runner:
Important:
If you need to build multiple optimizers, you should implement a
MultipleOptimizerConstructor which gets parameters passed to
MultiOptimWrapperConstructor which gets parameters passed to
corresponding optimizers and compose the ``OptimWrapperDict``.
More details about how to customize OptimizerConstructor can be
found at `optimizer-docs`_.
......@@ -893,17 +923,36 @@ class Runner:
if isinstance(optim_wrapper, OptimWrapper):
return optim_wrapper
elif isinstance(optim_wrapper, (dict, ConfigDict, Config)):
if 'type' not in optim_wrapper and ('constructor'
not in optim_wrapper):
# If `optim_wrapper` is a config dict with only one optimizer,
# the config dict must contain `optimizer`:
# optim_wrapper = dict(optimizer=dict(type='SGD', lr=0.1))
# `type` is optional, defaults to `OptimWrapper`.
# `optim_wrapper` could also be defined as:
# optim_wrapper = dict(type='AmpOptimWrapper', optimizer=dict(type='SGD', lr=0.1)) # noqa: E501
# to build specific optimizer wrapper.
if 'type' in optim_wrapper or 'optimizer' in optim_wrapper:
optim_wrapper = build_optim_wrapper(self.model, optim_wrapper)
return optim_wrapper
elif 'constructor' not in optim_wrapper:
# if `type` and `optimizer` are not defined in `optim_wrapper`,
# it should be the case of training with multiple optimizers.
# If constructor is not defined in `optim_wrapper`, each value
# of `optim_wrapper` must be an `OptimWrapper` instance since
# `DefaultOptimizerConstructor` will not handle the case of
# training with multiple optimizers. `build_optim_wrapper` will
# directly build the `OptimWrapperDict` instance from
# `optim_wrapper.`
optim_wrappers = OrderedDict()
for name, optim in optim_wrapper.items():
if not isinstance(optim, OptimWrapper):
raise ValueError(
'each item mush be an optimizer object when "type"'
' and "constructor" are not in optimizer, '
f'but got {name}={optim}')
'each item mush be an optimizer object when '
'"type" and "constructor" are not in '
f'optimizer, but got {name}={optim}')
optim_wrappers[name] = optim
return OptimWrapperDict(**optim_wrappers)
# If constructor is defined, directly build the optimizer
# wrapper instance from the config dict.
else:
optim_wrapper = build_optim_wrapper(self.model, optim_wrapper)
return optim_wrapper
......
......@@ -635,6 +635,11 @@ class TestRunner(TestCase):
dict(type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)))
self.assertIsInstance(optim_wrapper, OptimWrapper)
# 1.3 use default OptimWrapper type.
optim_wrapper = runner.build_optim_wrapper(
dict(optimizer=dict(type='SGD', lr=0.01)))
self.assertIsInstance(optim_wrapper, OptimWrapper)
# 2. test multiple optmizers
# 2.1 input is a dict which contains multiple optimizer objects
optimizer1 = SGD(runner.model.linear1.parameters(), lr=0.01)
......@@ -679,14 +684,16 @@ class TestRunner(TestCase):
# `build_param_scheduler`
cfg = dict(type='MultiStepLR', milestones=[1, 2])
runner.optim_wrapper = dict(
key1=dict(type=OptimWrapper, optimizer=dict(type='SGD', lr=0.01)),
key2=dict(type=OptimWrapper, optimizer=dict(type='Adam', lr=0.02)),
key1=dict(
type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)),
key2=dict(
type='OptimWrapper', optimizer=dict(type='Adam', lr=0.02)),
)
with self.assertRaisesRegex(AssertionError, 'should be called before'):
runner.build_param_scheduler(cfg)
runner.optim_wrapper = runner.build_optim_wrapper(
dict(type=OptimWrapper, optimizer=dict(type='SGD', lr=0.01)))
dict(type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)))
param_schedulers = runner.build_param_scheduler(cfg)
self.assertIsInstance(param_schedulers, list)
self.assertEqual(len(param_schedulers), 1)
......@@ -755,7 +762,7 @@ class TestRunner(TestCase):
# 5. test converting epoch-based scheduler to iter-based
runner.optim_wrapper = runner.build_optim_wrapper(
dict(type=OptimWrapper, optimizer=dict(type='SGD', lr=0.01)))
dict(type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)))
# 5.1 train loop should be built before converting scheduler
cfg = dict(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment