Skip to content
Snippets Groups Projects
Commit 55bfc6b1 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Fix quadrature changes for master

parent 332c0a90
No related branches found
No related tags found
No related merge requests found
......@@ -16,9 +16,6 @@ from dune.perftool.pdelab.spaces import (lfs_child,
name_lfs_bound,
type_gfs,
)
from dune.perftool.pdelab.quadrature import (name_quadrature_position_in_cell,
quadrature_iname,
)
from dune.perftool.pdelab.geometry import (dimension_iname,
name_dimension,
name_jacobian_inverse_transposed,
......@@ -27,6 +24,7 @@ from dune.perftool.pdelab.geometry import (dimension_iname,
from dune.perftool.pdelab.localoperator import (lop_template_ansatz_gfs,
lop_template_test_gfs,
)
from dune.perftool.tools import get_pymbolic_basename
from dune.perftool.pdelab.driver import FEM_name_mangling
from dune.perftool.pdelab.restriction import restricted_name
from pymbolic.primitives import Product, Subscript, Variable
......@@ -71,11 +69,11 @@ def evaluate_basis(leaf_element, name, restriction):
instruction(inames=get_backend("quad_inames")(),
code='{} = {}.evaluateFunction({}, {}.finiteElement().localBasis());'.format(name,
cache,
qp,
str(qp),
lfs,
),
assignees=frozenset({name}),
read_variables=frozenset({qp}),
read_variables=frozenset({get_pymbolic_basename(qp)}),
)
......@@ -93,15 +91,15 @@ def evaluate_reference_gradient(leaf_element, name, restriction):
lfs = name_leaf_lfs(leaf_element, restriction)
temporary_variable(name, shape=(name_lfs_bound(lfs), 1, name_dimension()), decl_method=declare_cache_temporary(leaf_element, restriction, 'Jacobian'))
cache = name_localbasis_cache(leaf_element)
qp = name_quadrature_position_in_cell(restriction)
qp = get_backend("qp_in_cell")(restriction)
instruction(inames=get_backend("quad_inames")(),
code='{} = {}.evaluateJacobian({}, {}.finiteElement().localBasis());'.format(name,
cache,
qp,
str(qp),
lfs,
),
assignees=frozenset({name}),
read_variables=frozenset({qp}),
read_variables=frozenset({get_pymbolic_basename(qp)}),
)
......
......@@ -9,12 +9,13 @@ from dune.perftool.generation import (cached,
preamble,
temporary_variable,
)
from dune.perftool.pdelab.quadrature import (name_quadrature_position,
name_quadrature_position_in_cell,
from dune.perftool.pdelab.quadrature import (pymbolic_quadrature_position_in_cell,
quadrature_preamble,
)
from dune.perftool.tools import get_pymbolic_basename
from ufl.algorithms import MultiFunction
from pymbolic.primitives import Variable
from pymbolic.primitives import Expression as PymbolicExpression
@preamble
......@@ -194,25 +195,27 @@ def apply_in_cell_transformation(name, local, restriction):
geo = name_in_cell_geometry(restriction)
return quadrature_preamble("{} = {}.global({});".format(name,
geo,
local,
str(local),
),
assignees=frozenset({name}),
read_variables=frozenset({local}),
read_variables=frozenset({get_pymbolic_basename(local)}),
)
def name_in_cell_coordinates(local, basename, restriction):
def pymbolic_in_cell_coordinates(local, restriction):
basename = get_pymbolic_basename(local)
name = "{}_in_{}side".format(basename, "in" if restriction is Restriction.NEGATIVE else "out")
temporary_variable(name, shape=(name_dimension(),), shape_impl=("fv",))
apply_in_cell_transformation(name, local, restriction)
return name
return Variable(name)
def to_cell_coordinates(local, basename, restriction):
def to_cell_coordinates(local, restriction):
assert isinstance(local, PymbolicExpression)
if restriction == Restriction.NONE:
return local
else:
return name_in_cell_coordinates(local, basename, restriction)
return pymbolic_in_cell_coordinates(local, restriction)
def name_dimension():
......@@ -227,10 +230,10 @@ def name_intersection_dimension():
def evaluate_unit_outer_normal(name):
ig = name_intersection_geometry_wrapper()
qp = name_quadrature_position()
qp = get_backend("quad_pos")()
return quadrature_preamble("{} = {}.unitOuterNormal({});".format(name, ig, qp),
assignees=frozenset({name}),
read_variables=frozenset({qp}),
read_variables=frozenset({get_pymbolic_basename(qp)}),
)
......@@ -285,10 +288,10 @@ def define_jacobian_inverse_transposed(name, restriction):
pos = get_backend("qp_in_cell")(restriction)
return quadrature_preamble("{} = {}.jacobianInverseTransposed({});".format(name,
geo,
pos,
str(pos),
),
assignees=frozenset({name}),
read_variables=frozenset({pos}),
read_variables=frozenset({get_pymbolic_basename(pos)}),
)
......@@ -304,11 +307,11 @@ def define_jacobian_determinant(name):
pos = get_backend("quad_pos")()
code = "{} = {}.integrationElement({});".format(name,
geo,
pos,
str(pos),
)
return quadrature_preamble(code,
assignees=frozenset({name}),
read_variables=frozenset({pos}),
read_variables=frozenset({get_pymbolic_basename(pos)}),
)
......
......@@ -5,6 +5,7 @@ from dune.perftool.generation import (cached,
class_member,
constructor_parameter,
generator_factory,
get_backend,
initializer_list,
preamble,
temporary_variable
......@@ -13,10 +14,11 @@ from dune.perftool.pdelab.geometry import (name_cell,
name_dimension,
name_intersection,
)
from dune.perftool.pdelab.quadrature import (name_quadrature_position,
name_quadrature_position_in_cell,
from dune.perftool.pdelab.quadrature import (pymbolic_quadrature_position,
pymbolic_quadrature_position_in_cell,
quadrature_preamble,
)
from dune.perftool.tools import get_pymbolic_basename
from dune.perftool.cgen.clazz import AccessModifier
from dune.perftool.pdelab.localoperator import (class_type_from_cache,
localoperator_basename,
......@@ -132,12 +134,12 @@ def evaluate_cellwise_constant_parameter_function(name, restriction):
import numpy
valuearg(name, dtype=numpy.float64)
return '{} = {}.{}({}, {});'.format(name,
name_paramclass(),
name,
entity,
pos,
)
return 'auto {} = {}.{}({}, {});'.format(name,
name_paramclass(),
name,
entity,
pos,
)
@preamble
......@@ -155,26 +157,26 @@ def evaluate_intersectionwise_constant_parameter_function(name):
import numpy
valuearg(name, dtype=numpy.float64)
return '{} = {}.{}({}, {});'.format(name,
name_paramclass(),
name,
intersection,
pos,
)
return 'auto {} = {}.{}({}, {});'.format(name,
name_paramclass(),
name,
intersection,
pos,
)
def evaluate_cell_parameter_function(name, restriction):
param = name_paramclass()
entity = name_cell(restriction)
pos = name_quadrature_position_in_cell(restriction)
pos = get_backend(interface="qp_in_cell")(restriction)
return quadrature_preamble('{} = {}.{}({}, {});'.format(name,
name_paramclass(),
name,
entity,
pos,
str(pos),
),
assignees=frozenset({name}),
read_variables=frozenset({pos}),
read_variables=frozenset({get_pymbolic_basename(pos)}),
)
......@@ -186,15 +188,15 @@ def evaluate_intersection_parameter_function(name):
param = name_paramclass()
intersection = name_intersection()
pos = name_quadrature_position()
pos = get_backend("quad_pos")()
return quadrature_preamble('{} = {}.{}({}, {});'.format(name,
name_paramclass(),
name,
intersection,
pos,
str(pos),
),
assignees=frozenset({name}),
read_variables=frozenset({pos}),
read_variables=frozenset({get_pymbolic_basename(pos)}),
)
......
......@@ -130,17 +130,10 @@ def pymbolic_quadrature_position():
return Subscript(Variable(quad_points), (Variable(quad_iname),))
def name_quadrature_position():
return str(pymbolic_quadrature_position())
@backend(interface="qp_in_cell")
def name_quadrature_position_in_cell(restriction):
if restriction == Restriction.NONE:
return name_quadrature_position()
else:
from dune.perftool.pdelab.geometry import to_cell_coordinates
return to_cell_coordinates(name_quadrature_position(), name_quadrature_point(), restriction)
def pymbolic_quadrature_position_in_cell(restriction):
from dune.perftool.pdelab.geometry import to_cell_coordinates
return to_cell_coordinates(pymbolic_quadrature_position(), restriction)
@preamble
......@@ -158,11 +151,13 @@ def typedef_quadrature_weights(name):
dim = _local_dim()
return "using {} = typename Dune::QuadraturePoint<{}, {}>::Field;".format(name, range_field, dim)
def pymbolic_quadrature_weight():
vec = name_quadrature_weights()
return Subscript(Variable(vec),
tuple(Variable(i) for i in quadrature_inames()))
def type_quadrature_weights(name):
name = name.upper()
typedef_quadrature_weights(name)
......
""" Some grabbag tools """
from pymbolic.primitives import Expression, Variable, Subscript
def get_pymbolic_basename(expr):
assert isinstance(expr, Expression), "Type: {}, expr: {}".format(type(expr), expr)
if isinstance(expr, Variable):
return expr.name
if isinstance(expr, Subscript):
return get_pymbolic_basename(expr.aggregate)
raise NotImplementedError("Cannot determine basename of {}".format(expr))
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment