Skip to content
Snippets Groups Projects
Unverified Commit fba9a94f authored by wxDai's avatar wxDai Committed by GitHub
Browse files

[Refactor] add testing utils (#475)

* add testing utils

* fix ut

* add blank line betweeen `Args` and `Returns`
parent 7e423cf2
No related branches found
No related tags found
No related merge requests found
# Copyright (c) OpenMMLab. All rights reserved.
from .compare import assert_allclose
from .compare import (assert_allclose, assert_attrs_equal,
assert_dict_contains_subset, assert_dict_has_keys,
assert_is_norm_layer, assert_keys_equal,
assert_params_all_zeros, check_python_script)
__all__ = ['assert_allclose']
__all__ = [
'assert_allclose', 'assert_dict_contains_subset', 'assert_keys_equal',
'assert_attrs_equal', 'assert_dict_has_keys', 'assert_is_norm_layer',
'assert_params_all_zeros', 'check_python_script'
]
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Callable, Optional, Union
import sys
from collections.abc import Iterable
from runpy import run_path
from shlex import split
from typing import Any, Callable, Dict, List, Optional, Union
from unittest.mock import patch
from torch.nn import GroupNorm, LayerNorm
from torch.testing import assert_allclose as _assert_allclose
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm
def assert_allclose(
......@@ -47,3 +54,135 @@ def assert_allclose(
# when PyTorch < 1.6
_assert_allclose(
actual, expected, rtol=rtol, atol=atol, equal_nan=equal_nan)
def check_python_script(cmd):
"""Run the python cmd script with `__main__`. The difference between
`os.system` is that, this function exectues code in the current process, so
that it can be tracked by coverage tools. Currently it supports two forms:
- ./tests/data/scripts/hello.py zz
- python tests/data/scripts/hello.py zz
"""
args = split(cmd)
if args[0] == 'python':
args = args[1:]
with patch.object(sys, 'argv', args):
run_path(args[0], run_name='__main__')
def _any(judge_result):
"""Since built-in ``any`` works only when the element of iterable is not
iterable, implement the function."""
if not isinstance(judge_result, Iterable):
return judge_result
try:
for element in judge_result:
if _any(element):
return True
except TypeError:
# Maybe encounter the case: torch.tensor(True) | torch.tensor(False)
if judge_result:
return True
return False
def assert_dict_contains_subset(dict_obj: Dict[Any, Any],
expected_subset: Dict[Any, Any]) -> bool:
"""Check if the dict_obj contains the expected_subset.
Args:
dict_obj (Dict[Any, Any]): Dict object to be checked.
expected_subset (Dict[Any, Any]): Subset expected to be contained in
dict_obj.
Returns:
bool: Whether the dict_obj contains the expected_subset.
"""
for key, value in expected_subset.items():
if key not in dict_obj.keys() or _any(dict_obj[key] != value):
return False
return True
def assert_attrs_equal(obj: Any, expected_attrs: Dict[str, Any]) -> bool:
"""Check if attribute of class object is correct.
Args:
obj (object): Class object to be checked.
expected_attrs (Dict[str, Any]): Dict of the expected attrs.
Returns:
bool: Whether the attribute of class object is correct.
"""
for attr, value in expected_attrs.items():
if not hasattr(obj, attr) or _any(getattr(obj, attr) != value):
return False
return True
def assert_dict_has_keys(obj: Dict[str, Any],
expected_keys: List[str]) -> bool:
"""Check if the obj has all the expected_keys.
Args:
obj (Dict[str, Any]): Object to be checked.
expected_keys (List[str]): Keys expected to contained in the keys of
the obj.
Returns:
bool: Whether the obj has the expected keys.
"""
return set(expected_keys).issubset(set(obj.keys()))
def assert_keys_equal(result_keys: List[str], target_keys: List[str]) -> bool:
"""Check if target_keys is equal to result_keys.
Args:
result_keys (List[str]): Result keys to be checked.
target_keys (List[str]): Target keys to be checked.
Returns:
bool: Whether target_keys is equal to result_keys.
"""
return set(result_keys) == set(target_keys)
def assert_is_norm_layer(module) -> bool:
"""Check if the module is a norm layer.
Args:
module (nn.Module): The module to be checked.
Returns:
bool: Whether the module is a norm layer.
"""
norm_layer_candidates = (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)
return isinstance(module, norm_layer_candidates)
def assert_params_all_zeros(module) -> bool:
"""Check if the parameters of the module is all zeros.
Args:
module (nn.Module): The module to be checked.
Returns:
bool: Whether the parameters of the module is all zeros.
"""
weight_data = module.weight.data
is_weight_zero = weight_data.allclose(
weight_data.new_zeros(weight_data.size()))
if hasattr(module, 'bias') and module.bias is not None:
bias_data = module.bias.data
is_bias_zero = bias_data.allclose(
bias_data.new_zeros(bias_data.size()))
else:
is_bias_zero = True
return is_weight_zero and is_bias_zero
# Copyright (c) OpenMMLab. All rights reserved.
#!/usr/bin/env python
import argparse
import warnings
def parse_args():
parser = argparse.ArgumentParser(description='Say hello.')
parser.add_argument('name', help='To whom.')
args = parser.parse_args()
return args
def main():
args = parse_args()
print(f'hello {args.name}!')
if args.name == 'agent':
warnings.warn('I have a secret!')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import mmengine.testing as testing
try:
import torch
except ImportError:
torch = None
else:
import torch.nn as nn
def test_assert_dict_contains_subset():
dict_obj = {'a': 'test1', 'b': 2, 'c': (4, 6)}
# case 1
expected_subset = {'a': 'test1', 'b': 2, 'c': (4, 6)}
assert testing.assert_dict_contains_subset(dict_obj, expected_subset)
# case 2
expected_subset = {'a': 'test1', 'b': 2, 'c': (6, 4)}
assert not testing.assert_dict_contains_subset(dict_obj, expected_subset)
# case 3
expected_subset = {'a': 'test1', 'b': 2, 'c': None}
assert not testing.assert_dict_contains_subset(dict_obj, expected_subset)
# case 4
expected_subset = {'a': 'test1', 'b': 2, 'd': (4, 6)}
assert not testing.assert_dict_contains_subset(dict_obj, expected_subset)
# case 5
dict_obj = {
'a': 'test1',
'b': 2,
'c': (4, 6),
'd': np.array([[5, 3, 5], [1, 2, 3]])
}
expected_subset = {
'a': 'test1',
'b': 2,
'c': (4, 6),
'd': np.array([[5, 3, 5], [6, 2, 3]])
}
assert not testing.assert_dict_contains_subset(dict_obj, expected_subset)
# case 6
dict_obj = {'a': 'test1', 'b': 2, 'c': (4, 6), 'd': np.array([[1]])}
expected_subset = {'a': 'test1', 'b': 2, 'c': (4, 6), 'd': np.array([[1]])}
assert testing.assert_dict_contains_subset(dict_obj, expected_subset)
if torch is not None:
dict_obj = {
'a': 'test1',
'b': 2,
'c': (4, 6),
'd': torch.tensor([5, 3, 5])
}
# case 7
expected_subset = {'d': torch.tensor([5, 5, 5])}
assert not testing.assert_dict_contains_subset(dict_obj,
expected_subset)
# case 8
expected_subset = {'d': torch.tensor([[5, 3, 5], [4, 1, 2]])}
assert not testing.assert_dict_contains_subset(dict_obj,
expected_subset)
def test_assert_attrs_equal():
class TestExample:
a, b, c = 1, ('wvi', 3), [4.5, 3.14]
def test_func(self):
return self.b
# case 1
assert testing.assert_attrs_equal(TestExample, {
'a': 1,
'b': ('wvi', 3),
'c': [4.5, 3.14]
})
# case 2
assert not testing.assert_attrs_equal(TestExample, {
'a': 1,
'b': ('wvi', 3),
'c': [4.5, 3.14, 2]
})
# case 3
assert not testing.assert_attrs_equal(TestExample, {
'bc': 54,
'c': [4.5, 3.14]
})
# case 4
assert testing.assert_attrs_equal(TestExample, {
'b': ('wvi', 3),
'test_func': TestExample.test_func
})
if torch is not None:
class TestExample:
a, b = torch.tensor([1]), torch.tensor([4, 5])
# case 5
assert testing.assert_attrs_equal(TestExample, {
'a': torch.tensor([1]),
'b': torch.tensor([4, 5])
})
# case 6
assert not testing.assert_attrs_equal(TestExample, {
'a': torch.tensor([1]),
'b': torch.tensor([4, 6])
})
assert_dict_has_keys_data_1 = [({
'res_layer': 1,
'norm_layer': 2,
'dense_layer': 3
})]
assert_dict_has_keys_data_2 = [(['res_layer', 'dense_layer'], True),
(['res_layer', 'conv_layer'], False)]
@pytest.mark.parametrize('obj', assert_dict_has_keys_data_1)
@pytest.mark.parametrize('expected_keys, ret_value',
assert_dict_has_keys_data_2)
def test_assert_dict_has_keys(obj, expected_keys, ret_value):
assert testing.assert_dict_has_keys(obj, expected_keys) == ret_value
assert_keys_equal_data_1 = [(['res_layer', 'norm_layer', 'dense_layer'])]
assert_keys_equal_data_2 = [(['res_layer', 'norm_layer', 'dense_layer'], True),
(['res_layer', 'dense_layer', 'norm_layer'], True),
(['res_layer', 'norm_layer'], False),
(['res_layer', 'conv_layer', 'norm_layer'], False)]
@pytest.mark.parametrize('result_keys', assert_keys_equal_data_1)
@pytest.mark.parametrize('target_keys, ret_value', assert_keys_equal_data_2)
def test_assert_keys_equal(result_keys, target_keys, ret_value):
assert testing.assert_keys_equal(result_keys, target_keys) == ret_value
@pytest.mark.skipif(torch is None, reason='requires torch library')
def test_assert_is_norm_layer():
# case 1
assert not testing.assert_is_norm_layer(nn.Conv3d(3, 64, 3))
# case 2
assert testing.assert_is_norm_layer(nn.BatchNorm3d(128))
# case 3
assert testing.assert_is_norm_layer(nn.GroupNorm(8, 64))
# case 4
assert not testing.assert_is_norm_layer(nn.Sigmoid())
@pytest.mark.skipif(torch is None, reason='requires torch library')
def test_assert_params_all_zeros():
demo_module = nn.Conv2d(3, 64, 3)
nn.init.constant_(demo_module.weight, 0)
nn.init.constant_(demo_module.bias, 0)
assert testing.assert_params_all_zeros(demo_module)
nn.init.xavier_normal_(demo_module.weight)
nn.init.constant_(demo_module.bias, 0)
assert not testing.assert_params_all_zeros(demo_module)
demo_module = nn.Linear(2048, 400, bias=False)
nn.init.constant_(demo_module.weight, 0)
assert testing.assert_params_all_zeros(demo_module)
nn.init.normal_(demo_module.weight, mean=0, std=0.01)
assert not testing.assert_params_all_zeros(demo_module)
def test_check_python_script(capsys):
testing.check_python_script('./tests/data/scripts/hello.py zz')
captured = capsys.readouterr().out
assert captured == 'hello zz!\n'
testing.check_python_script('./tests/data/scripts/hello.py agent')
captured = capsys.readouterr().out
assert captured == 'hello agent!\n'
# Make sure that wrong cmd raises an error
with pytest.raises(SystemExit):
testing.check_python_script('./tests/data/scripts/hello.py li zz')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment