From b2d3ef49aa96d820ba0f666ab08ee8a8ae7ddefc Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Thu, 15 Dec 2016 17:22:44 +0100
Subject: [PATCH] Add matchfma.py

---
 .../loopy/transformations/matchfma.py         | 34 +++++++++++++++++++
 1 file changed, 34 insertions(+)
 create mode 100644 python/dune/perftool/loopy/transformations/matchfma.py

diff --git a/python/dune/perftool/loopy/transformations/matchfma.py b/python/dune/perftool/loopy/transformations/matchfma.py
new file mode 100644
index 00000000..ee1d81eb
--- /dev/null
+++ b/python/dune/perftool/loopy/transformations/matchfma.py
@@ -0,0 +1,34 @@
+""" Match FMA in expressions! """
+
+from dune.perftool.loopy.symbolic import FusedMultiplyAdd as FMA
+from loopy.symbolic import SubstitutionMapper
+
+import loopy as lp
+import pymbolic.primitives as prim
+
+
+class FMASubstitutionMapper(SubstitutionMapper):
+    def map_sum(self, expr):
+        if len(expr.children) == 2:
+            c1, c2 = expr.children
+            if isinstance(c1, prim.Product) and len(c1.children) == 2:
+                return FMA(self.rec(c1.children[0]), self.rec(c1.children[1]), self.rec(c2))
+            if isinstance(c2, prim.Product) and len(c2.children) == 2:
+                return FMA(self.rec(c2.children[0]), self.rec(c2.children[1]), self.rec(c1))
+        return SubstitutionMapper.map_sum(self, expr)
+
+
+def substitute_fma(expr):
+    return FMASubstitutionMapper(lambda x: x)(expr)
+
+
+def match_fused_multiply_add(knl):
+    new_insns = []
+
+    for insn in knl.instructions:
+        if isinstance(insn, lp.Assignment):
+            new_insns.append(insn.copy(expression=substitute_fma(insn.expression)))
+        else:
+            new_insns.append(insn)
+
+    return knl.copy(instructions=new_insns)
-- 
GitLab