diff --git a/mmengine/testing/__init__.py b/mmengine/testing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3318b89c34a2dfee68df90953b28d4f359e1f244 --- /dev/null +++ b/mmengine/testing/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .compare import assert_allclose + +__all__ = ['assert_allclose'] diff --git a/mmengine/testing/compare.py b/mmengine/testing/compare.py new file mode 100644 index 0000000000000000000000000000000000000000..abbf22617063814db8ae2ae521b645d798767f76 --- /dev/null +++ b/mmengine/testing/compare.py @@ -0,0 +1,48 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Callable, Optional, Union + +from torch.testing import assert_allclose as _assert_allclose + +from mmengine.utils import TORCH_VERSION, digit_version + + +def assert_allclose( + actual: Any, + expected: Any, + rtol: Optional[float] = None, + atol: Optional[float] = None, + equal_nan: bool = True, + msg: Optional[Union[str, Callable]] = '', +) -> None: + """Asserts that ``actual`` and ``expected`` are close. A wrapper function + of ``torch.testing.assert_allclose``. + + Args: + actual (Any): Actual input. + expected (Any): Expected input. + rtol (Optional[float]): Relative tolerance. If specified ``atol`` must + also be specified. If omitted, default values based on the + :attr:`~torch.Tensor.dtype` are selected with the below table. + atol (Optional[float]): Absolute tolerance. If specified :attr:`rtol` + must also be specified. If omitted, default values based on the + :attr:`~torch.Tensor.dtype` are selected with the below table. + equal_nan (bool): If ``True``, two ``NaN`` values will be considered + equal. + msg (Optional[Union[str, Callable]]): Optional error message to use if + the values of corresponding tensors mismatch. Unused when PyTorch + < 1.6. + """ + if 'parrots' not in TORCH_VERSION and \ + digit_version(TORCH_VERSION) >= digit_version('1.6'): + _assert_allclose( + actual, + expected, + rtol=rtol, + atol=atol, + equal_nan=equal_nan, + msg=msg) + else: + # torch.testing.assert_allclose has no ``msg`` argument + # when PyTorch < 1.6 + _assert_allclose( + actual, expected, rtol=rtol, atol=atol, equal_nan=equal_nan) diff --git a/tests/test_optim/test_scheduler/test_lr_scheduler.py b/tests/test_optim/test_scheduler/test_lr_scheduler.py index 8d53990b1c7abc581a990016fd78673c6a6200c8..d747b6bddb6fbbf3060c44f0eb080a79998052fd 100644 --- a/tests/test_optim/test_scheduler/test_lr_scheduler.py +++ b/tests/test_optim/test_scheduler/test_lr_scheduler.py @@ -5,11 +5,11 @@ from unittest import TestCase import torch import torch.nn.functional as F import torch.optim as optim -from torch.testing import assert_allclose from mmengine.optim.scheduler import (ConstantLR, CosineAnnealingLR, ExponentialLR, LinearLR, MultiStepLR, StepLR, _ParamScheduler) +from mmengine.testing import assert_allclose class ToyModel(torch.nn.Module): diff --git a/tests/test_optim/test_scheduler/test_momentum_scheduler.py b/tests/test_optim/test_scheduler/test_momentum_scheduler.py index 9b144a0acbe60f839347b80d4ba57bb36e1fdadc..fd63a9b941686783a58680f2d2fc70195fdcca94 100644 --- a/tests/test_optim/test_scheduler/test_momentum_scheduler.py +++ b/tests/test_optim/test_scheduler/test_momentum_scheduler.py @@ -5,13 +5,13 @@ from unittest import TestCase import torch import torch.nn.functional as F import torch.optim as optim -from torch.testing import assert_allclose from mmengine.optim.scheduler import (ConstantMomentum, CosineAnnealingMomentum, ExponentialMomentum, LinearMomentum, MultiStepMomentum, StepMomentum, _ParamScheduler) +from mmengine.testing import assert_allclose class ToyModel(torch.nn.Module): diff --git a/tests/test_optim/test_scheduler/test_param_scheduler.py b/tests/test_optim/test_scheduler/test_param_scheduler.py index 8303d4ce969892a777bfd3ccc0ec0d71d4db0040..d1467828e15a74b2a4f42e850b83c9474b17ed77 100644 --- a/tests/test_optim/test_scheduler/test_param_scheduler.py +++ b/tests/test_optim/test_scheduler/test_param_scheduler.py @@ -5,7 +5,6 @@ from unittest import TestCase import torch import torch.nn.functional as F import torch.optim as optim -from torch.testing import assert_allclose from mmengine.optim.scheduler import (ConstantParamScheduler, CosineAnnealingParamScheduler, @@ -13,6 +12,7 @@ from mmengine.optim.scheduler import (ConstantParamScheduler, LinearParamScheduler, MultiStepParamScheduler, StepParamScheduler, _ParamScheduler) +from mmengine.testing import assert_allclose class ToyModel(torch.nn.Module):