diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py index 539d2f698b8c23c5f13b30f66229def5137705be..4205aedb8276c4f1a3bcf61fc57985a47e756935 100644 --- a/python/dune/perftool/sumfact/sumfact.py +++ b/python/dune/perftool/sumfact/sumfact.py @@ -308,7 +308,6 @@ def sum_factorization_kernel(a_matrices, buf, stage, # Get the inames needed for one matrix-matrix multiplication i = sumfact_iname(out_shape[0], "row") j = sumfact_iname(out_shape[1], "col") - k = sumfact_iname(a_matrix.cols, "red") # Maybe introduce a vectorization iname for this matrix-matrix multiplication vec_iname = () @@ -317,15 +316,27 @@ def sum_factorization_kernel(a_matrices, buf, stage, vec_iname = (prim.Variable(iname),) transform(lp.tag_inames, [(iname, "vec")]) - # Construct the matrix-matrix-multiplication expression a_ik*in_kj - prod = Product((Subscript(Variable(a_matrix.name), (Variable(i), Variable(k)) + vec_iname), - Subscript(Variable(inp), (Variable(k), Variable(j)) + vec_iname) - )) + if a_matrix.cols == 1: + # A trivial reduction is implemented as a product, otherwise we run into + # a code generation corner case producing way too complicated code. This + # could be fixed upstream, but the loopy code realizing reductions is not + # trivial and the priority is kind of low. + matprod = Product((Subscript(Variable(a_matrix.name), (Variable(i), 0) + vec_iname), + Subscript(Variable(inp), (0, Variable(j)) + vec_iname) + )) + else: + k = sumfact_iname(a_matrix.cols, "red") + + # Construct the matrix-matrix-multiplication expression a_ik*in_kj + prod = Product((Subscript(Variable(a_matrix.name), (Variable(i), Variable(k)) + vec_iname), + Subscript(Variable(inp), (Variable(k), Variable(j)) + vec_iname) + )) + matprod = Reduction("sum", k, prod) # Issue the reduction instruction that implements the multiplication # at the same time store the instruction ID for the next instruction to depend on insn_dep = frozenset({instruction(assignee=Subscript(Variable(out), (Variable(i), Variable(j)) + vec_iname), - expression=Reduction("sum", k, prod), + expression=matprod, forced_iname_deps=frozenset({i, j}).union(additional_inames), forced_iname_deps_is_final=True, depends_on=insn_dep,