Skip to content
Snippets Groups Projects
Unverified Commit ded73f3a authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Enhance] Enhance compatibility of `revert_sync_batchnorm` (#695)

* [Enhance] Enhance revert_sync_batchnorm and convert_sync_batchnorm

* [Enhance] Enhance revert_sync_batchnorm and convert_sync_batchnorm

* Fix unit test

* Add coments

* Refine comments

* clean the code

* revert convert_sync_batchnorm

* revert convert_sync_batchnorm

* refine comment

* fix CI

* fix CI
parent 9b4dbb31
No related branches found
No related tags found
No related merge requests found
......@@ -7,6 +7,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.logging import print_log
from mmengine.utils.dl_utils import mmcv_full_available
......@@ -192,7 +193,17 @@ def revert_sync_batchnorm(module: nn.Module) -> nn.Module:
if hasattr(module, 'qconfig'):
module_output.qconfig = module.qconfig
for name, child in module.named_children():
module_output.add_module(name, revert_sync_batchnorm(child))
# Some custom modules or 3rd party implemented modules may raise an
# error when calling `add_module`. Therefore, try to catch the error
# and do not raise it. See https://github.com/open-mmlab/mmengine/issues/638 # noqa: E501
# for more details.
try:
module_output.add_module(name, revert_sync_batchnorm(child))
except Exception:
print_log(
F'Failed to convert {child} from SyncBN to BN!',
logger='current',
level=logging.WARNING)
del module
return module_output
......
......@@ -15,6 +15,16 @@ from mmengine.registry import MODEL_WRAPPERS, Registry
from mmengine.utils import is_installed
class ToyModule(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(1, 1)
def add_module(self, name, module):
raise ValueError()
@pytest.mark.skipif(
torch.__version__ == 'parrots', reason='not supported in parrots now')
def test_revert_syncbn():
......@@ -28,6 +38,12 @@ def test_revert_syncbn():
y = conv(x)
assert y.shape == (1, 8, 9, 9)
# TODO, capsys provided by `pytest` cannot capture the error log produced
# by MMLogger. Test the error log after refactoring the unit test with
# `unittest`
conv = nn.Sequential(ToyModule(), nn.SyncBatchNorm(8))
revert_sync_batchnorm(conv)
@pytest.mark.skipif(
torch.__version__ == 'parrots', reason='not supported in parrots now')
......@@ -41,10 +57,12 @@ def test_convert_syncbn():
# Test convert to mmcv SyncBatchNorm
if is_installed('mmcv'):
# MMCV SyncBatchNorm is only supported on distributed training.
# torch 1.6 will throw an AssertionError, and higher version will
# throw an RuntimeError
with pytest.raises((RuntimeError, AssertionError)):
convert_sync_batchnorm(conv, implementation='mmcv')
# Test convert to Pytorch SyncBatchNorm
# Test convert BN to Pytorch SyncBatchNorm
# Expect a ValueError prompting that SyncBN is not supported on CPU
converted_conv = convert_sync_batchnorm(conv)
assert isinstance(converted_conv[1], torch.nn.SyncBatchNorm)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment