Skip to content
Snippets Groups Projects
Commit efe49e01 authored by Marcel Koch's avatar Marcel Koch
Browse files

more FMA

parent 15a35f91
No related branches found
No related tags found
No related merge requests found
import pymbolic.primitives as prim
from loopy.match import Writes
from dune.codegen.blockstructured.tools import (sub_element_inames,
name_point_in_macro,
)
from dune.codegen.generation import (geometry_mixin, from dune.codegen.generation import (geometry_mixin,
temporary_variable, temporary_variable,
instruction, instruction,
get_global_context_value, get_global_context_value,
domain) domain)
from dune.codegen.tools import get_pymbolic_basename from dune.codegen.loopy.symbolic import FusedMultiplyAdd as FMA
from dune.codegen.options import get_form_option from dune.codegen.options import get_form_option
from dune.codegen.pdelab.geometry import (AxiparallelGeometryMixin, from dune.codegen.pdelab.geometry import (AxiparallelGeometryMixin,
EquidistantGeometryMixin, EquidistantGeometryMixin,
...@@ -16,12 +22,8 @@ from dune.codegen.pdelab.geometry import (AxiparallelGeometryMixin, ...@@ -16,12 +22,8 @@ from dune.codegen.pdelab.geometry import (AxiparallelGeometryMixin,
name_cell_geometry name_cell_geometry
) )
from dune.codegen.pdelab.tensors import name_matrix_inverse, name_determinant from dune.codegen.pdelab.tensors import name_matrix_inverse, name_determinant
from dune.codegen.blockstructured.tools import (sub_element_inames, from dune.codegen.tools import get_pymbolic_basename
name_point_in_macro,
)
from dune.codegen.ufl.modified_terminals import Restriction from dune.codegen.ufl.modified_terminals import Restriction
import pymbolic.primitives as prim
from loopy.match import Writes
@geometry_mixin("blockstructured_multilinear") @geometry_mixin("blockstructured_multilinear")
...@@ -160,12 +162,12 @@ def compute_jacobian(name, visitor): ...@@ -160,12 +162,12 @@ def compute_jacobian(name, visitor):
a, b, c = coefficients a, b, c = coefficients
expr_jac = [None, None] expr_jac = [None, None]
expr_jac[0] = prim.Sum((prim.Product((prim.Subscript(pymbolic_pos, (1,)), expr_jac[0] = FMA(prim.Subscript(pymbolic_pos, (1,)),
prim.Subscript(prim.Variable(a), (prim.Variable(jac_iname),)))), prim.Subscript(prim.Variable(a), (prim.Variable(jac_iname),)),
prim.Subscript(prim.Variable(b), (prim.Variable(jac_iname),)))) prim.Subscript(prim.Variable(b), (prim.Variable(jac_iname),)))
expr_jac[1] = prim.Sum((prim.Product((prim.Subscript(pymbolic_pos, (0,)), expr_jac[1] = FMA(prim.Subscript(pymbolic_pos, (0,)),
prim.Subscript(prim.Variable(a), (prim.Variable(jac_iname),)))), prim.Subscript(prim.Variable(a), (prim.Variable(jac_iname),)),
prim.Subscript(prim.Variable(c), (prim.Variable(jac_iname),)))) prim.Subscript(prim.Variable(c), (prim.Variable(jac_iname),)))
elif world_dimension() == 3: elif world_dimension() == 3:
a, b, c, d, e, f, g = coefficients a, b, c, d, e, f, g = coefficients
...@@ -176,14 +178,13 @@ def compute_jacobian(name, visitor): ...@@ -176,14 +178,13 @@ def compute_jacobian(name, visitor):
# with k, l in {0,1,2} != i and k<l and vj = terms[i][j] # with k, l in {0,1,2} != i and k<l and vj = terms[i][j]
for i in range(3): for i in range(3):
k, l = sorted(set(range(3)) - {i}) k, l = sorted(set(range(3)) - {i})
expr_jac[i] = prim.Sum((prim.Product((prim.Subscript(pymbolic_pos, (k,)), prim.Subscript(pymbolic_pos, (l,)), expr_jac[i] = FMA(prim.Subscript(prim.Variable(a), (prim.Variable(jac_iname),)),
prim.Subscript(prim.Variable(a), (prim.Variable(jac_iname),)))), prim.Subscript(pymbolic_pos, (k,)) * prim.Subscript(pymbolic_pos, (l,)),
prim.Product((prim.Subscript(pymbolic_pos, (k,)), FMA(prim.Subscript(prim.Variable(terms[i][0]), (prim.Variable(jac_iname),)),
prim.Subscript(prim.Variable(terms[i][0]), (prim.Variable(jac_iname),)))), prim.Subscript(pymbolic_pos, (k,)),
prim.Product((prim.Subscript(pymbolic_pos, (l,)), FMA(prim.Subscript(prim.Variable(terms[i][1]), (prim.Variable(jac_iname),)),
prim.Subscript(prim.Variable(terms[i][1]), (prim.Variable(jac_iname),)))), prim.Subscript(pymbolic_pos, (l,)),
prim.Subscript(prim.Variable(terms[i][2]), (prim.Variable(jac_iname),)) prim.Subscript(prim.Variable(terms[i][2]), (prim.Variable(jac_iname),)))))
))
else: else:
raise NotImplementedError() raise NotImplementedError()
...@@ -225,12 +226,15 @@ def compute_multilinear_to_global_transformation(name, local, visitor): ...@@ -225,12 +226,15 @@ def compute_multilinear_to_global_transformation(name, local, visitor):
# global[d] = T(local)[d] # global[d] = T(local)[d]
if dim == 2: if dim == 2:
a_pym, b_pym, c_pym = coeffs_pym a_pym, b_pym, c_pym = coeffs_pym
expr = a_pym * local_pym[0] * local_pym[1] + b_pym * local_pym[0] + c_pym * local_pym[1] + corner_0_pym expr = FMA(a_pym, local_pym[0] * local_pym[1], FMA(b_pym, local_pym[0], FMA(c_pym, local_pym[1], corner_0_pym)))
elif dim == 3: elif dim == 3:
a_pym, b_pym, c_pym, d_pym, e_pym, f_pym, g_pym = coeffs_pym a_pym, b_pym, c_pym, d_pym, e_pym, f_pym, g_pym = coeffs_pym
expr = (a_pym * local_pym[0] * local_pym[1] * local_pym[2] + b_pym * local_pym[0] * local_pym[1] + expr = FMA(a_pym * local_pym[0], local_pym[1] * local_pym[2],
c_pym * local_pym[0] * local_pym[2] + d_pym * local_pym[1] * local_pym[2] + FMA(b_pym, local_pym[0] * local_pym[1],
e_pym * local_pym[0] + f_pym * local_pym[1] + g_pym * local_pym[2] + corner_0_pym) FMA(c_pym, local_pym[0] * local_pym[2],
FMA(d_pym, local_pym[1] * local_pym[2],
FMA(e_pym, local_pym[0],
FMA(f_pym, local_pym[1], FMA(g_pym, local_pym[2], corner_0_pym)))))))
else: else:
raise NotImplementedError raise NotImplementedError
...@@ -254,9 +258,10 @@ def compute_axiparallel_to_global_transformation(name, local, visitor): ...@@ -254,9 +258,10 @@ def compute_axiparallel_to_global_transformation(name, local, visitor):
dim_pym = prim.Variable(component_iname('to_global')) dim_pym = prim.Variable(component_iname('to_global'))
# global[d] = lower_left[d] + local[d] * (upper_right[d] - lower_left[d]) # global[d] = lower_left[d] + local[d] * (upper_right[d] - lower_left[d])
expr = (prim.Subscript(prim.Variable(corners), (0, dim_pym)) + expr = FMA(prim.Subscript(prim.Variable(corners), (2**dim - 1, dim_pym)) -
prim.Subscript(local, (dim_pym,)) * (prim.Subscript(prim.Variable(corners), (2**dim - 1, dim_pym)) - prim.Subscript(prim.Variable(corners), (0, dim_pym)),
prim.Subscript(prim.Variable(corners), (0, dim_pym)))) prim.Subscript(local, (dim_pym,)), prim.Subscript(prim.Variable(corners), (0, dim_pym)))
assignee = prim.Subscript(prim.Variable(name), (dim_pym,)) assignee = prim.Subscript(prim.Variable(name), (dim_pym,))
instruction(assignee=assignee, expression=expr, instruction(assignee=assignee, expression=expr,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment