diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py index 7274e58a73c36eb1ebf28cffcf9b289fc0e88a77..0965bdce0bef66127b56e144085f282180ead61c 100644 --- a/python/dune/perftool/sumfact/basis.py +++ b/python/dune/perftool/sumfact/basis.py @@ -26,6 +26,9 @@ from dune.perftool.sumfact.sumfact import (get_facedir, sum_factorization_kernel, ) from dune.perftool.sumfact.quadrature import quadrature_inames +from dune.perftool.sumfact.switch import (get_facedir, + get_facemod, + ) from dune.perftool.pdelab.geometry import world_dimension from dune.perftool.loopy.buffer import initialize_buffer from dune.perftool.pdelab.driver import FEM_name_mangling @@ -188,14 +191,25 @@ def evaluate_basis(element, name, restriction): theta = name_theta() quad_inames = quadrature_inames() inames = lfs_inames(element, restriction) - assert(len(quad_inames) == len(inames)) - - instruction(expression=prim.Product(tuple(prim.Subscript(prim.Variable(theta), - (prim.Variable(i), prim.Variable(j)) - ) - for (i, j) in zip(quad_inames, inames) - ) - ), + facedir = get_facedir(restriction) + + # Collect the pairs of lfs/quad inames that are in use + # On facets, the normal direction of the facet is excluded + prod = tuple(prim.Subscript(prim.Variable(theta), + (prim.Variable(i), prim.Variable(j)) + ) + for (i, j) in zip(quad_inames, tuple(iname for i, iname in enumerate(inames) if i != facedir)) + ) + + # Add the missing direction on facedirs by evaluating at either 0 or 1 + if facedir: + facemod = get_facemod(restriction) + from dune.perftool.sumfact.amatrix import PolynomialLookup, name_polynomials + prod = prod + (prim.Call(PolynomialLookup(name_polynomials(), False), + (prim.Variable(inames[facedir]), facemod)),) + + # Issue the product + instruction(expression=prim.Product(prod), assignee=prim.Variable(name), forced_iname_deps=frozenset(quad_inames + inames), forced_iname_deps_is_final=True, diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py index a437781717308f4b8d202760f8284817172ce9eb..ed21964794496bc431ac99c3886672bdfe57d6e1 100644 --- a/python/dune/perftool/sumfact/sumfact.py +++ b/python/dune/perftool/sumfact/sumfact.py @@ -43,6 +43,7 @@ from dune.perftool.sumfact.amatrix import (AMatrix, basis_functions_per_direction, construct_amatrix_sequence, ) +from dune.perftool.sumfact.switch import get_facedir from dune.perftool.loopy.symbolic import SumfactKernel from dune.perftool.tools import get_pymbolic_basename from dune.perftool.error import PerftoolError @@ -62,17 +63,6 @@ import pymbolic.primitives as prim from pytools import product -def get_facedir(restriction): - from dune.perftool.pdelab.restriction import Restriction - if restriction == Restriction.NEGATIVE or get_global_context_value("integral_type") == "exterior_facet": - return get_global_context_value("facedir_s") - if restriction == Restriction.POSITIVE: - return get_global_context_value("facedir_n") - if restriction == Restriction.NONE: - return None - assert False - - @iname def _sumfact_iname(bound, _type, count): name = "sf_{}_{}".format(_type, str(count)) diff --git a/python/dune/perftool/sumfact/switch.py b/python/dune/perftool/sumfact/switch.py index 78b31590f8ea3656b19254ec77dde955c4607982..d46d5d7599cf05df323112464f35a66e5a39b781 100644 --- a/python/dune/perftool/sumfact/switch.py +++ b/python/dune/perftool/sumfact/switch.py @@ -105,3 +105,25 @@ def generate_interior_facet_switch(): block.append("}") return ClassMember(signature + block) + + +def get_facedir(restriction): + from dune.perftool.pdelab.restriction import Restriction + if restriction == Restriction.NEGATIVE or get_global_context_value("integral_type") == "exterior_facet": + return get_global_context_value("facedir_s") + if restriction == Restriction.POSITIVE: + return get_global_context_value("facedir_n") + if restriction == Restriction.NONE: + return None + assert False + + +def get_facemod(restriction): + from dune.perftool.pdelab.restriction import Restriction + if restriction == Restriction.NEGATIVE or get_global_context_value("integral_type") == "exterior_facet": + return get_global_context_value("facemod_s") + if restriction == Restriction.POSITIVE: + return get_global_context_value("facemod_n") + if restriction == Restriction.NONE: + return None + assert False