Skip to content
Snippets Groups Projects
Commit dd00c10b authored by Marcel Koch's avatar Marcel Koch
Browse files

use fma for determinant

3d -> 5 fma + 4 mul (without -1 *)
parent 33672520
No related branches found
No related tags found
No related merge requests found
...@@ -7,7 +7,7 @@ from dune.codegen.generation import (get_counted_variable, ...@@ -7,7 +7,7 @@ from dune.codegen.generation import (get_counted_variable,
instruction, instruction,
temporary_variable, temporary_variable,
) )
from dune.codegen.loopy.symbolic import FusedMultiplyAdd as FMA
from loopy.match import Writes from loopy.match import Writes
import pymbolic.primitives as prim import pymbolic.primitives as prim
...@@ -24,16 +24,16 @@ def define_determinant(name, matrix, shape, visitor): ...@@ -24,16 +24,16 @@ def define_determinant(name, matrix, shape, visitor):
matrix_entry = [[prim.Subscript(prim.Variable(matrix), (i, j)) for j in range(dim)] for i in range(dim)] matrix_entry = [[prim.Subscript(prim.Variable(matrix), (i, j)) for j in range(dim)] for i in range(dim)]
if dim == 2: if dim == 2:
expr_determinant = prim.Sum((prim.Product((matrix_entry[0][0], matrix_entry[1][1])), expr_determinant = FMA(matrix_entry[0][0], matrix_entry[1][1],
-1 * prim.Product((matrix_entry[1][0], matrix_entry[0][1])))) -1 * prim.Product((matrix_entry[1][0], matrix_entry[0][1])))
elif dim == 3: elif dim == 3:
expr_determinant = prim.Sum((prim.Product((matrix_entry[0][0], matrix_entry[1][1], matrix_entry[2][2])), fma_A = FMA(matrix_entry[1][1], matrix_entry[2][2], -1 * matrix_entry[1][2] * matrix_entry[2][1])
prim.Product((matrix_entry[0][1], matrix_entry[1][2], matrix_entry[2][0])), fma_B = FMA(matrix_entry[1][0], matrix_entry[2][2], -1 * matrix_entry[1][2] * matrix_entry[2][0])
prim.Product((matrix_entry[0][2], matrix_entry[1][0], matrix_entry[2][1])), fma_C = FMA(matrix_entry[1][0], matrix_entry[2][1], -1 * matrix_entry[1][1] * matrix_entry[2][0])
-1 * prim.Product((matrix_entry[0][2], matrix_entry[1][1], matrix_entry[2][0])),
-1 * prim.Product((matrix_entry[0][0], matrix_entry[1][2], matrix_entry[2][1])), expr_determinant = FMA(matrix_entry[0][2], fma_C,
-1 * prim.Product((matrix_entry[0][1], matrix_entry[1][0], matrix_entry[2][2])) FMA(matrix_entry[0][0], fma_A, -1 * matrix_entry[0][1] * fma_B))
))
else: else:
raise NotImplementedError() raise NotImplementedError()
instruction(expression=expr_determinant, instruction(expression=expr_determinant,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment