diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py index aa4b0ddec4ef9b96049adb8d2e93fb6adde8b1b5..1c1823df4c486e657ac1e960cc0b258fe2cb6d33 100644 --- a/python/dune/perftool/sumfact/basis.py +++ b/python/dune/perftool/sumfact/basis.py @@ -19,7 +19,6 @@ from dune.perftool.generation import (backend, from dune.perftool.sumfact.tabulation import (basis_functions_per_direction, construct_basis_matrix_sequence, BasisTabulationMatrix, - quadrature_points_per_direction, PolynomialLookup, name_polynomials, polynomial_degree, @@ -242,7 +241,8 @@ def evaluate_basis(element, name, restriction): help_index = 0 for direction in range(len(inames)): if direction != facedir: - prod = prod + (BasisTabulationMatrix(basis_size=basis_per_dir[direction]).pymbolic((prim.Variable(quad_inames[help_index]), prim.Variable(inames[direction]))),) + prod = prod + (BasisTabulationMatrix(basis_size=basis_per_dir[direction], + direction=direction).pymbolic((prim.Variable(quad_inames[help_index]), prim.Variable(inames[direction]))),) help_index += 1 # Add the missing direction on facedirs by evaluating at either 0 or 1 @@ -295,6 +295,7 @@ def evaluate_reference_gradient(element, name, restriction, index): for i in range(dim): if i != facedir: tab = BasisTabulationMatrix(basis_size=_basis_functions_per_direction(element)[i], + direction=i, derivative=index == i) prod.append(tab.pymbolic((prim.Variable(quadinamemapping[i]), prim.Variable(inames[i])))) diff --git a/python/dune/perftool/sumfact/geometry.py b/python/dune/perftool/sumfact/geometry.py index 95f383800ac5aa3276ce37830b153c4cfd23f6a1..77bb00385b5c83b0add2326b1db6d03488ce50bf 100644 --- a/python/dune/perftool/sumfact/geometry.py +++ b/python/dune/perftool/sumfact/geometry.py @@ -80,7 +80,7 @@ def pymbolic_spatial_coordinate_multilinear(do_predicates, visitor): # dim-1 matrices! from dune.perftool.sumfact.tabulation import quadrature_points_per_direction, BasisTabulationMatrix quadrature_size = quadrature_points_per_direction() - matrix_sequence = (BasisTabulationMatrix(quadrature_size=quadrature_size, basis_size=2),) * local_dimension() + matrix_sequence = tuple(BasisTabulationMatrix(direction=i, basis_size=2) for i in range(world_dimension()) if i != get_facedir(visitor.restriction)) inp = GeoCornersInput(visitor.indices[0]) from dune.perftool.sumfact.symbolic import SumfactKernel diff --git a/python/dune/perftool/sumfact/tabulation.py b/python/dune/perftool/sumfact/tabulation.py index e8518668c227c7528ab72712bbd9e2cba92c147f..976a2f4b7f9bc066c3293f70e2e6d3bcd66fca3b 100644 --- a/python/dune/perftool/sumfact/tabulation.py +++ b/python/dune/perftool/sumfact/tabulation.py @@ -7,6 +7,7 @@ from dune.perftool.pdelab.geometry import world_dimension, local_dimension from dune.perftool.generation import (class_member, domain, function_mangler, + generator_factory, get_global_context_value, iname, include_file, @@ -43,27 +44,21 @@ class BasisTabulationMatrixBase(object): class BasisTabulationMatrix(BasisTabulationMatrixBase, ImmutableRecord): def __init__(self, - quadrature_size=None, basis_size=None, transpose=False, derivative=False, face=None, + direction=None, slice_size=None, slice_index=None, ): assert(isinstance(basis_size, int)) - if quadrature_size is None: - quadrature_size = quadrature_points_per_direction() - assert(qp == quadrature_size[0] for qp in quadrature_size) - quadrature_size = quadrature_size[0] - if slice_size is not None: - quadrature_size = ceildiv(quadrature_size, slice_size) ImmutableRecord.__init__(self, - quadrature_size=quadrature_size, basis_size=basis_size, transpose=transpose, derivative=derivative, face=face, + direction=direction, slice_size=slice_size, slice_index=slice_index, ) @@ -82,6 +77,15 @@ class BasisTabulationMatrix(BasisTabulationMatrixBase, ImmutableRecord): else: return self.basis_size + @property + def quadrature_size(self): + size = quadrature_points_per_direction() + size = size[self.direction] + if self.slice_size is not None: + size = ceildiv(size, self.slice_size) + return size + + def pymbolic(self, indices): name = "{}{}Theta{}{}_qp{}_dof{}" \ .format("face{}_".format(self.face) if self.face is not None else "", @@ -201,7 +205,17 @@ class BasisTabulationMatrixArray(BasisTabulationMatrixBase): return True +@generator_factory(context_tags=("kernel",), + cache_key_generator=lambda q: 0 if q is None else 1) +def set_quadrature_points_per_direction(quad): + return quad + + def quadrature_points_per_direction(): + custom = set_quadrature_points_per_direction(None) + if custom is not None: + return custom + # Quadrature order per direction q = quadrature_order() if isinstance(q, int): @@ -403,7 +417,7 @@ def construct_basis_matrix_sequence(transpose=False, derivative=None, facedir=No if facedir == i: onface = facemod - result[i] = BasisTabulationMatrix(quadrature_size=quadrature_size[i], + result[i] = BasisTabulationMatrix(direction=i, basis_size=basis_size[i], transpose=transpose, derivative=derivative == i,