diff --git a/mmengine/analysis/jit_handles.py b/mmengine/analysis/jit_handles.py index 4f3dd696f2c4ebf13c56ffab5946953f494e05fd..917509d7e3e4060a70e7ee422b081d3a306f3a15 100644 --- a/mmengine/analysis/jit_handles.py +++ b/mmengine/analysis/jit_handles.py @@ -209,13 +209,16 @@ def einsum_flop_jit(inputs: List[Any], outputs: List[Any]) -> Union[int, Any]: def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Union[int, Any]: """Count flops for matmul.""" - # Inputs should be a list of length 2. - # Inputs contains the shapes of two matrices. - input_shapes = [get_shape(v) for v in inputs] - assert len(input_shapes) == 2, input_shapes - assert input_shapes[0][-1] == input_shapes[1][ # type: ignore - -2], input_shapes # type: ignore - flop = prod(input_shapes[0]) * input_shapes[-1][-1] # type: ignore + # input_shapes is a list of length 2. + input_shapes: list = [get_shape(v) for v in inputs] + input1, input2 = input_shapes + if len(input1) == 1: + input1 = [1, input1[0]] + if len(input2) == 1: + input2 = [input2[0], 1] + + assert input1[-1] == input2[-2], input_shapes + flop = prod(input1) * input2[-1] return flop diff --git a/tests/test_analysis/test_flop_count.py b/tests/test_analysis/test_flop_count.py index 0c0e8943cd362c2e251b4357759f8da9aa6c8ac2..20749a0babd17dbd4f958645e5259e18c59943e4 100644 --- a/tests/test_analysis/test_flop_count.py +++ b/tests/test_analysis/test_flop_count.py @@ -580,8 +580,6 @@ class TestFlopAnalyzer(unittest.TestCase): transpose=True, output_padding=output_padding9, ) - - def test_matmul(self) -> None: """Test flop count for operation matmul.""" m = 20 n = 10 @@ -596,6 +594,13 @@ class TestFlopAnalyzer(unittest.TestCase): self.assertDictEqual( flop_dict, gt_dict, 'Matmul operation failed to pass the flop count test.') + # Test with single dimension y + y = torch.randn(n) + gt_dict['matmul'] = m * n * 1 / 1e9 + flop_dict, _ = flop_count(m_net, (x, y)) + self.assertDictEqual( + flop_dict, gt_dict, + 'Matmul operation failed to pass the flop count test.') def test_matmul_broadcast(self) -> None: """Test flop count for operation matmul."""