From 8beacd3b58379f9315cff47fab5757ba4dc5385e Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Thu, 9 Mar 2023 17:29:26 +0800 Subject: [PATCH] [Fix] Support calculate the flops of `matmul` with single dimension matrix (#970) * Support calculate the flops of matmul * Remove unnecessary type ignore * Update mmengine/analysis/jit_handles.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --------- Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --- mmengine/analysis/jit_handles.py | 17 ++++++++++------- tests/test_analysis/test_flop_count.py | 9 +++++++-- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/mmengine/analysis/jit_handles.py b/mmengine/analysis/jit_handles.py index 4f3dd696..917509d7 100644 --- a/mmengine/analysis/jit_handles.py +++ b/mmengine/analysis/jit_handles.py @@ -209,13 +209,16 @@ def einsum_flop_jit(inputs: List[Any], outputs: List[Any]) -> Union[int, Any]: def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Union[int, Any]: """Count flops for matmul.""" - # Inputs should be a list of length 2. - # Inputs contains the shapes of two matrices. - input_shapes = [get_shape(v) for v in inputs] - assert len(input_shapes) == 2, input_shapes - assert input_shapes[0][-1] == input_shapes[1][ # type: ignore - -2], input_shapes # type: ignore - flop = prod(input_shapes[0]) * input_shapes[-1][-1] # type: ignore + # input_shapes is a list of length 2. + input_shapes: list = [get_shape(v) for v in inputs] + input1, input2 = input_shapes + if len(input1) == 1: + input1 = [1, input1[0]] + if len(input2) == 1: + input2 = [input2[0], 1] + + assert input1[-1] == input2[-2], input_shapes + flop = prod(input1) * input2[-1] return flop diff --git a/tests/test_analysis/test_flop_count.py b/tests/test_analysis/test_flop_count.py index 0c0e8943..20749a0b 100644 --- a/tests/test_analysis/test_flop_count.py +++ b/tests/test_analysis/test_flop_count.py @@ -580,8 +580,6 @@ class TestFlopAnalyzer(unittest.TestCase): transpose=True, output_padding=output_padding9, ) - - def test_matmul(self) -> None: """Test flop count for operation matmul.""" m = 20 n = 10 @@ -596,6 +594,13 @@ class TestFlopAnalyzer(unittest.TestCase): self.assertDictEqual( flop_dict, gt_dict, 'Matmul operation failed to pass the flop count test.') + # Test with single dimension y + y = torch.randn(n) + gt_dict['matmul'] = m * n * 1 / 1e9 + flop_dict, _ = flop_count(m_net, (x, y)) + self.assertDictEqual( + flop_dict, gt_dict, + 'Matmul operation failed to pass the flop count test.') def test_matmul_broadcast(self) -> None: """Test flop count for operation matmul.""" -- GitLab