Skip to content
Snippets Groups Projects
Commit 5acbcba0 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

WIP

parent c9f510b5
No related branches found
No related tags found
No related merge requests found
......@@ -19,7 +19,6 @@ from dune.perftool.generation import (backend,
from dune.perftool.sumfact.tabulation import (basis_functions_per_direction,
construct_basis_matrix_sequence,
BasisTabulationMatrix,
quadrature_points_per_direction,
PolynomialLookup,
name_polynomials,
polynomial_degree,
......@@ -242,7 +241,8 @@ def evaluate_basis(element, name, restriction):
help_index = 0
for direction in range(len(inames)):
if direction != facedir:
prod = prod + (BasisTabulationMatrix(basis_size=basis_per_dir[direction]).pymbolic((prim.Variable(quad_inames[help_index]), prim.Variable(inames[direction]))),)
prod = prod + (BasisTabulationMatrix(basis_size=basis_per_dir[direction],
direction=direction).pymbolic((prim.Variable(quad_inames[help_index]), prim.Variable(inames[direction]))),)
help_index += 1
# Add the missing direction on facedirs by evaluating at either 0 or 1
......@@ -295,6 +295,7 @@ def evaluate_reference_gradient(element, name, restriction, index):
for i in range(dim):
if i != facedir:
tab = BasisTabulationMatrix(basis_size=_basis_functions_per_direction(element)[i],
direction=i,
derivative=index == i)
prod.append(tab.pymbolic((prim.Variable(quadinamemapping[i]), prim.Variable(inames[i]))))
......
......@@ -80,7 +80,7 @@ def pymbolic_spatial_coordinate_multilinear(do_predicates, visitor):
# dim-1 matrices!
from dune.perftool.sumfact.tabulation import quadrature_points_per_direction, BasisTabulationMatrix
quadrature_size = quadrature_points_per_direction()
matrix_sequence = (BasisTabulationMatrix(quadrature_size=quadrature_size, basis_size=2),) * local_dimension()
matrix_sequence = tuple(BasisTabulationMatrix(direction=i, basis_size=2) for i in range(world_dimension()) if i != get_facedir(visitor.restriction))
inp = GeoCornersInput(visitor.indices[0])
from dune.perftool.sumfact.symbolic import SumfactKernel
......
......@@ -7,6 +7,7 @@ from dune.perftool.pdelab.geometry import world_dimension, local_dimension
from dune.perftool.generation import (class_member,
domain,
function_mangler,
generator_factory,
get_global_context_value,
iname,
include_file,
......@@ -43,27 +44,21 @@ class BasisTabulationMatrixBase(object):
class BasisTabulationMatrix(BasisTabulationMatrixBase, ImmutableRecord):
def __init__(self,
quadrature_size=None,
basis_size=None,
transpose=False,
derivative=False,
face=None,
direction=None,
slice_size=None,
slice_index=None,
):
assert(isinstance(basis_size, int))
if quadrature_size is None:
quadrature_size = quadrature_points_per_direction()
assert(qp == quadrature_size[0] for qp in quadrature_size)
quadrature_size = quadrature_size[0]
if slice_size is not None:
quadrature_size = ceildiv(quadrature_size, slice_size)
ImmutableRecord.__init__(self,
quadrature_size=quadrature_size,
basis_size=basis_size,
transpose=transpose,
derivative=derivative,
face=face,
direction=direction,
slice_size=slice_size,
slice_index=slice_index,
)
......@@ -82,6 +77,15 @@ class BasisTabulationMatrix(BasisTabulationMatrixBase, ImmutableRecord):
else:
return self.basis_size
@property
def quadrature_size(self):
size = quadrature_points_per_direction()
size = size[self.direction]
if self.slice_size is not None:
size = ceildiv(size, self.slice_size)
return size
def pymbolic(self, indices):
name = "{}{}Theta{}{}_qp{}_dof{}" \
.format("face{}_".format(self.face) if self.face is not None else "",
......@@ -201,7 +205,17 @@ class BasisTabulationMatrixArray(BasisTabulationMatrixBase):
return True
@generator_factory(context_tags=("kernel",),
cache_key_generator=lambda q: 0 if q is None else 1)
def set_quadrature_points_per_direction(quad):
return quad
def quadrature_points_per_direction():
custom = set_quadrature_points_per_direction(None)
if custom is not None:
return custom
# Quadrature order per direction
q = quadrature_order()
if isinstance(q, int):
......@@ -403,7 +417,7 @@ def construct_basis_matrix_sequence(transpose=False, derivative=None, facedir=No
if facedir == i:
onface = facemod
result[i] = BasisTabulationMatrix(quadrature_size=quadrature_size[i],
result[i] = BasisTabulationMatrix(direction=i,
basis_size=basis_size[i],
transpose=transpose,
derivative=derivative == i,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment