From c2c5664fad0ce88e3f661ae23b04008550a7c4e5 Mon Sep 17 00:00:00 2001 From: RangiLyu <lyuchqi@gmail.com> Date: Tue, 1 Mar 2022 11:28:21 +0800 Subject: [PATCH] Fix pt1.5 unit tests. (#65) * Fix pt1.5 unit tests. * move to mmengine.testing --- mmengine/testing/__init__.py | 4 ++ mmengine/testing/compare.py | 48 +++++++++++++++++++ .../test_scheduler/test_lr_scheduler.py | 2 +- .../test_scheduler/test_momentum_scheduler.py | 2 +- .../test_scheduler/test_param_scheduler.py | 2 +- 5 files changed, 55 insertions(+), 3 deletions(-) create mode 100644 mmengine/testing/__init__.py create mode 100644 mmengine/testing/compare.py diff --git a/mmengine/testing/__init__.py b/mmengine/testing/__init__.py new file mode 100644 index 00000000..3318b89c --- /dev/null +++ b/mmengine/testing/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .compare import assert_allclose + +__all__ = ['assert_allclose'] diff --git a/mmengine/testing/compare.py b/mmengine/testing/compare.py new file mode 100644 index 00000000..abbf2261 --- /dev/null +++ b/mmengine/testing/compare.py @@ -0,0 +1,48 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Callable, Optional, Union + +from torch.testing import assert_allclose as _assert_allclose + +from mmengine.utils import TORCH_VERSION, digit_version + + +def assert_allclose( + actual: Any, + expected: Any, + rtol: Optional[float] = None, + atol: Optional[float] = None, + equal_nan: bool = True, + msg: Optional[Union[str, Callable]] = '', +) -> None: + """Asserts that ``actual`` and ``expected`` are close. A wrapper function + of ``torch.testing.assert_allclose``. + + Args: + actual (Any): Actual input. + expected (Any): Expected input. + rtol (Optional[float]): Relative tolerance. If specified ``atol`` must + also be specified. If omitted, default values based on the + :attr:`~torch.Tensor.dtype` are selected with the below table. + atol (Optional[float]): Absolute tolerance. If specified :attr:`rtol` + must also be specified. If omitted, default values based on the + :attr:`~torch.Tensor.dtype` are selected with the below table. + equal_nan (bool): If ``True``, two ``NaN`` values will be considered + equal. + msg (Optional[Union[str, Callable]]): Optional error message to use if + the values of corresponding tensors mismatch. Unused when PyTorch + < 1.6. + """ + if 'parrots' not in TORCH_VERSION and \ + digit_version(TORCH_VERSION) >= digit_version('1.6'): + _assert_allclose( + actual, + expected, + rtol=rtol, + atol=atol, + equal_nan=equal_nan, + msg=msg) + else: + # torch.testing.assert_allclose has no ``msg`` argument + # when PyTorch < 1.6 + _assert_allclose( + actual, expected, rtol=rtol, atol=atol, equal_nan=equal_nan) diff --git a/tests/test_optim/test_scheduler/test_lr_scheduler.py b/tests/test_optim/test_scheduler/test_lr_scheduler.py index 8d53990b..d747b6bd 100644 --- a/tests/test_optim/test_scheduler/test_lr_scheduler.py +++ b/tests/test_optim/test_scheduler/test_lr_scheduler.py @@ -5,11 +5,11 @@ from unittest import TestCase import torch import torch.nn.functional as F import torch.optim as optim -from torch.testing import assert_allclose from mmengine.optim.scheduler import (ConstantLR, CosineAnnealingLR, ExponentialLR, LinearLR, MultiStepLR, StepLR, _ParamScheduler) +from mmengine.testing import assert_allclose class ToyModel(torch.nn.Module): diff --git a/tests/test_optim/test_scheduler/test_momentum_scheduler.py b/tests/test_optim/test_scheduler/test_momentum_scheduler.py index 9b144a0a..fd63a9b9 100644 --- a/tests/test_optim/test_scheduler/test_momentum_scheduler.py +++ b/tests/test_optim/test_scheduler/test_momentum_scheduler.py @@ -5,13 +5,13 @@ from unittest import TestCase import torch import torch.nn.functional as F import torch.optim as optim -from torch.testing import assert_allclose from mmengine.optim.scheduler import (ConstantMomentum, CosineAnnealingMomentum, ExponentialMomentum, LinearMomentum, MultiStepMomentum, StepMomentum, _ParamScheduler) +from mmengine.testing import assert_allclose class ToyModel(torch.nn.Module): diff --git a/tests/test_optim/test_scheduler/test_param_scheduler.py b/tests/test_optim/test_scheduler/test_param_scheduler.py index 8303d4ce..d1467828 100644 --- a/tests/test_optim/test_scheduler/test_param_scheduler.py +++ b/tests/test_optim/test_scheduler/test_param_scheduler.py @@ -5,7 +5,6 @@ from unittest import TestCase import torch import torch.nn.functional as F import torch.optim as optim -from torch.testing import assert_allclose from mmengine.optim.scheduler import (ConstantParamScheduler, CosineAnnealingParamScheduler, @@ -13,6 +12,7 @@ from mmengine.optim.scheduler import (ConstantParamScheduler, LinearParamScheduler, MultiStepParamScheduler, StepParamScheduler, _ParamScheduler) +from mmengine.testing import assert_allclose class ToyModel(torch.nn.Module): -- GitLab