diff --git a/mmengine/analysis/jit_handles.py b/mmengine/analysis/jit_handles.py
index 4f3dd696f2c4ebf13c56ffab5946953f494e05fd..917509d7e3e4060a70e7ee422b081d3a306f3a15 100644
--- a/mmengine/analysis/jit_handles.py
+++ b/mmengine/analysis/jit_handles.py
@@ -209,13 +209,16 @@ def einsum_flop_jit(inputs: List[Any], outputs: List[Any]) -> Union[int, Any]:
 
 def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Union[int, Any]:
     """Count flops for matmul."""
-    # Inputs should be a list of length 2.
-    # Inputs contains the shapes of two matrices.
-    input_shapes = [get_shape(v) for v in inputs]
-    assert len(input_shapes) == 2, input_shapes
-    assert input_shapes[0][-1] == input_shapes[1][  # type: ignore
-        -2], input_shapes  # type: ignore
-    flop = prod(input_shapes[0]) * input_shapes[-1][-1]  # type: ignore
+    # input_shapes is a list of length 2.
+    input_shapes: list = [get_shape(v) for v in inputs]
+    input1, input2 = input_shapes
+    if len(input1) == 1:
+        input1 = [1, input1[0]]
+    if len(input2) == 1:
+        input2 = [input2[0], 1]
+
+    assert input1[-1] == input2[-2], input_shapes
+    flop = prod(input1) * input2[-1]
     return flop
 
 
diff --git a/tests/test_analysis/test_flop_count.py b/tests/test_analysis/test_flop_count.py
index 0c0e8943cd362c2e251b4357759f8da9aa6c8ac2..20749a0babd17dbd4f958645e5259e18c59943e4 100644
--- a/tests/test_analysis/test_flop_count.py
+++ b/tests/test_analysis/test_flop_count.py
@@ -580,8 +580,6 @@ class TestFlopAnalyzer(unittest.TestCase):
             transpose=True,
             output_padding=output_padding9,
         )
-
-    def test_matmul(self) -> None:
         """Test flop count for operation matmul."""
         m = 20
         n = 10
@@ -596,6 +594,13 @@ class TestFlopAnalyzer(unittest.TestCase):
         self.assertDictEqual(
             flop_dict, gt_dict,
             'Matmul operation failed to pass the flop count test.')
+        # Test with single dimension y
+        y = torch.randn(n)
+        gt_dict['matmul'] = m * n * 1 / 1e9
+        flop_dict, _ = flop_count(m_net, (x, y))
+        self.assertDictEqual(
+            flop_dict, gt_dict,
+            'Matmul operation failed to pass the flop count test.')
 
     def test_matmul_broadcast(self) -> None:
         """Test flop count for operation matmul."""