From b2d3ef49aa96d820ba0f666ab08ee8a8ae7ddefc Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Thu, 15 Dec 2016 17:22:44 +0100 Subject: [PATCH] Add matchfma.py --- .../loopy/transformations/matchfma.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 python/dune/perftool/loopy/transformations/matchfma.py diff --git a/python/dune/perftool/loopy/transformations/matchfma.py b/python/dune/perftool/loopy/transformations/matchfma.py new file mode 100644 index 00000000..ee1d81eb --- /dev/null +++ b/python/dune/perftool/loopy/transformations/matchfma.py @@ -0,0 +1,34 @@ +""" Match FMA in expressions! """ + +from dune.perftool.loopy.symbolic import FusedMultiplyAdd as FMA +from loopy.symbolic import SubstitutionMapper + +import loopy as lp +import pymbolic.primitives as prim + + +class FMASubstitutionMapper(SubstitutionMapper): + def map_sum(self, expr): + if len(expr.children) == 2: + c1, c2 = expr.children + if isinstance(c1, prim.Product) and len(c1.children) == 2: + return FMA(self.rec(c1.children[0]), self.rec(c1.children[1]), self.rec(c2)) + if isinstance(c2, prim.Product) and len(c2.children) == 2: + return FMA(self.rec(c2.children[0]), self.rec(c2.children[1]), self.rec(c1)) + return SubstitutionMapper.map_sum(self, expr) + + +def substitute_fma(expr): + return FMASubstitutionMapper(lambda x: x)(expr) + + +def match_fused_multiply_add(knl): + new_insns = [] + + for insn in knl.instructions: + if isinstance(insn, lp.Assignment): + new_insns.append(insn.copy(expression=substitute_fma(insn.expression))) + else: + new_insns.append(insn) + + return knl.copy(instructions=new_insns) -- GitLab