From bee8a5f4b5cb5d02d96754f10224a6fdd7477fcb Mon Sep 17 00:00:00 2001 From: Marcel Koch <marcel.koch@uni-muenster.de> Date: Thu, 31 Jan 2019 14:19:25 +0100 Subject: [PATCH] use FMS for determinant and inverse --- python/dune/codegen/pdelab/tensors.py | 52 +++++++++++++-------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/python/dune/codegen/pdelab/tensors.py b/python/dune/codegen/pdelab/tensors.py index fcb139cf..3a8c377b 100644 --- a/python/dune/codegen/pdelab/tensors.py +++ b/python/dune/codegen/pdelab/tensors.py @@ -8,6 +8,7 @@ from dune.codegen.generation import (get_counted_variable, temporary_variable, ) from dune.codegen.loopy.symbolic import FusedMultiplyAdd as FMA +from dune.codegen.loopy.symbolic import FusedMultiplySub as FMS from loopy.match import Writes import pymbolic.primitives as prim @@ -24,15 +25,14 @@ def define_determinant(name, matrix, shape, visitor): matrix_entry = [[prim.Subscript(prim.Variable(matrix), (i, j)) for j in range(dim)] for i in range(dim)] if dim == 2: - expr_determinant = FMA(matrix_entry[0][0], matrix_entry[1][1], -1 * matrix_entry[1][0] * matrix_entry[0][1]) + expr_determinant = FMS(matrix_entry[0][0], matrix_entry[1][1], matrix_entry[1][0] * matrix_entry[0][1]) elif dim == 3: - fma_A = FMA(matrix_entry[1][1], matrix_entry[2][2], -1 * matrix_entry[1][2] * matrix_entry[2][1]) - fma_B = FMA(matrix_entry[1][0], matrix_entry[2][2], -1 * matrix_entry[1][2] * matrix_entry[2][0]) - fma_C = FMA(matrix_entry[1][0], matrix_entry[2][1], -1 * matrix_entry[1][1] * matrix_entry[2][0]) + fma_A = FMS(matrix_entry[1][1], matrix_entry[2][2], matrix_entry[1][2] * matrix_entry[2][1]) + fma_B = FMS(matrix_entry[1][0], matrix_entry[2][2], matrix_entry[1][2] * matrix_entry[2][0]) + fma_C = FMS(matrix_entry[1][0], matrix_entry[2][1], matrix_entry[1][1] * matrix_entry[2][0]) - expr_determinant = FMA(matrix_entry[0][2], fma_C, - FMA(matrix_entry[0][0], fma_A, -1 * matrix_entry[0][1] * fma_B)) + expr_determinant = FMA(matrix_entry[0][2], fma_C, FMS(matrix_entry[0][0], fma_A, matrix_entry[0][1] * fma_B)) else: raise NotImplementedError() instruction(expression=expr_determinant, @@ -72,26 +72,26 @@ def define_matrix_inverse(name, name_inv, shape, visitor): sign = 1. if i == j else -1. exprs[i][j] = prim.Product((sign, prim.Variable(det_inv), matrix_entry[1 - i][1 - j])) elif dim == 3: - exprs[0][0] = prim.Variable(det_inv) * FMA(matrix_entry[1][1], matrix_entry[2][2], - -1 * matrix_entry[1][2] * matrix_entry[2][1]) - exprs[1][0] = prim.Variable(det_inv) * FMA(matrix_entry[0][1], matrix_entry[2][2], - -1 * matrix_entry[0][2] * matrix_entry[2][1]) * -1 - exprs[2][0] = prim.Variable(det_inv) * FMA(matrix_entry[0][1], matrix_entry[1][2], - -1 * matrix_entry[0][2] * matrix_entry[1][1]) - - exprs[0][1] = prim.Variable(det_inv) * FMA(matrix_entry[1][0], matrix_entry[2][2], - -1 * matrix_entry[1][2] * matrix_entry[2][0]) * -1 - exprs[1][1] = prim.Variable(det_inv) * FMA(matrix_entry[0][0], matrix_entry[2][2], - -1 * matrix_entry[0][2] * matrix_entry[2][0]) - exprs[2][1] = prim.Variable(det_inv) * FMA(matrix_entry[0][0], matrix_entry[1][2], - -1 * matrix_entry[0][2] * matrix_entry[1][0]) * -1 - - exprs[0][2] = prim.Variable(det_inv) * FMA(matrix_entry[1][0], matrix_entry[2][1], - -1 * matrix_entry[1][1] * matrix_entry[2][0]) - exprs[1][2] = prim.Variable(det_inv) * FMA(matrix_entry[0][0], matrix_entry[2][1], - -1 * matrix_entry[0][1] * matrix_entry[2][0]) * -1 - exprs[2][2] = prim.Variable(det_inv) * FMA(matrix_entry[0][0], matrix_entry[1][1], - -1 * matrix_entry[0][1] * matrix_entry[1][0]) + exprs[0][0] = prim.Variable(det_inv) * FMS(matrix_entry[1][1], matrix_entry[2][2], + matrix_entry[1][2] * matrix_entry[2][1]) + exprs[1][0] = prim.Variable(det_inv) * FMS(matrix_entry[0][1], matrix_entry[2][2], + matrix_entry[0][2] * matrix_entry[2][1]) * -1 + exprs[2][0] = prim.Variable(det_inv) * FMS(matrix_entry[0][1], matrix_entry[1][2], + matrix_entry[0][2] * matrix_entry[1][1]) + + exprs[0][1] = prim.Variable(det_inv) * FMS(matrix_entry[1][0], matrix_entry[2][2], + matrix_entry[1][2] * matrix_entry[2][0]) * -1 + exprs[1][1] = prim.Variable(det_inv) * FMS(matrix_entry[0][0], matrix_entry[2][2], + matrix_entry[0][2] * matrix_entry[2][0]) + exprs[2][1] = prim.Variable(det_inv) * FMS(matrix_entry[0][0], matrix_entry[1][2], + matrix_entry[0][2] * matrix_entry[1][0]) * -1 + + exprs[0][2] = prim.Variable(det_inv) * FMS(matrix_entry[1][0], matrix_entry[2][1], + matrix_entry[1][1] * matrix_entry[2][0]) + exprs[1][2] = prim.Variable(det_inv) * FMS(matrix_entry[0][0], matrix_entry[2][1], + matrix_entry[0][1] * matrix_entry[2][0]) * -1 + exprs[2][2] = prim.Variable(det_inv) * FMS(matrix_entry[0][0], matrix_entry[1][1], + matrix_entry[0][1] * matrix_entry[1][0]) else: raise NotImplementedError for j in range(dim): -- GitLab