From 54f1a482f2791a843f94533e18fbe9183899300e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20He=C3=9F?= <rene.hess@iwr.uni-heidelberg.de> Date: Fri, 8 Feb 2019 15:19:34 +0100 Subject: [PATCH] Loopy transformation replacinrd reductions with inplace accumulation --- .../transformations/remove_reductions.py | 58 +++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 python/dune/codegen/loopy/transformations/remove_reductions.py diff --git a/python/dune/codegen/loopy/transformations/remove_reductions.py b/python/dune/codegen/loopy/transformations/remove_reductions.py new file mode 100644 index 00000000..c49cb86f --- /dev/null +++ b/python/dune/codegen/loopy/transformations/remove_reductions.py @@ -0,0 +1,58 @@ +import loopy as lp + +import pymbolic.primitives as prim + +def remove_reduction(knl, match): + """Removes all matching reductions and do direct accumulation in assignee instead""" + instructions = [] + + # Find reductions + for instr in lp.find_instructions(knl, match): + if isinstance(instr.expression, lp.symbolic.Reduction): + # Remove the instruction from the kernel + knl = lp.remove_instructions(knl, set([instr.id])) + + # Add instruction that sets assignee to zero + id_zero = instr.id + '_set_zero' + instructions.append(lp.Assignment(instr.assignee, + 0, + within_inames=instr.within_inames, + id=id_zero, + tags=('set_zero',) + )) + + # Add instruction that accumulates directly in assignee + assignee = instr.assignee + expression = prim.Sum((assignee, instr.expression.expr)) + within_inames = frozenset(tuple(instr.within_inames) + instr.expression.inames) + id_accum = instr.id + '_accum' + instructions.append(lp.Assignment(assignee, + expression, + within_inames=within_inames, + id=id_accum, + depends_on=frozenset((id_zero,)), + tags=('assignement',))) + + + knl = knl.copy(instructions=knl.instructions + instructions) + return knl + + +def remove_all_reductions(knl): + """Remove all reductions from loopy kernel + + This removes all reductions by instead setting the assignee to zero and + directly accumulating in the assignee. + """ + # Find ids of all reductions + ids = [] + for instr in knl.instructions: + if isinstance(instr.expression, lp.symbolic.Reduction): + ids.append(instr.id) + + # Remove reductions + for id in ids: + match = lp.match.Id(id) + knl = remove_reduction(knl, match) + + return knl -- GitLab