From efe49e016bed4dc822b628a1169ce9f3635a9f8b Mon Sep 17 00:00:00 2001
From: Marcel Koch <marcel.koch@uni-muenster.de>
Date: Thu, 31 Jan 2019 16:41:25 +0100
Subject: [PATCH] more FMA

---
 .../dune/codegen/blockstructured/geometry.py  | 59 ++++++++++---------
 1 file changed, 32 insertions(+), 27 deletions(-)

diff --git a/python/dune/codegen/blockstructured/geometry.py b/python/dune/codegen/blockstructured/geometry.py
index 753dba0f..5ffd0949 100644
--- a/python/dune/codegen/blockstructured/geometry.py
+++ b/python/dune/codegen/blockstructured/geometry.py
@@ -1,9 +1,15 @@
+import pymbolic.primitives as prim
+from loopy.match import Writes
+
+from dune.codegen.blockstructured.tools import (sub_element_inames,
+                                                name_point_in_macro,
+                                                )
 from dune.codegen.generation import (geometry_mixin,
                                      temporary_variable,
                                      instruction,
                                      get_global_context_value,
                                      domain)
-from dune.codegen.tools import get_pymbolic_basename
+from dune.codegen.loopy.symbolic import FusedMultiplyAdd as FMA
 from dune.codegen.options import get_form_option
 from dune.codegen.pdelab.geometry import (AxiparallelGeometryMixin,
                                           EquidistantGeometryMixin,
@@ -16,12 +22,8 @@ from dune.codegen.pdelab.geometry import (AxiparallelGeometryMixin,
                                           name_cell_geometry
                                           )
 from dune.codegen.pdelab.tensors import name_matrix_inverse, name_determinant
-from dune.codegen.blockstructured.tools import (sub_element_inames,
-                                                name_point_in_macro,
-                                                )
+from dune.codegen.tools import get_pymbolic_basename
 from dune.codegen.ufl.modified_terminals import Restriction
-import pymbolic.primitives as prim
-from loopy.match import Writes
 
 
 @geometry_mixin("blockstructured_multilinear")
@@ -160,12 +162,12 @@ def compute_jacobian(name, visitor):
         a, b, c = coefficients
 
         expr_jac = [None, None]
-        expr_jac[0] = prim.Sum((prim.Product((prim.Subscript(pymbolic_pos, (1,)),
-                                              prim.Subscript(prim.Variable(a), (prim.Variable(jac_iname),)))),
-                                prim.Subscript(prim.Variable(b), (prim.Variable(jac_iname),))))
-        expr_jac[1] = prim.Sum((prim.Product((prim.Subscript(pymbolic_pos, (0,)),
-                                              prim.Subscript(prim.Variable(a), (prim.Variable(jac_iname),)))),
-                                prim.Subscript(prim.Variable(c), (prim.Variable(jac_iname),))))
+        expr_jac[0] = FMA(prim.Subscript(pymbolic_pos, (1,)),
+                          prim.Subscript(prim.Variable(a), (prim.Variable(jac_iname),)),
+                          prim.Subscript(prim.Variable(b), (prim.Variable(jac_iname),)))
+        expr_jac[1] = FMA(prim.Subscript(pymbolic_pos, (0,)),
+                          prim.Subscript(prim.Variable(a), (prim.Variable(jac_iname),)),
+                          prim.Subscript(prim.Variable(c), (prim.Variable(jac_iname),)))
     elif world_dimension() == 3:
         a, b, c, d, e, f, g = coefficients
 
@@ -176,14 +178,13 @@ def compute_jacobian(name, visitor):
         # with k, l in {0,1,2} != i and k<l and vj = terms[i][j]
         for i in range(3):
             k, l = sorted(set(range(3)) - {i})
-            expr_jac[i] = prim.Sum((prim.Product((prim.Subscript(pymbolic_pos, (k,)), prim.Subscript(pymbolic_pos, (l,)),
-                                                  prim.Subscript(prim.Variable(a), (prim.Variable(jac_iname),)))),
-                                    prim.Product((prim.Subscript(pymbolic_pos, (k,)),
-                                                  prim.Subscript(prim.Variable(terms[i][0]), (prim.Variable(jac_iname),)))),
-                                    prim.Product((prim.Subscript(pymbolic_pos, (l,)),
-                                                  prim.Subscript(prim.Variable(terms[i][1]), (prim.Variable(jac_iname),)))),
-                                    prim.Subscript(prim.Variable(terms[i][2]), (prim.Variable(jac_iname),))
-                                    ))
+            expr_jac[i] = FMA(prim.Subscript(prim.Variable(a), (prim.Variable(jac_iname),)),
+                              prim.Subscript(pymbolic_pos, (k,)) * prim.Subscript(pymbolic_pos, (l,)),
+                              FMA(prim.Subscript(prim.Variable(terms[i][0]), (prim.Variable(jac_iname),)),
+                                  prim.Subscript(pymbolic_pos, (k,)),
+                                  FMA(prim.Subscript(prim.Variable(terms[i][1]), (prim.Variable(jac_iname),)),
+                                      prim.Subscript(pymbolic_pos, (l,)),
+                                      prim.Subscript(prim.Variable(terms[i][2]), (prim.Variable(jac_iname),)))))
     else:
         raise NotImplementedError()
 
@@ -225,12 +226,15 @@ def compute_multilinear_to_global_transformation(name, local, visitor):
     # global[d] = T(local)[d]
     if dim == 2:
         a_pym, b_pym, c_pym = coeffs_pym
-        expr = a_pym * local_pym[0] * local_pym[1] + b_pym * local_pym[0] + c_pym * local_pym[1] + corner_0_pym
+        expr = FMA(a_pym, local_pym[0] * local_pym[1], FMA(b_pym, local_pym[0], FMA(c_pym, local_pym[1], corner_0_pym)))
     elif dim == 3:
         a_pym, b_pym, c_pym, d_pym, e_pym, f_pym, g_pym = coeffs_pym
-        expr = (a_pym * local_pym[0] * local_pym[1] * local_pym[2] + b_pym * local_pym[0] * local_pym[1] +
-                c_pym * local_pym[0] * local_pym[2] + d_pym * local_pym[1] * local_pym[2] +
-                e_pym * local_pym[0] + f_pym * local_pym[1] + g_pym * local_pym[2] + corner_0_pym)
+        expr = FMA(a_pym * local_pym[0], local_pym[1] * local_pym[2],
+                   FMA(b_pym, local_pym[0] * local_pym[1],
+                       FMA(c_pym, local_pym[0] * local_pym[2],
+                           FMA(d_pym, local_pym[1] * local_pym[2],
+                               FMA(e_pym, local_pym[0],
+                                   FMA(f_pym, local_pym[1], FMA(g_pym, local_pym[2], corner_0_pym)))))))
     else:
         raise NotImplementedError
 
@@ -254,9 +258,10 @@ def compute_axiparallel_to_global_transformation(name, local, visitor):
     dim_pym = prim.Variable(component_iname('to_global'))
 
     # global[d] = lower_left[d] + local[d] * (upper_right[d] - lower_left[d])
-    expr = (prim.Subscript(prim.Variable(corners), (0, dim_pym)) +
-            prim.Subscript(local, (dim_pym,)) * (prim.Subscript(prim.Variable(corners), (2**dim - 1, dim_pym)) -
-                                                 prim.Subscript(prim.Variable(corners), (0, dim_pym))))
+    expr = FMA(prim.Subscript(prim.Variable(corners), (2**dim - 1, dim_pym)) -
+               prim.Subscript(prim.Variable(corners), (0, dim_pym)),
+               prim.Subscript(local, (dim_pym,)), prim.Subscript(prim.Variable(corners), (0, dim_pym)))
+
     assignee = prim.Subscript(prim.Variable(name), (dim_pym,))
 
     instruction(assignee=assignee, expression=expr,
-- 
GitLab