Skip to content
Snippets Groups Projects
Commit 31f426a7 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Properly differeniate between worlddim and dim of the integrated entity

parent 1b3115e4
No related branches found
No related tags found
No related merge requests found
...@@ -19,7 +19,7 @@ from dune.perftool.pdelab.spaces import (lfs_child, ...@@ -19,7 +19,7 @@ from dune.perftool.pdelab.spaces import (lfs_child,
type_gfs, type_gfs,
) )
from dune.perftool.pdelab.geometry import (dimension_iname, from dune.perftool.pdelab.geometry import (dimension_iname,
name_dimension, world_dimension,
name_jacobian_inverse_transposed, name_jacobian_inverse_transposed,
to_cell_coordinates, to_cell_coordinates,
) )
...@@ -96,7 +96,7 @@ def pymbolic_basis(leaf_element, restriction, number, context=''): ...@@ -96,7 +96,7 @@ def pymbolic_basis(leaf_element, restriction, number, context=''):
@kernel_cached @kernel_cached
def evaluate_reference_gradient(leaf_element, name, restriction): def evaluate_reference_gradient(leaf_element, name, restriction):
lfs = name_leaf_lfs(leaf_element, restriction) lfs = name_leaf_lfs(leaf_element, restriction)
temporary_variable(name, shape=(name_lfs_bound(lfs), 1, name_dimension()), decl_method=declare_cache_temporary(leaf_element, restriction, 'Jacobian')) temporary_variable(name, shape=(name_lfs_bound(lfs), 1, world_dimension()), decl_method=declare_cache_temporary(leaf_element, restriction, 'Jacobian'))
cache = name_localbasis_cache(leaf_element) cache = name_localbasis_cache(leaf_element)
qp = get_backend("qp_in_cell")(restriction) qp = get_backend("qp_in_cell")(restriction)
instruction(inames=get_backend("quad_inames")(), instruction(inames=get_backend("quad_inames")(),
......
...@@ -13,9 +13,7 @@ from dune.perftool.generation import (backend, ...@@ -13,9 +13,7 @@ from dune.perftool.generation import (backend,
valuearg, valuearg,
) )
from dune.perftool.options import option_switch from dune.perftool.options import option_switch
from dune.perftool.pdelab.quadrature import (pymbolic_quadrature_position_in_cell, from dune.perftool.pdelab.quadrature import quadrature_preamble
quadrature_preamble,
)
from dune.perftool.tools import get_pymbolic_basename from dune.perftool.tools import get_pymbolic_basename
from ufl.algorithms import MultiFunction from ufl.algorithms import MultiFunction
from pymbolic.primitives import Variable from pymbolic.primitives import Variable
...@@ -214,7 +212,7 @@ def apply_in_cell_transformation(name, local, restriction): ...@@ -214,7 +212,7 @@ def apply_in_cell_transformation(name, local, restriction):
def pymbolic_in_cell_coordinates(local, restriction): def pymbolic_in_cell_coordinates(local, restriction):
basename = get_pymbolic_basename(local) basename = get_pymbolic_basename(local)
name = "{}_in_{}side".format(basename, "in" if restriction is Restriction.NEGATIVE else "out") name = "{}_in_{}side".format(basename, "in" if restriction is Restriction.NEGATIVE else "out")
temporary_variable(name, shape=(name_dimension(),), shape_impl=("fv",)) temporary_variable(name, shape=(world_dimension(),), shape_impl=("fv",))
apply_in_cell_transformation(name, local, restriction) apply_in_cell_transformation(name, local, restriction)
return Variable(name) return Variable(name)
...@@ -227,18 +225,21 @@ def to_cell_coordinates(local, restriction): ...@@ -227,18 +225,21 @@ def to_cell_coordinates(local, restriction):
return pymbolic_in_cell_coordinates(local, restriction) return pymbolic_in_cell_coordinates(local, restriction)
def name_dimension(): def world_dimension():
formdata = get_global_context_value('formdata') formdata = get_global_context_value('formdata')
return formdata.geometric_dimension return formdata.geometric_dimension
def world_dimension(): def intersection_dimension():
return name_dimension() return world_dimension() - 1
def name_intersection_dimension(): def local_dimension():
formdata = get_global_context_value('formdata') it = get_global_context_value('integral_type')
return formdata.geometric_dimension - 1 if it == "cell":
return world_dimension()
else:
return intersection_dimension()
def evaluate_unit_outer_normal(name): def evaluate_unit_outer_normal(name):
...@@ -259,7 +260,7 @@ def declare_normal(name, shape, shape_impl): ...@@ -259,7 +260,7 @@ def declare_normal(name, shape, shape_impl):
def name_unit_outer_normal(): def name_unit_outer_normal():
name = "outer_normal" name = "outer_normal"
temporary_variable(name, shape=(name_dimension(),), decl_method=declare_normal) temporary_variable(name, shape=(world_dimension(),), decl_method=declare_normal)
evaluate_unit_outer_normal(name) evaluate_unit_outer_normal(name)
return "outer_normal" return "outer_normal"
...@@ -274,7 +275,7 @@ def evaluate_unit_inner_normal(name): ...@@ -274,7 +275,7 @@ def evaluate_unit_inner_normal(name):
def name_unit_inner_normal(): def name_unit_inner_normal():
name = "inner_normal" name = "inner_normal"
temporary_variable(name, shape=(name_dimension(),), decl_method=declare_normal) temporary_variable(name, shape=(world_dimension(),), decl_method=declare_normal)
evaluate_unit_inner_normal(name) evaluate_unit_inner_normal(name)
return "inner_normal" return "inner_normal"
...@@ -300,7 +301,7 @@ def define_jacobian_inverse_transposed_temporary(restriction): ...@@ -300,7 +301,7 @@ def define_jacobian_inverse_transposed_temporary(restriction):
def define_constant_jacobian_inveser_transposed(name, restriction): def define_constant_jacobian_inveser_transposed(name, restriction):
geo = name_cell_geometry(restriction) geo = name_cell_geometry(restriction)
pos = name_localcenter() pos = name_localcenter()
dim = name_dimension() dim = world_dimension()
if restriction: if restriction:
geo_in = name_in_cell_geometry(restriction) geo_in = name_in_cell_geometry(restriction)
...@@ -316,7 +317,7 @@ def define_constant_jacobian_inveser_transposed(name, restriction): ...@@ -316,7 +317,7 @@ def define_constant_jacobian_inveser_transposed(name, restriction):
@backend(interface="define_jit", name="default") @backend(interface="define_jit", name="default")
def define_jacobian_inverse_transposed(name, restriction): def define_jacobian_inverse_transposed(name, restriction):
dim = name_dimension() dim = world_dimension()
temporary_variable(name, decl_method=define_jacobian_inverse_transposed_temporary(restriction), shape=(dim, dim)) temporary_variable(name, decl_method=define_jacobian_inverse_transposed_temporary(restriction), shape=(dim, dim))
geo = name_cell_geometry(restriction) geo = name_cell_geometry(restriction)
pos = get_backend("qp_in_cell")(restriction) pos = get_backend("qp_in_cell")(restriction)
......
...@@ -11,7 +11,6 @@ from dune.perftool.generation import (class_basename, ...@@ -11,7 +11,6 @@ from dune.perftool.generation import (class_basename,
temporary_variable temporary_variable
) )
from dune.perftool.pdelab.geometry import (name_cell, from dune.perftool.pdelab.geometry import (name_cell,
name_dimension,
name_intersection, name_intersection,
) )
from dune.perftool.pdelab.quadrature import (pymbolic_quadrature_position, from dune.perftool.pdelab.quadrature import (pymbolic_quadrature_position,
......
...@@ -72,19 +72,6 @@ def name_quadrature_point(): ...@@ -72,19 +72,6 @@ def name_quadrature_point():
return "qp" return "qp"
def _local_dim():
# To determine the shape, I do query global information here for lack of good alternatives
from dune.perftool.generation import get_global_context_value
it = get_global_context_value("integral_type")
from dune.perftool.pdelab.geometry import name_dimension, name_intersection_dimension
if it == 'cell':
dim = name_dimension()
else:
dim = name_intersection_dimension()
return dim
@preamble @preamble
def fill_quadrature_points_cache(name): def fill_quadrature_points_cache(name):
from dune.perftool.pdelab.geometry import name_geometry from dune.perftool.pdelab.geometry import name_geometry
...@@ -97,7 +84,8 @@ def fill_quadrature_points_cache(name): ...@@ -97,7 +84,8 @@ def fill_quadrature_points_cache(name):
@class_member(classtag="operator") @class_member(classtag="operator")
def typedef_quadrature_points(name): def typedef_quadrature_points(name):
range_field = lop_template_range_field() range_field = lop_template_range_field()
dim = _local_dim() from dune.perftool.pdelab.geometry import local_dimension
dim = local_dimension()
return "using {} = typename Dune::QuadraturePoint<{}, {}>::Vector;".format(name, range_field, dim) return "using {} = typename Dune::QuadraturePoint<{}, {}>::Vector;".format(name, range_field, dim)
...@@ -115,7 +103,8 @@ def define_quadrature_points(name): ...@@ -115,7 +103,8 @@ def define_quadrature_points(name):
def name_quadrature_points(): def name_quadrature_points():
"""Name of vector storing quadrature points as class member""" """Name of vector storing quadrature points as class member"""
dim = _local_dim() from dune.perftool.pdelab.geometry import local_dimension
dim = local_dimension()
name = "qp_order" + str(dim) name = "qp_order" + str(dim)
shape = (name_quadrature_bound(), dim) shape = (name_quadrature_bound(), dim)
globalarg(name, shape=shape, dtype=numpy.float64, managed=False) globalarg(name, shape=shape, dtype=numpy.float64, managed=False)
...@@ -150,7 +139,8 @@ def fill_quadrature_weights_cache(name): ...@@ -150,7 +139,8 @@ def fill_quadrature_weights_cache(name):
@class_member(classtag="operator") @class_member(classtag="operator")
def typedef_quadrature_weights(name): def typedef_quadrature_weights(name):
range_field = lop_template_range_field() range_field = lop_template_range_field()
dim = _local_dim() from dune.perftool.pdelab.geometry import local_dimension
dim = local_dimension()
return "using {} = typename Dune::QuadraturePoint<{}, {}>::Field;".format(name, range_field, dim) return "using {} = typename Dune::QuadraturePoint<{}, {}>::Field;".format(name, range_field, dim)
...@@ -174,7 +164,8 @@ def define_quadrature_weights(name): ...@@ -174,7 +164,8 @@ def define_quadrature_weights(name):
def name_quadrature_weights(): def name_quadrature_weights():
""""Name of vector storing quadrature weights as class member""" """"Name of vector storing quadrature weights as class member"""
dim = _local_dim() from dune.perftool.pdelab.geometry import local_dimension
dim = local_dimension()
name = "qw_order" + str(dim) name = "qw_order" + str(dim)
define_quadrature_weights(name) define_quadrature_weights(name)
fill_quadrature_weights_cache(name) fill_quadrature_weights_cache(name)
......
...@@ -235,8 +235,6 @@ def define_theta(name, shape, transpose, derivative, additional_indices=()): ...@@ -235,8 +235,6 @@ def define_theta(name, shape, transpose, derivative, additional_indices=()):
potentially_vectorized=True, potentially_vectorized=True,
) )
# TODO Enforce the alignment here!
i = theta_iname("i", shape[0]) i = theta_iname("i", shape[0])
j = theta_iname("j", shape[1]) j = theta_iname("j", shape[1])
......
...@@ -25,6 +25,7 @@ from dune.perftool.sumfact.sumfact import (setup_theta, ...@@ -25,6 +25,7 @@ from dune.perftool.sumfact.sumfact import (setup_theta,
sum_factorization_kernel, sum_factorization_kernel,
) )
from dune.perftool.sumfact.quadrature import quadrature_inames from dune.perftool.sumfact.quadrature import quadrature_inames
from dune.perftool.pdelab.geometry import world_dimension
from dune.perftool.loopy.buffer import initialize_buffer from dune.perftool.loopy.buffer import initialize_buffer
from dune.perftool.pdelab.driver import FEM_name_mangling from dune.perftool.pdelab.driver import FEM_name_mangling
from dune.perftool.pdelab.restriction import restricted_name from dune.perftool.pdelab.restriction import restricted_name
...@@ -52,16 +53,11 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor): ...@@ -52,16 +53,11 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor):
from ufl.functionview import select_subelement from ufl.functionview import select_subelement
sub_element = select_subelement(element, component) sub_element = select_subelement(element, component)
rank = len(sub_element.value_shape()) + 1 rank = len(sub_element.value_shape()) + 1
shape = sub_element.value_shape() + (element.cell().geometric_dimension(),) shape = sub_element.value_shape() + (world_dimension(),)
shape_impl = ('arr',) * rank shape_impl = ('arr',) * rank
temporary_variable(name, shape=shape, shape_impl=shape_impl) temporary_variable(name, shape=shape, shape_impl=shape_impl)
# TODO: dim = world_dimension()
# - This only covers rank 1
# - Avoid setting up whole gradient if only one component is needed?
# Get geometric dimension
formdata = get_global_context_value('formdata')
dim = formdata.geometric_dimension
buffers = [] buffers = []
insn_dep = None insn_dep = None
for i in range(dim): for i in range(dim):
...@@ -127,8 +123,7 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor): ...@@ -127,8 +123,7 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor):
@kernel_cached @kernel_cached
def pymbolic_trialfunction(element, restriction, component, visitor): def pymbolic_trialfunction(element, restriction, component, visitor):
# Get geometric dimension # Get geometric dimension
formdata = get_global_context_value('formdata') dim = world_dimension()
dim = formdata.geometric_dimension
# Construct the matrix sequence for this sum factorization # Construct the matrix sequence for this sum factorization
a_matrices = construct_amatrix_sequence() a_matrices = construct_amatrix_sequence()
...@@ -180,8 +175,7 @@ def sumfact_lfs_iname(bound, dim): ...@@ -180,8 +175,7 @@ def sumfact_lfs_iname(bound, dim):
@backend(interface="lfs_inames", name="sumfact") @backend(interface="lfs_inames", name="sumfact")
def lfs_inames(element, restriction, number=1, context=''): def lfs_inames(element, restriction, number=1, context=''):
assert number == 1 assert number == 1
formdata = get_global_context_value('formdata') dim = world_dimension()
dim = formdata.geometric_dimension
return tuple(sumfact_lfs_iname(basis_functions_per_direction(), d) for d in range(dim)) return tuple(sumfact_lfs_iname(basis_functions_per_direction(), d) for d in range(dim))
...@@ -220,10 +214,8 @@ def pymbolic_basis(element, restriction, number): ...@@ -220,10 +214,8 @@ def pymbolic_basis(element, restriction, number):
@backend(interface="evaluate_grad") @backend(interface="evaluate_grad")
@kernel_cached @kernel_cached
def evaluate_reference_gradient(element, name, restriction): def evaluate_reference_gradient(element, name, restriction):
from dune.perftool.pdelab.geometry import name_dimension dim = world_dimension()
temporary_variable( temporary_variable(name, shape=(dim,))
name,
shape=(name_dimension(),))
quad_inames = quadrature_inames() quad_inames = quadrature_inames()
inames = lfs_inames(element, restriction) inames = lfs_inames(element, restriction)
assert(len(quad_inames) == len(inames)) assert(len(quad_inames) == len(inames))
...@@ -232,10 +224,6 @@ def evaluate_reference_gradient(element, name, restriction): ...@@ -232,10 +224,6 @@ def evaluate_reference_gradient(element, name, restriction):
theta = name_theta() theta = name_theta()
dtheta = name_theta(derivative=True) dtheta = name_theta(derivative=True)
# Get geometric dimension
formdata = get_global_context_value('formdata')
dim = formdata.geometric_dimension
for i in range(dim): for i in range(dim):
calls = [prim.Subscript(prim.Variable(theta), (prim.Variable(m), prim.Variable(n))) calls = [prim.Subscript(prim.Variable(theta), (prim.Variable(m), prim.Variable(n)))
for (m, n) in zip(quad_inames, inames)] for (m, n) in zip(quad_inames, inames)]
......
...@@ -12,7 +12,9 @@ from dune.perftool.sumfact.amatrix import (quadrature_points_per_direction, ...@@ -12,7 +12,9 @@ from dune.perftool.sumfact.amatrix import (quadrature_points_per_direction,
name_oned_quadrature_weights, name_oned_quadrature_weights,
) )
from dune.perftool.pdelab.argument import name_accumulation_variable from dune.perftool.pdelab.argument import name_accumulation_variable
from dune.perftool.pdelab.geometry import dimension_iname from dune.perftool.pdelab.geometry import (dimension_iname,
local_dimension,
)
from loopy import CallMangleInfo from loopy import CallMangleInfo
from loopy.symbolic import FunctionIdentifier from loopy.symbolic import FunctionIdentifier
...@@ -77,9 +79,7 @@ def sumfact_quad_iname(d, context): ...@@ -77,9 +79,7 @@ def sumfact_quad_iname(d, context):
@backend(interface="quad_inames", name="sumfact") @backend(interface="quad_inames", name="sumfact")
def quadrature_inames(context=''): def quadrature_inames(context=''):
formdata = get_global_context_value('formdata') return tuple(sumfact_quad_iname(d, context) for d in range(local_dimension()))
dim = formdata.geometric_dimension
return tuple(sumfact_quad_iname(d, context) for d in range(dim))
def define_recursive_quadrature_weight(name, dir): def define_recursive_quadrature_weight(name, dir):
...@@ -98,9 +98,7 @@ def define_recursive_quadrature_weight(name, dir): ...@@ -98,9 +98,7 @@ def define_recursive_quadrature_weight(name, dir):
def recursive_quadrature_weight(dir=0): def recursive_quadrature_weight(dir=0):
formdata = get_global_context_value('formdata') if dir == local_dimension():
dim = formdata.geometric_dimension
if dir == dim:
return pymbolic_base_weight() return pymbolic_base_weight()
else: else:
name = 'weight_{}'.format(dir) name = 'weight_{}'.format(dir)
...@@ -113,9 +111,7 @@ def quadrature_weight(): ...@@ -113,9 +111,7 @@ def quadrature_weight():
def define_quadrature_position(name): def define_quadrature_position(name):
formdata = get_global_context_value('formdata') for i in range(local_dimension()):
dim = formdata.geometric_dimension
for i in range(dim):
instruction(expression=Subscript(Variable(name_oned_quadrature_points()), (Variable(quadrature_inames()[i]),)), instruction(expression=Subscript(Variable(name_oned_quadrature_points()), (Variable(quadrature_inames()[i]),)),
assignee=Subscript(Variable(name), (i,)), assignee=Subscript(Variable(name), (i,)),
forced_iname_deps=frozenset(quadrature_inames()), forced_iname_deps=frozenset(quadrature_inames()),
...@@ -126,10 +122,8 @@ def define_quadrature_position(name): ...@@ -126,10 +122,8 @@ def define_quadrature_position(name):
@backend(interface="quad_pos", name="sumfact") @backend(interface="quad_pos", name="sumfact")
def pymbolic_quadrature_position(): def pymbolic_quadrature_position():
formdata = get_global_context_value('formdata')
dim = formdata.geometric_dimension
name = 'pos' name = 'pos'
temporary_variable(name, shape=(dim,), shape_impl=("fv",)) temporary_variable(name, shape=(local_dimension(),), shape_impl=("fv",))
define_quadrature_position(name) define_quadrature_position(name)
return Variable(name) return Variable(name)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment