From 4b137861b5e7562180cde04fa3ce3fb321563ef2 Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Thu, 8 Dec 2016 18:20:03 +0100 Subject: [PATCH] Implement the evaluation of basis for jacobians --- python/dune/perftool/sumfact/basis.py | 30 ++++++++++++++++++------- python/dune/perftool/sumfact/sumfact.py | 12 +--------- python/dune/perftool/sumfact/switch.py | 22 ++++++++++++++++++ 3 files changed, 45 insertions(+), 19 deletions(-) diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py index 7274e58a..0965bdce 100644 --- a/python/dune/perftool/sumfact/basis.py +++ b/python/dune/perftool/sumfact/basis.py @@ -26,6 +26,9 @@ from dune.perftool.sumfact.sumfact import (get_facedir, sum_factorization_kernel, ) from dune.perftool.sumfact.quadrature import quadrature_inames +from dune.perftool.sumfact.switch import (get_facedir, + get_facemod, + ) from dune.perftool.pdelab.geometry import world_dimension from dune.perftool.loopy.buffer import initialize_buffer from dune.perftool.pdelab.driver import FEM_name_mangling @@ -188,14 +191,25 @@ def evaluate_basis(element, name, restriction): theta = name_theta() quad_inames = quadrature_inames() inames = lfs_inames(element, restriction) - assert(len(quad_inames) == len(inames)) - - instruction(expression=prim.Product(tuple(prim.Subscript(prim.Variable(theta), - (prim.Variable(i), prim.Variable(j)) - ) - for (i, j) in zip(quad_inames, inames) - ) - ), + facedir = get_facedir(restriction) + + # Collect the pairs of lfs/quad inames that are in use + # On facets, the normal direction of the facet is excluded + prod = tuple(prim.Subscript(prim.Variable(theta), + (prim.Variable(i), prim.Variable(j)) + ) + for (i, j) in zip(quad_inames, tuple(iname for i, iname in enumerate(inames) if i != facedir)) + ) + + # Add the missing direction on facedirs by evaluating at either 0 or 1 + if facedir: + facemod = get_facemod(restriction) + from dune.perftool.sumfact.amatrix import PolynomialLookup, name_polynomials + prod = prod + (prim.Call(PolynomialLookup(name_polynomials(), False), + (prim.Variable(inames[facedir]), facemod)),) + + # Issue the product + instruction(expression=prim.Product(prod), assignee=prim.Variable(name), forced_iname_deps=frozenset(quad_inames + inames), forced_iname_deps_is_final=True, diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py index a4377817..ed219647 100644 --- a/python/dune/perftool/sumfact/sumfact.py +++ b/python/dune/perftool/sumfact/sumfact.py @@ -43,6 +43,7 @@ from dune.perftool.sumfact.amatrix import (AMatrix, basis_functions_per_direction, construct_amatrix_sequence, ) +from dune.perftool.sumfact.switch import get_facedir from dune.perftool.loopy.symbolic import SumfactKernel from dune.perftool.tools import get_pymbolic_basename from dune.perftool.error import PerftoolError @@ -62,17 +63,6 @@ import pymbolic.primitives as prim from pytools import product -def get_facedir(restriction): - from dune.perftool.pdelab.restriction import Restriction - if restriction == Restriction.NEGATIVE or get_global_context_value("integral_type") == "exterior_facet": - return get_global_context_value("facedir_s") - if restriction == Restriction.POSITIVE: - return get_global_context_value("facedir_n") - if restriction == Restriction.NONE: - return None - assert False - - @iname def _sumfact_iname(bound, _type, count): name = "sf_{}_{}".format(_type, str(count)) diff --git a/python/dune/perftool/sumfact/switch.py b/python/dune/perftool/sumfact/switch.py index 78b31590..d46d5d75 100644 --- a/python/dune/perftool/sumfact/switch.py +++ b/python/dune/perftool/sumfact/switch.py @@ -105,3 +105,25 @@ def generate_interior_facet_switch(): block.append("}") return ClassMember(signature + block) + + +def get_facedir(restriction): + from dune.perftool.pdelab.restriction import Restriction + if restriction == Restriction.NEGATIVE or get_global_context_value("integral_type") == "exterior_facet": + return get_global_context_value("facedir_s") + if restriction == Restriction.POSITIVE: + return get_global_context_value("facedir_n") + if restriction == Restriction.NONE: + return None + assert False + + +def get_facemod(restriction): + from dune.perftool.pdelab.restriction import Restriction + if restriction == Restriction.NEGATIVE or get_global_context_value("integral_type") == "exterior_facet": + return get_global_context_value("facemod_s") + if restriction == Restriction.POSITIVE: + return get_global_context_value("facemod_n") + if restriction == Restriction.NONE: + return None + assert False -- GitLab