diff --git a/python/dune/codegen/pdelab/localoperator.py b/python/dune/codegen/pdelab/localoperator.py index eb7e1d3a50e7f8ba4b7b0f476c54565925c15d2a..11474dbf3489fae9fc85e944c9289322d8b3d3d7 100644 --- a/python/dune/codegen/pdelab/localoperator.py +++ b/python/dune/codegen/pdelab/localoperator.py @@ -495,7 +495,8 @@ def visit_integral(integral): # Start the visiting process! visitor = get_visitor(measure, subdomain_id) - visitor.accumulate(integrand) + with global_context(visitor=visitor): + visitor.accumulate(integrand) run_hook(name="after_visit", args=(visitor,), diff --git a/python/dune/codegen/sumfact/accumulation.py b/python/dune/codegen/sumfact/accumulation.py index ae4a763063f0dc25303fc517356b58b02551eebc..9394f746ec9cd7bafa569ccfbcb83d9a8e44663f 100644 --- a/python/dune/codegen/sumfact/accumulation.py +++ b/python/dune/codegen/sumfact/accumulation.py @@ -39,9 +39,6 @@ from dune.codegen.sumfact.permutation import (permute_backward, from dune.codegen.sumfact.tabulation import (basis_functions_per_direction, construct_basis_matrix_sequence, ) -from dune.codegen.sumfact.switch import (get_facedir, - get_facemod, - ) from dune.codegen.sumfact.symbolic import SumfactKernel, SumfactKernelInterfaceBase from dune.codegen.ufl.modified_terminals import extract_modified_arguments from dune.codegen.tools import get_pymbolic_basename, get_leaf, ImmutableCuttingRecord @@ -355,6 +352,22 @@ class SumfactAccumulationMixin(AccumulationMixinBase): def generate_accumulation_instruction(self, expr): return generate_accumulation_instruction(expr, self) + def get_facedir(self, restriction): + from dune.codegen.pdelab.restriction import Restriction + if restriction == Restriction.POSITIVE or get_global_context_value("integral_type") == "exterior_facet": + return get_global_context_value("facedir_s") + if restriction == Restriction.NEGATIVE: + return get_global_context_value("facedir_n") + return None + + def get_facemod(self, restriction): + from dune.codegen.pdelab.restriction import Restriction + if restriction == Restriction.POSITIVE or get_global_context_value("integral_type") == "exterior_facet": + return get_global_context_value("facemod_s") + if restriction == Restriction.NEGATIVE: + return get_global_context_value("facemod_n") + return None + class SumfactAccumulationInfo(ImmutableRecord): def __init__(self, @@ -520,8 +533,8 @@ def generate_accumulation_instruction(expr, visitor): matrix_sequence = construct_basis_matrix_sequence( transpose=True, derivative=test_info.grad_index, - facedir=get_facedir(test_info.restriction), - facemod=get_facemod(test_info.restriction), + facedir=visitor.get_facedir(test_info.restriction), + facemod=visitor.get_facemod(test_info.restriction), basis_size=basis_size) jacobian_inames = trial_info.inames diff --git a/python/dune/codegen/sumfact/basis.py b/python/dune/codegen/sumfact/basis.py index c9b75eb445af01e9420e850650b12054dd0115f3..8be5fc9e4a5157fee7f597b3000282238d85dd56 100644 --- a/python/dune/codegen/sumfact/basis.py +++ b/python/dune/codegen/sumfact/basis.py @@ -31,9 +31,6 @@ from dune.codegen.sumfact.permutation import (permute_backward, sumfact_cost_permutation_strategy, sumfact_quadrature_permutation_strategy, ) -from dune.codegen.sumfact.switch import (get_facedir, - get_facemod, - ) from dune.codegen.pdelab.argument import name_coefficientcontainer, name_applycontainer from dune.codegen.pdelab.basis import GenericBasisMixin from dune.codegen.pdelab.geometry import (local_dimension, @@ -86,7 +83,7 @@ class SumfactBasisMixin(GenericBasisMixin): temporary_variable(name, shape=()) quad_inames = self.quadrature_inames() inames = self.lfs_inames(element, restriction) - facedir = get_facedir(restriction) + facedir = self.get_facedir(restriction) # Collect the pairs of lfs/quad inames that are in use # On facets, the normal direction of the facet is excluded @@ -106,7 +103,7 @@ class SumfactBasisMixin(GenericBasisMixin): # Add the missing direction on facedirs by evaluating at either 0 or 1 if facedir is not None: - facemod = get_facemod(restriction) + facemod = self.get_facemod(restriction) prod = prod + (prim.Call(PolynomialLookup(name_polynomials(element.degree()), False), (prim.Variable(inames[facedir]), facemod)),) @@ -141,7 +138,7 @@ class SumfactBasisMixin(GenericBasisMixin): temporary_variable(name, shape=()) quad_inames = self.quadrature_inames() inames = self.lfs_inames(element, restriction) - facedir = get_facedir(restriction) + facedir = self.get_facedir(restriction) # Map the direction to a quadrature iname quadinamemapping = {} @@ -161,7 +158,7 @@ class SumfactBasisMixin(GenericBasisMixin): prod.append(tab.pymbolic((prim.Variable(quadinamemapping[i]), prim.Variable(inames[i])))) if facedir is not None: - facemod = get_facemod(restriction) + facemod = self.get_facemod(restriction) prod.append(prim.Call(PolynomialLookup(name_polynomials(element.degree()), index == facedir), (prim.Variable(inames[facedir]), facemod)),) @@ -197,8 +194,8 @@ class SumfactBasisMixin(GenericBasisMixin): # Construct the matrix sequence for this sum factorization matrix_sequence = construct_basis_matrix_sequence(derivative=derivative, - facedir=get_facedir(restriction), - facemod=get_facemod(restriction), + facedir=self.get_facedir(restriction), + facemod=self.get_facemod(restriction), basis_size=basis_size) inp = LFSSumfactKernelInput(matrix_sequence=matrix_sequence, diff --git a/python/dune/codegen/sumfact/geometry.py b/python/dune/codegen/sumfact/geometry.py index d37558f5c40f42355b5fa99ed9a4ea25ab7ade28..14a095696de75f6e169d5772949d1ed4707c2c7b 100644 --- a/python/dune/codegen/sumfact/geometry.py +++ b/python/dune/codegen/sumfact/geometry.py @@ -39,7 +39,6 @@ from dune.codegen.sumfact.permutation import (permute_backward, sumfact_quadrature_permutation_strategy, ) from dune.codegen.sumfact.quadrature import additional_inames -from dune.codegen.sumfact.switch import get_facedir, get_facemod from dune.codegen.sumfact.symbolic import SumfactKernelInterfaceBase, SumfactKernel from dune.codegen.tools import get_pymbolic_basename, ImmutableCuttingRecord from dune.codegen.options import get_form_option, option_switch @@ -162,8 +161,8 @@ class SumfactMultiLinearGeometryMixin(GenericPDELabGeometryMixin): def outer_normal(self): """ This is the *unnormalized* outer normal """ name = "outer_normal" - facedir_s = get_facedir(Restriction.POSITIVE) - facemod_s = get_facemod(Restriction.POSITIVE) + facedir_s = self.get_facedir(Restriction.POSITIVE) + facemod_s = self.get_facemod(Restriction.POSITIVE) temporary_variable(name, shape=(world_dimension(),)) for i in range(world_dimension()): @@ -210,8 +209,8 @@ class SumfactMultiLinearGeometryMixin(GenericPDELabGeometryMixin): restriction = enforce_boundary_restriction(self) # Generate sum factorization kernel and add vectorization info - matrix_sequence = construct_basis_matrix_sequence(facedir=get_facedir(restriction), - facemod=get_facemod(restriction), + matrix_sequence = construct_basis_matrix_sequence(facedir=self.get_facedir(restriction), + facemod=self.get_facemod(restriction), basis_size=(2,) * world_dimension()) inp = GeoCornersInput(matrix_sequence=matrix_sequence, direction=self.indices[0], @@ -253,8 +252,8 @@ class SumfactAxiParallelGeometryMixin(AxiparallelGeometryMixin): assert isinstance(i, int) # Use facemod_s and facedir_s - if i == get_facedir(Restriction.POSITIVE): - if get_facemod(Restriction.POSITIVE): + if i == self.get_facedir(Restriction.POSITIVE): + if self.get_facemod(Restriction.POSITIVE): return 1 else: return -1 @@ -270,7 +269,7 @@ class SumfactEqudistantGeometryMixin(EquidistantGeometryMixin, SumfactAxiParalle def facet_jacobian_determinant(self, o): name = "fdetjac" self.define_facet_jacobian_determinant(name) - facedir = get_facedir(Restriction.POSITIVE) + facedir = self.get_facedir(Restriction.POSITIVE) globalarg(name, shape=(world_dimension(),)) return prim.Subscript(prim.Variable(name), (facedir,)) @@ -304,8 +303,7 @@ class SumfactEqudistantGeometryMixin(EquidistantGeometryMixin, SumfactAxiParalle restriction = Restriction.NONE if self.measure == "interior_facet": restriction = Restriction.POSITIVE - from dune.codegen.sumfact.switch import get_facedir - face = get_facedir(restriction) + face = self.get_facedir(restriction) lowcorner = name_lowerleft_corner() meshwidth = name_meshwidth() @@ -527,8 +525,8 @@ def _name_jacobian(i, j, restriction, visitor): """ # Create matrix sequence with derivative in j direction matrix_sequence = construct_basis_matrix_sequence(derivative=j, - facedir=get_facedir(restriction), - facemod=get_facemod(restriction), + facedir=visitor.get_facedir(restriction), + facemod=visitor.get_facemod(restriction), basis_size=(2,) * world_dimension()) # Sum factorization input for the i'th component of the geometry mapping diff --git a/python/dune/codegen/sumfact/permutation.py b/python/dune/codegen/sumfact/permutation.py index 7f37dfeae795031ec81b292b96e65649f8fc78cc..a7b82eac88ce54b21700dc08514779fc08ba6272 100644 --- a/python/dune/codegen/sumfact/permutation.py +++ b/python/dune/codegen/sumfact/permutation.py @@ -3,7 +3,6 @@ import itertools from dune.codegen.options import get_option -from dune.codegen.sumfact.switch import get_facedir, get_facemod from dune.codegen.ufl.modified_terminals import Restriction @@ -125,8 +124,9 @@ def sumfact_quadrature_permutation_strategy(dim, restriction): # all others can be derived by rotating the cube and matching edge # directions. def _order_on_self(restriction): - facedir = get_facedir(restriction) - facemod = get_facemod(restriction) + visitor = get_global_context_value("visitor") + facedir = visitor.get_facedir(restriction) + facemod = visitor.get_facemod(restriction) quadrature_order = { (0, 0): (0, 1, 2), diff --git a/python/dune/codegen/sumfact/quadrature.py b/python/dune/codegen/sumfact/quadrature.py index 91e99c4c3c0cd333d6e64353ed06ac0c129a7e05..ae387426ebda79a8435811c548e91761f2de1b3c 100644 --- a/python/dune/codegen/sumfact/quadrature.py +++ b/python/dune/codegen/sumfact/quadrature.py @@ -10,7 +10,6 @@ from dune.codegen.generation import (domain, quadrature_mixin, temporary_variable, ) -from dune.codegen.sumfact.switch import get_facedir from dune.codegen.sumfact.tabulation import (quadrature_points_per_direction, local_quadrature_points_per_direction, name_oned_quadrature_points, @@ -22,7 +21,6 @@ from dune.codegen.pdelab.geometry import (local_dimension, ) from dune.codegen.pdelab.quadrature import GenericQuadratureMixin from dune.codegen.options import get_form_option -from dune.codegen.sumfact.switch import get_facedir from dune.codegen.loopy.target import dtype_floatingpoint from loopy import CallMangleInfo diff --git a/python/dune/codegen/sumfact/switch.py b/python/dune/codegen/sumfact/switch.py index a031790dd5321841eb8c100195cde675651fe1d3..65d35c17e4a8d2753e680c8c440fee6e9f4c6adf 100644 --- a/python/dune/codegen/sumfact/switch.py +++ b/python/dune/codegen/sumfact/switch.py @@ -169,25 +169,3 @@ def generate_interior_facet_switch(): block.append("}") return ClassMember(signature + block) - - -def get_facedir(restriction): - from dune.codegen.pdelab.restriction import Restriction - if restriction == Restriction.POSITIVE or get_global_context_value("integral_type") == "exterior_facet": - return get_global_context_value("facedir_s") - if restriction == Restriction.NEGATIVE: - return get_global_context_value("facedir_n") - if restriction == Restriction.NONE: - return None - assert False - - -def get_facemod(restriction): - from dune.codegen.pdelab.restriction import Restriction - if restriction == Restriction.POSITIVE or get_global_context_value("integral_type") == "exterior_facet": - return get_global_context_value("facemod_s") - if restriction == Restriction.NEGATIVE: - return get_global_context_value("facemod_n") - if restriction == Restriction.NONE: - return None - assert False