Skip to content
Snippets Groups Projects
Unverified Commit 8beacd3b authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Fix] Support calculate the flops of `matmul` with single dimension matrix (#970)


* Support calculate the flops of matmul

* Remove unnecessary type ignore

* Update mmengine/analysis/jit_handles.py

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

---------

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent 6e58c0d2
No related branches found
No related tags found
No related merge requests found
......@@ -209,13 +209,16 @@ def einsum_flop_jit(inputs: List[Any], outputs: List[Any]) -> Union[int, Any]:
def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Union[int, Any]:
"""Count flops for matmul."""
# Inputs should be a list of length 2.
# Inputs contains the shapes of two matrices.
input_shapes = [get_shape(v) for v in inputs]
assert len(input_shapes) == 2, input_shapes
assert input_shapes[0][-1] == input_shapes[1][ # type: ignore
-2], input_shapes # type: ignore
flop = prod(input_shapes[0]) * input_shapes[-1][-1] # type: ignore
# input_shapes is a list of length 2.
input_shapes: list = [get_shape(v) for v in inputs]
input1, input2 = input_shapes
if len(input1) == 1:
input1 = [1, input1[0]]
if len(input2) == 1:
input2 = [input2[0], 1]
assert input1[-1] == input2[-2], input_shapes
flop = prod(input1) * input2[-1]
return flop
......
......@@ -580,8 +580,6 @@ class TestFlopAnalyzer(unittest.TestCase):
transpose=True,
output_padding=output_padding9,
)
def test_matmul(self) -> None:
"""Test flop count for operation matmul."""
m = 20
n = 10
......@@ -596,6 +594,13 @@ class TestFlopAnalyzer(unittest.TestCase):
self.assertDictEqual(
flop_dict, gt_dict,
'Matmul operation failed to pass the flop count test.')
# Test with single dimension y
y = torch.randn(n)
gt_dict['matmul'] = m * n * 1 / 1e9
flop_dict, _ = flop_count(m_net, (x, y))
self.assertDictEqual(
flop_dict, gt_dict,
'Matmul operation failed to pass the flop count test.')
def test_matmul_broadcast(self) -> None:
"""Test flop count for operation matmul."""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment