Skip to content
Snippets Groups Projects
Commit 227304ca authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Eliminate trivial reductions

For some reason, the patch refuses to schedule the kernel for sumfact_poisson_dg_symdiff.
parent 0f0d69d0
No related branches found
No related tags found
No related merge requests found
...@@ -308,7 +308,6 @@ def sum_factorization_kernel(a_matrices, buf, stage, ...@@ -308,7 +308,6 @@ def sum_factorization_kernel(a_matrices, buf, stage,
# Get the inames needed for one matrix-matrix multiplication # Get the inames needed for one matrix-matrix multiplication
i = sumfact_iname(out_shape[0], "row") i = sumfact_iname(out_shape[0], "row")
j = sumfact_iname(out_shape[1], "col") j = sumfact_iname(out_shape[1], "col")
k = sumfact_iname(a_matrix.cols, "red")
# Maybe introduce a vectorization iname for this matrix-matrix multiplication # Maybe introduce a vectorization iname for this matrix-matrix multiplication
vec_iname = () vec_iname = ()
...@@ -317,15 +316,27 @@ def sum_factorization_kernel(a_matrices, buf, stage, ...@@ -317,15 +316,27 @@ def sum_factorization_kernel(a_matrices, buf, stage,
vec_iname = (prim.Variable(iname),) vec_iname = (prim.Variable(iname),)
transform(lp.tag_inames, [(iname, "vec")]) transform(lp.tag_inames, [(iname, "vec")])
# Construct the matrix-matrix-multiplication expression a_ik*in_kj if a_matrix.cols == 1:
prod = Product((Subscript(Variable(a_matrix.name), (Variable(i), Variable(k)) + vec_iname), # A trivial reduction is implemented as a product, otherwise we run into
Subscript(Variable(inp), (Variable(k), Variable(j)) + vec_iname) # a code generation corner case producing way too complicated code. This
)) # could be fixed upstream, but the loopy code realizing reductions is not
# trivial and the priority is kind of low.
matprod = Product((Subscript(Variable(a_matrix.name), (Variable(i), 0) + vec_iname),
Subscript(Variable(inp), (0, Variable(j)) + vec_iname)
))
else:
k = sumfact_iname(a_matrix.cols, "red")
# Construct the matrix-matrix-multiplication expression a_ik*in_kj
prod = Product((Subscript(Variable(a_matrix.name), (Variable(i), Variable(k)) + vec_iname),
Subscript(Variable(inp), (Variable(k), Variable(j)) + vec_iname)
))
matprod = Reduction("sum", k, prod)
# Issue the reduction instruction that implements the multiplication # Issue the reduction instruction that implements the multiplication
# at the same time store the instruction ID for the next instruction to depend on # at the same time store the instruction ID for the next instruction to depend on
insn_dep = frozenset({instruction(assignee=Subscript(Variable(out), (Variable(i), Variable(j)) + vec_iname), insn_dep = frozenset({instruction(assignee=Subscript(Variable(out), (Variable(i), Variable(j)) + vec_iname),
expression=Reduction("sum", k, prod), expression=matprod,
forced_iname_deps=frozenset({i, j}).union(additional_inames), forced_iname_deps=frozenset({i, j}).union(additional_inames),
forced_iname_deps_is_final=True, forced_iname_deps_is_final=True,
depends_on=insn_dep, depends_on=insn_dep,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment