From fba9a94f52db30f863c83e1f95faa8b07ef98b53 Mon Sep 17 00:00:00 2001
From: wxDai <wxDai2001@gmail.com>
Date: Mon, 29 Aug 2022 22:59:20 +0800
Subject: [PATCH] [Refactor] add testing utils (#475)

* add testing utils

* fix ut

* add blank line betweeen `Args` and `Returns`
---
 mmengine/testing/__init__.py       |  11 +-
 mmengine/testing/compare.py        | 141 ++++++++++++++++++++-
 tests/data/scripts/hello.py        |  25 ++++
 tests/test_testing/test_compare.py | 197 +++++++++++++++++++++++++++++
 4 files changed, 371 insertions(+), 3 deletions(-)
 create mode 100644 tests/data/scripts/hello.py
 create mode 100644 tests/test_testing/test_compare.py

diff --git a/mmengine/testing/__init__.py b/mmengine/testing/__init__.py
index 3318b89c..109b9e23 100644
--- a/mmengine/testing/__init__.py
+++ b/mmengine/testing/__init__.py
@@ -1,4 +1,11 @@
 # Copyright (c) OpenMMLab. All rights reserved.
-from .compare import assert_allclose
+from .compare import (assert_allclose, assert_attrs_equal,
+                      assert_dict_contains_subset, assert_dict_has_keys,
+                      assert_is_norm_layer, assert_keys_equal,
+                      assert_params_all_zeros, check_python_script)
 
-__all__ = ['assert_allclose']
+__all__ = [
+    'assert_allclose', 'assert_dict_contains_subset', 'assert_keys_equal',
+    'assert_attrs_equal', 'assert_dict_has_keys', 'assert_is_norm_layer',
+    'assert_params_all_zeros', 'check_python_script'
+]
diff --git a/mmengine/testing/compare.py b/mmengine/testing/compare.py
index 2f803c84..14c7a97b 100644
--- a/mmengine/testing/compare.py
+++ b/mmengine/testing/compare.py
@@ -1,10 +1,17 @@
 # Copyright (c) OpenMMLab. All rights reserved.
-from typing import Any, Callable, Optional, Union
+import sys
+from collections.abc import Iterable
+from runpy import run_path
+from shlex import split
+from typing import Any, Callable, Dict, List, Optional, Union
+from unittest.mock import patch
 
+from torch.nn import GroupNorm, LayerNorm
 from torch.testing import assert_allclose as _assert_allclose
 
 from mmengine.utils import digit_version
 from mmengine.utils.dl_utils import TORCH_VERSION
+from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm
 
 
 def assert_allclose(
@@ -47,3 +54,135 @@ def assert_allclose(
         # when PyTorch < 1.6
         _assert_allclose(
             actual, expected, rtol=rtol, atol=atol, equal_nan=equal_nan)
+
+
+def check_python_script(cmd):
+    """Run the python cmd script with `__main__`. The difference between
+    `os.system` is that, this function exectues code in the current process, so
+    that it can be tracked by coverage tools. Currently it supports two forms:
+
+    - ./tests/data/scripts/hello.py zz
+    - python tests/data/scripts/hello.py zz
+    """
+    args = split(cmd)
+    if args[0] == 'python':
+        args = args[1:]
+    with patch.object(sys, 'argv', args):
+        run_path(args[0], run_name='__main__')
+
+
+def _any(judge_result):
+    """Since built-in ``any`` works only when the element of iterable is not
+    iterable, implement the function."""
+    if not isinstance(judge_result, Iterable):
+        return judge_result
+
+    try:
+        for element in judge_result:
+            if _any(element):
+                return True
+    except TypeError:
+        # Maybe encounter the case: torch.tensor(True) | torch.tensor(False)
+        if judge_result:
+            return True
+    return False
+
+
+def assert_dict_contains_subset(dict_obj: Dict[Any, Any],
+                                expected_subset: Dict[Any, Any]) -> bool:
+    """Check if the dict_obj contains the expected_subset.
+
+    Args:
+        dict_obj (Dict[Any, Any]): Dict object to be checked.
+        expected_subset (Dict[Any, Any]): Subset expected to be contained in
+            dict_obj.
+
+    Returns:
+        bool: Whether the dict_obj contains the expected_subset.
+    """
+
+    for key, value in expected_subset.items():
+        if key not in dict_obj.keys() or _any(dict_obj[key] != value):
+            return False
+    return True
+
+
+def assert_attrs_equal(obj: Any, expected_attrs: Dict[str, Any]) -> bool:
+    """Check if attribute of class object is correct.
+
+    Args:
+        obj (object): Class object to be checked.
+        expected_attrs (Dict[str, Any]): Dict of the expected attrs.
+
+    Returns:
+        bool: Whether the attribute of class object is correct.
+    """
+    for attr, value in expected_attrs.items():
+        if not hasattr(obj, attr) or _any(getattr(obj, attr) != value):
+            return False
+    return True
+
+
+def assert_dict_has_keys(obj: Dict[str, Any],
+                         expected_keys: List[str]) -> bool:
+    """Check if the obj has all the expected_keys.
+
+    Args:
+        obj (Dict[str, Any]): Object to be checked.
+        expected_keys (List[str]): Keys expected to contained in the keys of
+            the obj.
+
+    Returns:
+        bool: Whether the obj has the expected keys.
+    """
+    return set(expected_keys).issubset(set(obj.keys()))
+
+
+def assert_keys_equal(result_keys: List[str], target_keys: List[str]) -> bool:
+    """Check if target_keys is equal to result_keys.
+
+    Args:
+        result_keys (List[str]): Result keys to be checked.
+        target_keys (List[str]): Target keys to be checked.
+
+    Returns:
+        bool: Whether target_keys is equal to result_keys.
+    """
+    return set(result_keys) == set(target_keys)
+
+
+def assert_is_norm_layer(module) -> bool:
+    """Check if the module is a norm layer.
+
+    Args:
+        module (nn.Module): The module to be checked.
+
+    Returns:
+        bool: Whether the module is a norm layer.
+    """
+
+    norm_layer_candidates = (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)
+    return isinstance(module, norm_layer_candidates)
+
+
+def assert_params_all_zeros(module) -> bool:
+    """Check if the parameters of the module is all zeros.
+
+    Args:
+        module (nn.Module): The module to be checked.
+
+    Returns:
+        bool: Whether the parameters of the module is all zeros.
+    """
+    weight_data = module.weight.data
+    is_weight_zero = weight_data.allclose(
+        weight_data.new_zeros(weight_data.size()))
+
+    if hasattr(module, 'bias') and module.bias is not None:
+        bias_data = module.bias.data
+        is_bias_zero = bias_data.allclose(
+            bias_data.new_zeros(bias_data.size()))
+    else:
+        is_bias_zero = True
+
+    return is_weight_zero and is_bias_zero
diff --git a/tests/data/scripts/hello.py b/tests/data/scripts/hello.py
new file mode 100644
index 00000000..2ed1a1e3
--- /dev/null
+++ b/tests/data/scripts/hello.py
@@ -0,0 +1,25 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+#!/usr/bin/env python
+
+import argparse
+import warnings
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='Say hello.')
+    parser.add_argument('name', help='To whom.')
+
+    args = parser.parse_args()
+
+    return args
+
+
+def main():
+    args = parse_args()
+    print(f'hello {args.name}!')
+    if args.name == 'agent':
+        warnings.warn('I have a secret!')
+
+
+if __name__ == '__main__':
+    main()
diff --git a/tests/test_testing/test_compare.py b/tests/test_testing/test_compare.py
new file mode 100644
index 00000000..cd4e79bc
--- /dev/null
+++ b/tests/test_testing/test_compare.py
@@ -0,0 +1,197 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import pytest
+
+import mmengine.testing as testing
+
+try:
+    import torch
+except ImportError:
+    torch = None
+else:
+    import torch.nn as nn
+
+
+def test_assert_dict_contains_subset():
+    dict_obj = {'a': 'test1', 'b': 2, 'c': (4, 6)}
+
+    # case 1
+    expected_subset = {'a': 'test1', 'b': 2, 'c': (4, 6)}
+    assert testing.assert_dict_contains_subset(dict_obj, expected_subset)
+
+    # case 2
+    expected_subset = {'a': 'test1', 'b': 2, 'c': (6, 4)}
+    assert not testing.assert_dict_contains_subset(dict_obj, expected_subset)
+
+    # case 3
+    expected_subset = {'a': 'test1', 'b': 2, 'c': None}
+    assert not testing.assert_dict_contains_subset(dict_obj, expected_subset)
+
+    # case 4
+    expected_subset = {'a': 'test1', 'b': 2, 'd': (4, 6)}
+    assert not testing.assert_dict_contains_subset(dict_obj, expected_subset)
+
+    # case 5
+    dict_obj = {
+        'a': 'test1',
+        'b': 2,
+        'c': (4, 6),
+        'd': np.array([[5, 3, 5], [1, 2, 3]])
+    }
+    expected_subset = {
+        'a': 'test1',
+        'b': 2,
+        'c': (4, 6),
+        'd': np.array([[5, 3, 5], [6, 2, 3]])
+    }
+    assert not testing.assert_dict_contains_subset(dict_obj, expected_subset)
+
+    # case 6
+    dict_obj = {'a': 'test1', 'b': 2, 'c': (4, 6), 'd': np.array([[1]])}
+    expected_subset = {'a': 'test1', 'b': 2, 'c': (4, 6), 'd': np.array([[1]])}
+    assert testing.assert_dict_contains_subset(dict_obj, expected_subset)
+
+    if torch is not None:
+        dict_obj = {
+            'a': 'test1',
+            'b': 2,
+            'c': (4, 6),
+            'd': torch.tensor([5, 3, 5])
+        }
+
+        # case 7
+        expected_subset = {'d': torch.tensor([5, 5, 5])}
+        assert not testing.assert_dict_contains_subset(dict_obj,
+                                                       expected_subset)
+
+        # case 8
+        expected_subset = {'d': torch.tensor([[5, 3, 5], [4, 1, 2]])}
+        assert not testing.assert_dict_contains_subset(dict_obj,
+                                                       expected_subset)
+
+
+def test_assert_attrs_equal():
+
+    class TestExample:
+        a, b, c = 1, ('wvi', 3), [4.5, 3.14]
+
+        def test_func(self):
+            return self.b
+
+    # case 1
+    assert testing.assert_attrs_equal(TestExample, {
+        'a': 1,
+        'b': ('wvi', 3),
+        'c': [4.5, 3.14]
+    })
+
+    # case 2
+    assert not testing.assert_attrs_equal(TestExample, {
+        'a': 1,
+        'b': ('wvi', 3),
+        'c': [4.5, 3.14, 2]
+    })
+
+    # case 3
+    assert not testing.assert_attrs_equal(TestExample, {
+        'bc': 54,
+        'c': [4.5, 3.14]
+    })
+
+    # case 4
+    assert testing.assert_attrs_equal(TestExample, {
+        'b': ('wvi', 3),
+        'test_func': TestExample.test_func
+    })
+
+    if torch is not None:
+
+        class TestExample:
+            a, b = torch.tensor([1]), torch.tensor([4, 5])
+
+        # case 5
+        assert testing.assert_attrs_equal(TestExample, {
+            'a': torch.tensor([1]),
+            'b': torch.tensor([4, 5])
+        })
+
+        # case 6
+        assert not testing.assert_attrs_equal(TestExample, {
+            'a': torch.tensor([1]),
+            'b': torch.tensor([4, 6])
+        })
+
+
+assert_dict_has_keys_data_1 = [({
+    'res_layer': 1,
+    'norm_layer': 2,
+    'dense_layer': 3
+})]
+assert_dict_has_keys_data_2 = [(['res_layer', 'dense_layer'], True),
+                               (['res_layer', 'conv_layer'], False)]
+
+
+@pytest.mark.parametrize('obj', assert_dict_has_keys_data_1)
+@pytest.mark.parametrize('expected_keys, ret_value',
+                         assert_dict_has_keys_data_2)
+def test_assert_dict_has_keys(obj, expected_keys, ret_value):
+    assert testing.assert_dict_has_keys(obj, expected_keys) == ret_value
+
+
+assert_keys_equal_data_1 = [(['res_layer', 'norm_layer', 'dense_layer'])]
+assert_keys_equal_data_2 = [(['res_layer', 'norm_layer', 'dense_layer'], True),
+                            (['res_layer', 'dense_layer', 'norm_layer'], True),
+                            (['res_layer', 'norm_layer'], False),
+                            (['res_layer', 'conv_layer', 'norm_layer'], False)]
+
+
+@pytest.mark.parametrize('result_keys', assert_keys_equal_data_1)
+@pytest.mark.parametrize('target_keys, ret_value', assert_keys_equal_data_2)
+def test_assert_keys_equal(result_keys, target_keys, ret_value):
+    assert testing.assert_keys_equal(result_keys, target_keys) == ret_value
+
+
+@pytest.mark.skipif(torch is None, reason='requires torch library')
+def test_assert_is_norm_layer():
+    # case 1
+    assert not testing.assert_is_norm_layer(nn.Conv3d(3, 64, 3))
+
+    # case 2
+    assert testing.assert_is_norm_layer(nn.BatchNorm3d(128))
+
+    # case 3
+    assert testing.assert_is_norm_layer(nn.GroupNorm(8, 64))
+
+    # case 4
+    assert not testing.assert_is_norm_layer(nn.Sigmoid())
+
+
+@pytest.mark.skipif(torch is None, reason='requires torch library')
+def test_assert_params_all_zeros():
+    demo_module = nn.Conv2d(3, 64, 3)
+    nn.init.constant_(demo_module.weight, 0)
+    nn.init.constant_(demo_module.bias, 0)
+    assert testing.assert_params_all_zeros(demo_module)
+
+    nn.init.xavier_normal_(demo_module.weight)
+    nn.init.constant_(demo_module.bias, 0)
+    assert not testing.assert_params_all_zeros(demo_module)
+
+    demo_module = nn.Linear(2048, 400, bias=False)
+    nn.init.constant_(demo_module.weight, 0)
+    assert testing.assert_params_all_zeros(demo_module)
+
+    nn.init.normal_(demo_module.weight, mean=0, std=0.01)
+    assert not testing.assert_params_all_zeros(demo_module)
+
+
+def test_check_python_script(capsys):
+    testing.check_python_script('./tests/data/scripts/hello.py zz')
+    captured = capsys.readouterr().out
+    assert captured == 'hello zz!\n'
+    testing.check_python_script('./tests/data/scripts/hello.py agent')
+    captured = capsys.readouterr().out
+    assert captured == 'hello agent!\n'
+    # Make sure that wrong cmd raises an error
+    with pytest.raises(SystemExit):
+        testing.check_python_script('./tests/data/scripts/hello.py li zz')
-- 
GitLab