diff --git a/mmengine/testing/__init__.py b/mmengine/testing/__init__.py index 3318b89c34a2dfee68df90953b28d4f359e1f244..109b9e23a19f5eb3b2f634e8b94cdb5119037963 100644 --- a/mmengine/testing/__init__.py +++ b/mmengine/testing/__init__.py @@ -1,4 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .compare import assert_allclose +from .compare import (assert_allclose, assert_attrs_equal, + assert_dict_contains_subset, assert_dict_has_keys, + assert_is_norm_layer, assert_keys_equal, + assert_params_all_zeros, check_python_script) -__all__ = ['assert_allclose'] +__all__ = [ + 'assert_allclose', 'assert_dict_contains_subset', 'assert_keys_equal', + 'assert_attrs_equal', 'assert_dict_has_keys', 'assert_is_norm_layer', + 'assert_params_all_zeros', 'check_python_script' +] diff --git a/mmengine/testing/compare.py b/mmengine/testing/compare.py index 2f803c84c78951b07628d52d5124d1f52a091920..14c7a97ba73ee98600102ab28d649b01aab8f3bc 100644 --- a/mmengine/testing/compare.py +++ b/mmengine/testing/compare.py @@ -1,10 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Callable, Optional, Union +import sys +from collections.abc import Iterable +from runpy import run_path +from shlex import split +from typing import Any, Callable, Dict, List, Optional, Union +from unittest.mock import patch +from torch.nn import GroupNorm, LayerNorm from torch.testing import assert_allclose as _assert_allclose from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm def assert_allclose( @@ -47,3 +54,135 @@ def assert_allclose( # when PyTorch < 1.6 _assert_allclose( actual, expected, rtol=rtol, atol=atol, equal_nan=equal_nan) + + +def check_python_script(cmd): + """Run the python cmd script with `__main__`. The difference between + `os.system` is that, this function exectues code in the current process, so + that it can be tracked by coverage tools. Currently it supports two forms: + + - ./tests/data/scripts/hello.py zz + - python tests/data/scripts/hello.py zz + """ + args = split(cmd) + if args[0] == 'python': + args = args[1:] + with patch.object(sys, 'argv', args): + run_path(args[0], run_name='__main__') + + +def _any(judge_result): + """Since built-in ``any`` works only when the element of iterable is not + iterable, implement the function.""" + if not isinstance(judge_result, Iterable): + return judge_result + + try: + for element in judge_result: + if _any(element): + return True + except TypeError: + # Maybe encounter the case: torch.tensor(True) | torch.tensor(False) + if judge_result: + return True + return False + + +def assert_dict_contains_subset(dict_obj: Dict[Any, Any], + expected_subset: Dict[Any, Any]) -> bool: + """Check if the dict_obj contains the expected_subset. + + Args: + dict_obj (Dict[Any, Any]): Dict object to be checked. + expected_subset (Dict[Any, Any]): Subset expected to be contained in + dict_obj. + + Returns: + bool: Whether the dict_obj contains the expected_subset. + """ + + for key, value in expected_subset.items(): + if key not in dict_obj.keys() or _any(dict_obj[key] != value): + return False + return True + + +def assert_attrs_equal(obj: Any, expected_attrs: Dict[str, Any]) -> bool: + """Check if attribute of class object is correct. + + Args: + obj (object): Class object to be checked. + expected_attrs (Dict[str, Any]): Dict of the expected attrs. + + Returns: + bool: Whether the attribute of class object is correct. + """ + for attr, value in expected_attrs.items(): + if not hasattr(obj, attr) or _any(getattr(obj, attr) != value): + return False + return True + + +def assert_dict_has_keys(obj: Dict[str, Any], + expected_keys: List[str]) -> bool: + """Check if the obj has all the expected_keys. + + Args: + obj (Dict[str, Any]): Object to be checked. + expected_keys (List[str]): Keys expected to contained in the keys of + the obj. + + Returns: + bool: Whether the obj has the expected keys. + """ + return set(expected_keys).issubset(set(obj.keys())) + + +def assert_keys_equal(result_keys: List[str], target_keys: List[str]) -> bool: + """Check if target_keys is equal to result_keys. + + Args: + result_keys (List[str]): Result keys to be checked. + target_keys (List[str]): Target keys to be checked. + + Returns: + bool: Whether target_keys is equal to result_keys. + """ + return set(result_keys) == set(target_keys) + + +def assert_is_norm_layer(module) -> bool: + """Check if the module is a norm layer. + + Args: + module (nn.Module): The module to be checked. + + Returns: + bool: Whether the module is a norm layer. + """ + + norm_layer_candidates = (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm) + return isinstance(module, norm_layer_candidates) + + +def assert_params_all_zeros(module) -> bool: + """Check if the parameters of the module is all zeros. + + Args: + module (nn.Module): The module to be checked. + + Returns: + bool: Whether the parameters of the module is all zeros. + """ + weight_data = module.weight.data + is_weight_zero = weight_data.allclose( + weight_data.new_zeros(weight_data.size())) + + if hasattr(module, 'bias') and module.bias is not None: + bias_data = module.bias.data + is_bias_zero = bias_data.allclose( + bias_data.new_zeros(bias_data.size())) + else: + is_bias_zero = True + + return is_weight_zero and is_bias_zero diff --git a/tests/data/scripts/hello.py b/tests/data/scripts/hello.py new file mode 100644 index 0000000000000000000000000000000000000000..2ed1a1e319fa36eb11ed3f0fcd365eb43a382d01 --- /dev/null +++ b/tests/data/scripts/hello.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +#!/usr/bin/env python + +import argparse +import warnings + + +def parse_args(): + parser = argparse.ArgumentParser(description='Say hello.') + parser.add_argument('name', help='To whom.') + + args = parser.parse_args() + + return args + + +def main(): + args = parse_args() + print(f'hello {args.name}!') + if args.name == 'agent': + warnings.warn('I have a secret!') + + +if __name__ == '__main__': + main() diff --git a/tests/test_testing/test_compare.py b/tests/test_testing/test_compare.py new file mode 100644 index 0000000000000000000000000000000000000000..cd4e79bc57986ee67068c617cd32e32bf6ee2d33 --- /dev/null +++ b/tests/test_testing/test_compare.py @@ -0,0 +1,197 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import pytest + +import mmengine.testing as testing + +try: + import torch +except ImportError: + torch = None +else: + import torch.nn as nn + + +def test_assert_dict_contains_subset(): + dict_obj = {'a': 'test1', 'b': 2, 'c': (4, 6)} + + # case 1 + expected_subset = {'a': 'test1', 'b': 2, 'c': (4, 6)} + assert testing.assert_dict_contains_subset(dict_obj, expected_subset) + + # case 2 + expected_subset = {'a': 'test1', 'b': 2, 'c': (6, 4)} + assert not testing.assert_dict_contains_subset(dict_obj, expected_subset) + + # case 3 + expected_subset = {'a': 'test1', 'b': 2, 'c': None} + assert not testing.assert_dict_contains_subset(dict_obj, expected_subset) + + # case 4 + expected_subset = {'a': 'test1', 'b': 2, 'd': (4, 6)} + assert not testing.assert_dict_contains_subset(dict_obj, expected_subset) + + # case 5 + dict_obj = { + 'a': 'test1', + 'b': 2, + 'c': (4, 6), + 'd': np.array([[5, 3, 5], [1, 2, 3]]) + } + expected_subset = { + 'a': 'test1', + 'b': 2, + 'c': (4, 6), + 'd': np.array([[5, 3, 5], [6, 2, 3]]) + } + assert not testing.assert_dict_contains_subset(dict_obj, expected_subset) + + # case 6 + dict_obj = {'a': 'test1', 'b': 2, 'c': (4, 6), 'd': np.array([[1]])} + expected_subset = {'a': 'test1', 'b': 2, 'c': (4, 6), 'd': np.array([[1]])} + assert testing.assert_dict_contains_subset(dict_obj, expected_subset) + + if torch is not None: + dict_obj = { + 'a': 'test1', + 'b': 2, + 'c': (4, 6), + 'd': torch.tensor([5, 3, 5]) + } + + # case 7 + expected_subset = {'d': torch.tensor([5, 5, 5])} + assert not testing.assert_dict_contains_subset(dict_obj, + expected_subset) + + # case 8 + expected_subset = {'d': torch.tensor([[5, 3, 5], [4, 1, 2]])} + assert not testing.assert_dict_contains_subset(dict_obj, + expected_subset) + + +def test_assert_attrs_equal(): + + class TestExample: + a, b, c = 1, ('wvi', 3), [4.5, 3.14] + + def test_func(self): + return self.b + + # case 1 + assert testing.assert_attrs_equal(TestExample, { + 'a': 1, + 'b': ('wvi', 3), + 'c': [4.5, 3.14] + }) + + # case 2 + assert not testing.assert_attrs_equal(TestExample, { + 'a': 1, + 'b': ('wvi', 3), + 'c': [4.5, 3.14, 2] + }) + + # case 3 + assert not testing.assert_attrs_equal(TestExample, { + 'bc': 54, + 'c': [4.5, 3.14] + }) + + # case 4 + assert testing.assert_attrs_equal(TestExample, { + 'b': ('wvi', 3), + 'test_func': TestExample.test_func + }) + + if torch is not None: + + class TestExample: + a, b = torch.tensor([1]), torch.tensor([4, 5]) + + # case 5 + assert testing.assert_attrs_equal(TestExample, { + 'a': torch.tensor([1]), + 'b': torch.tensor([4, 5]) + }) + + # case 6 + assert not testing.assert_attrs_equal(TestExample, { + 'a': torch.tensor([1]), + 'b': torch.tensor([4, 6]) + }) + + +assert_dict_has_keys_data_1 = [({ + 'res_layer': 1, + 'norm_layer': 2, + 'dense_layer': 3 +})] +assert_dict_has_keys_data_2 = [(['res_layer', 'dense_layer'], True), + (['res_layer', 'conv_layer'], False)] + + +@pytest.mark.parametrize('obj', assert_dict_has_keys_data_1) +@pytest.mark.parametrize('expected_keys, ret_value', + assert_dict_has_keys_data_2) +def test_assert_dict_has_keys(obj, expected_keys, ret_value): + assert testing.assert_dict_has_keys(obj, expected_keys) == ret_value + + +assert_keys_equal_data_1 = [(['res_layer', 'norm_layer', 'dense_layer'])] +assert_keys_equal_data_2 = [(['res_layer', 'norm_layer', 'dense_layer'], True), + (['res_layer', 'dense_layer', 'norm_layer'], True), + (['res_layer', 'norm_layer'], False), + (['res_layer', 'conv_layer', 'norm_layer'], False)] + + +@pytest.mark.parametrize('result_keys', assert_keys_equal_data_1) +@pytest.mark.parametrize('target_keys, ret_value', assert_keys_equal_data_2) +def test_assert_keys_equal(result_keys, target_keys, ret_value): + assert testing.assert_keys_equal(result_keys, target_keys) == ret_value + + +@pytest.mark.skipif(torch is None, reason='requires torch library') +def test_assert_is_norm_layer(): + # case 1 + assert not testing.assert_is_norm_layer(nn.Conv3d(3, 64, 3)) + + # case 2 + assert testing.assert_is_norm_layer(nn.BatchNorm3d(128)) + + # case 3 + assert testing.assert_is_norm_layer(nn.GroupNorm(8, 64)) + + # case 4 + assert not testing.assert_is_norm_layer(nn.Sigmoid()) + + +@pytest.mark.skipif(torch is None, reason='requires torch library') +def test_assert_params_all_zeros(): + demo_module = nn.Conv2d(3, 64, 3) + nn.init.constant_(demo_module.weight, 0) + nn.init.constant_(demo_module.bias, 0) + assert testing.assert_params_all_zeros(demo_module) + + nn.init.xavier_normal_(demo_module.weight) + nn.init.constant_(demo_module.bias, 0) + assert not testing.assert_params_all_zeros(demo_module) + + demo_module = nn.Linear(2048, 400, bias=False) + nn.init.constant_(demo_module.weight, 0) + assert testing.assert_params_all_zeros(demo_module) + + nn.init.normal_(demo_module.weight, mean=0, std=0.01) + assert not testing.assert_params_all_zeros(demo_module) + + +def test_check_python_script(capsys): + testing.check_python_script('./tests/data/scripts/hello.py zz') + captured = capsys.readouterr().out + assert captured == 'hello zz!\n' + testing.check_python_script('./tests/data/scripts/hello.py agent') + captured = capsys.readouterr().out + assert captured == 'hello agent!\n' + # Make sure that wrong cmd raises an error + with pytest.raises(SystemExit): + testing.check_python_script('./tests/data/scripts/hello.py li zz')