From 1860994911d8382fc22af4a01c485a797ccfbc74 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Thu, 15 Dec 2016 16:56:13 +0100
Subject: [PATCH] FMA working!

---
 python/dune/perftool/loopy/symbolic.py        |  8 ++++++--
 python/dune/perftool/loopy/target.py          | 11 ++++++++---
 python/dune/perftool/loopy/vcl.py             | 10 ++++++++++
 python/dune/perftool/pdelab/localoperator.py  |  3 +++
 python/dune/perftool/sumfact/vectorization.py |  2 +-
 5 files changed, 28 insertions(+), 6 deletions(-)

diff --git a/python/dune/perftool/loopy/symbolic.py b/python/dune/perftool/loopy/symbolic.py
index 58768edf..131bed1c 100644
--- a/python/dune/perftool/loopy/symbolic.py
+++ b/python/dune/perftool/loopy/symbolic.py
@@ -58,7 +58,7 @@ class FusedMultiplyAdd(prim.Expression):
         return (self.mul_op1, self.mul_op2, self.add_op)
 
     def stringifier(self):
-        return StringifyMapper
+        return lp.symbolic.StringifyMapper
 
     mapper_method = intern("map_fused_multiply_add")
 
@@ -122,6 +122,10 @@ def type_inference_fused_multiply_add(self, expr):
     return self.rec(expr.mul_op1)
 
 
+def vectorizability_map_fused_multiply_add(self, expr):
+    return all((self.rec(expr.mul_op1), self.rec(expr.mul_op2), self.rec(expr.add_op)))
+
+
 #
 # Do the actual monkey patching!!!
 #
@@ -141,7 +145,7 @@ lp.symbolic.WalkMapper.map_fused_multiply_add = walk_map_fused_multiply_add
 lp.symbolic.StringifyMapper.map_fused_multiply_add = stringify_map_fused_multiply_add
 lp.symbolic.DependencyMapper.map_fused_multiply_add = dependency_map_fused_multiply_add
 lp.type_inference.TypeInferenceMapper.map_fused_multiply_add = type_inference_fused_multiply_add
-
+lp.expression.VectorizabilityChecker.map_fused_multiply_add = vectorizability_map_fused_multiply_add
 
 #
 # Some helper functions!
diff --git a/python/dune/perftool/loopy/target.py b/python/dune/perftool/loopy/target.py
index 18ebddbf..0d46b7e2 100644
--- a/python/dune/perftool/loopy/target.py
+++ b/python/dune/perftool/loopy/target.py
@@ -83,9 +83,14 @@ class DuneExpressionToCExpressionMapper(ExpressionToCExpressionMapper):
         return ret
 
     def map_fused_multiply_add(self, expr, type_context):
-        # Default implementation that discards the node in favor of the resp.
-        # additions and multiplications.
-        return self.rec(expr.mul_op1 * expr.mul_op2 + expr.add_op, type_context)
+        if self.codegen_state.vectorization_info:
+            # If this is vectorized we call the VCL function mul_add
+            return prim.Call(prim.Variable("mul_add"),
+                             (self.rec(expr.mul_op1), self.rec(expr.mul_op2), self.rec(expr.add_op)))
+        else:
+            # Default implementation that discards the node in favor of the resp.
+            # additions and multiplications.
+            return self.rec(expr.mul_op1 * expr.mul_op2 + expr.add_op, type_context)
 
 
 class DuneCExpressionToCodeMapper(CExpressionToCodeMapper):
diff --git a/python/dune/perftool/loopy/vcl.py b/python/dune/perftool/loopy/vcl.py
index 78dea3cf..2d8bf6cc 100644
--- a/python/dune/perftool/loopy/vcl.py
+++ b/python/dune/perftool/loopy/vcl.py
@@ -59,3 +59,13 @@ def get_vcl_type(nptype, register_size=None, vec_size=None):
         vec_size = register_size // (np.dtype(nptype).itemsize * 8)
 
     return VCLTypeRegistry.types[np.dtype(nptype), vec_size]
+
+
+@function_mangler
+def vcl_mul_add(knl, func, arg_dtypes):
+    if func == "mul_add":
+        # This is not 100% within the loopy philosophy, as we are
+        # passing the vector registers as references and have them
+        # changed. Loopy assumes this function to be read-only.
+        vcl = lp.types.NumpyType(get_vcl_type(np.float64, register_size=256))
+        return lp.CallMangleInfo("mul_add", (vcl), (vcl, vcl, vcl))
diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py
index 4707d96f..e214b523 100644
--- a/python/dune/perftool/pdelab/localoperator.py
+++ b/python/dune/perftool/pdelab/localoperator.py
@@ -500,6 +500,9 @@ def extract_kernel_from_cache(tag, wrap_in_cgen=True):
     # *REALLY* ignore boostability. This is - so far - necessary due to a mystery bug.
     kernel = kernel.copy(instructions=[i.copy(boostable=False, boostable_into=frozenset()) for i in kernel.instructions])
 
+    from dune.perftool.loopy.transformations.matchfma import match_fused_multiply_add
+    kernel = match_fused_multiply_add(kernel)
+
     if wrap_in_cgen:
         # Wrap the kernel in something which can generate code
         from dune.perftool.pdelab.signatures import assembly_routine_signature
diff --git a/python/dune/perftool/sumfact/vectorization.py b/python/dune/perftool/sumfact/vectorization.py
index 6af1c14d..100f3b21 100644
--- a/python/dune/perftool/sumfact/vectorization.py
+++ b/python/dune/perftool/sumfact/vectorization.py
@@ -136,4 +136,4 @@ class HasSumfactMapper(lp.symbolic.CombineMapper):
 
 
 def find_sumfact(expr):
-    return HasSumfactMapper()(expr)
\ No newline at end of file
+    return HasSumfactMapper()(expr)
-- 
GitLab