From dd00c10b822816ff5d7bbf49c213867d7e510b9c Mon Sep 17 00:00:00 2001
From: Marcel Koch <marcel.koch@uni-muenster.de>
Date: Thu, 31 Jan 2019 11:17:30 +0100
Subject: [PATCH] use fma for determinant

3d -> 5 fma + 4 mul (without -1 *)
---
 python/dune/codegen/pdelab/tensors.py | 20 ++++++++++----------
 1 file changed, 10 insertions(+), 10 deletions(-)

diff --git a/python/dune/codegen/pdelab/tensors.py b/python/dune/codegen/pdelab/tensors.py
index fa9d52e6..a1a52a4c 100644
--- a/python/dune/codegen/pdelab/tensors.py
+++ b/python/dune/codegen/pdelab/tensors.py
@@ -7,7 +7,7 @@ from dune.codegen.generation import (get_counted_variable,
                                      instruction,
                                      temporary_variable,
                                      )
-
+from dune.codegen.loopy.symbolic import FusedMultiplyAdd as FMA
 from loopy.match import Writes
 
 import pymbolic.primitives as prim
@@ -24,16 +24,16 @@ def define_determinant(name, matrix, shape, visitor):
 
     matrix_entry = [[prim.Subscript(prim.Variable(matrix), (i, j)) for j in range(dim)] for i in range(dim)]
     if dim == 2:
-        expr_determinant = prim.Sum((prim.Product((matrix_entry[0][0], matrix_entry[1][1])),
-                                     -1 * prim.Product((matrix_entry[1][0], matrix_entry[0][1]))))
+        expr_determinant = FMA(matrix_entry[0][0], matrix_entry[1][1],
+                               -1 * prim.Product((matrix_entry[1][0], matrix_entry[0][1])))
+
     elif dim == 3:
-        expr_determinant = prim.Sum((prim.Product((matrix_entry[0][0], matrix_entry[1][1], matrix_entry[2][2])),
-                                     prim.Product((matrix_entry[0][1], matrix_entry[1][2], matrix_entry[2][0])),
-                                     prim.Product((matrix_entry[0][2], matrix_entry[1][0], matrix_entry[2][1])),
-                                     -1 * prim.Product((matrix_entry[0][2], matrix_entry[1][1], matrix_entry[2][0])),
-                                     -1 * prim.Product((matrix_entry[0][0], matrix_entry[1][2], matrix_entry[2][1])),
-                                     -1 * prim.Product((matrix_entry[0][1], matrix_entry[1][0], matrix_entry[2][2]))
-                                     ))
+        fma_A = FMA(matrix_entry[1][1], matrix_entry[2][2], -1 * matrix_entry[1][2] * matrix_entry[2][1])
+        fma_B = FMA(matrix_entry[1][0], matrix_entry[2][2], -1 * matrix_entry[1][2] * matrix_entry[2][0])
+        fma_C = FMA(matrix_entry[1][0], matrix_entry[2][1], -1 * matrix_entry[1][1] * matrix_entry[2][0])
+
+        expr_determinant = FMA(matrix_entry[0][2], fma_C,
+                               FMA(matrix_entry[0][0], fma_A, -1 * matrix_entry[0][1] * fma_B))
     else:
         raise NotImplementedError()
     instruction(expression=expr_determinant,
-- 
GitLab