Skip to content
Snippets Groups Projects
Unverified Commit e8ee1926 authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

[Enhancement] Improve revert_sync_batchnorm to support mmcv SyncBN (#448)

parent e907931f
No related branches found
No related tags found
No related merge requests found
......@@ -34,6 +34,15 @@ def revert_sync_batchnorm(module: nn.Module) -> nn.Module:
"""
module_output = module
module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm]
try:
import mmcv
except ImportError:
pass
else:
if hasattr(mmcv, 'ops'):
module_checklist.append(mmcv.ops.SyncBatchNorm)
if isinstance(module, tuple(module_checklist)):
module_output = _BatchNormXd(module.num_features, module.eps,
module.momentum, module.affine,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment