diff --git a/python/dune/codegen/pdelab/localoperator.py b/python/dune/codegen/pdelab/localoperator.py index 152eadd5968ea9adc9673da118f6a25c03761906..171f9af707fb744bddf97d9faae9cdf52875043a 100644 --- a/python/dune/codegen/pdelab/localoperator.py +++ b/python/dune/codegen/pdelab/localoperator.py @@ -497,7 +497,8 @@ def visit_integral(integral): # Start the visiting process! visitor = get_visitor(measure, subdomain_id) - visitor.accumulate(integrand) + with global_context(visitor=visitor): + visitor.accumulate(integrand) run_hook(name="after_visit", args=(visitor,), diff --git a/python/dune/codegen/sumfact/accumulation.py b/python/dune/codegen/sumfact/accumulation.py index 90fd01fd5b90b722c0a3ee194e1da82ed44f6eab..25be6f7eb80b365d8eead39c3cbb93c0e36c52bc 100644 --- a/python/dune/codegen/sumfact/accumulation.py +++ b/python/dune/codegen/sumfact/accumulation.py @@ -39,9 +39,6 @@ from dune.codegen.sumfact.permutation import (permute_backward, from dune.codegen.sumfact.tabulation import (basis_functions_per_direction, construct_basis_matrix_sequence, ) -from dune.codegen.sumfact.switch import (get_facedir, - get_facemod, - ) from dune.codegen.sumfact.symbolic import SumfactKernel, SumfactKernelInterfaceBase from dune.codegen.ufl.modified_terminals import extract_modified_arguments from dune.codegen.tools import get_pymbolic_basename, get_leaf, ImmutableCuttingRecord @@ -361,6 +358,22 @@ class SumfactAccumulationMixin(AccumulationMixinBase): def generate_accumulation_instruction(self, expr): return generate_accumulation_instruction(expr, self) + def get_facedir(self, restriction): + from dune.codegen.pdelab.restriction import Restriction + if restriction == Restriction.POSITIVE or get_global_context_value("integral_type") == "exterior_facet": + return get_global_context_value("facedir_s") + if restriction == Restriction.NEGATIVE: + return get_global_context_value("facedir_n") + return None + + def get_facemod(self, restriction): + from dune.codegen.pdelab.restriction import Restriction + if restriction == Restriction.POSITIVE or get_global_context_value("integral_type") == "exterior_facet": + return get_global_context_value("facemod_s") + if restriction == Restriction.NEGATIVE: + return get_global_context_value("facemod_n") + return None + def additional_matrix_sequence(self): return None @@ -375,8 +388,8 @@ class SumfactPointDiagonalAccumulationMixin(SumfactAccumulationMixin): info = self.current_info[1] return construct_basis_matrix_sequence(transpose=True, derivative=info.grad_index, - facedir=get_facedir(info.restriction), - facemod=get_facemod(info.restriction), + facedir=self.get_facedir(info.restriction), + facemod=self.get_facemod(info.restriction), basis_size=get_basis_size(info), ) @@ -582,8 +595,8 @@ def generate_accumulation_instruction(expr, visitor): matrix_sequence = construct_basis_matrix_sequence( transpose=True, derivative=test_info.grad_index, - facedir=get_facedir(test_info.restriction), - facemod=get_facemod(test_info.restriction), + facedir=visitor.get_facedir(test_info.restriction), + facemod=visitor.get_facemod(test_info.restriction), basis_size=basis_size, additional_sequence=visitor.additional_matrix_sequence()) diff --git a/python/dune/codegen/sumfact/basis.py b/python/dune/codegen/sumfact/basis.py index c49e052a90faac34ad6f64b539f93b3bb1d560cb..83e5206b5289719f3c4e85f4ca523976cb6feb79 100644 --- a/python/dune/codegen/sumfact/basis.py +++ b/python/dune/codegen/sumfact/basis.py @@ -31,9 +31,6 @@ from dune.codegen.sumfact.permutation import (permute_backward, sumfact_cost_permutation_strategy, sumfact_quadrature_permutation_strategy, ) -from dune.codegen.sumfact.switch import (get_facedir, - get_facemod, - ) from dune.codegen.pdelab.argument import name_coefficientcontainer, name_applycontainer from dune.codegen.pdelab.basis import GenericBasisMixin from dune.codegen.pdelab.geometry import (local_dimension, @@ -86,7 +83,7 @@ class SumfactBasisMixin(GenericBasisMixin): temporary_variable(name, shape=()) quad_inames = self.quadrature_inames() inames = self.lfs_inames(element, restriction) - facedir = get_facedir(restriction) + facedir = self.get_facedir(restriction) # Collect the pairs of lfs/quad inames that are in use # On facets, the normal direction of the facet is excluded @@ -106,7 +103,7 @@ class SumfactBasisMixin(GenericBasisMixin): # Add the missing direction on facedirs by evaluating at either 0 or 1 if facedir is not None: - facemod = get_facemod(restriction) + facemod = self.get_facemod(restriction) prod = prod + (prim.Call(PolynomialLookup(name_polynomials(element.degree()), False), (prim.Variable(inames[facedir]), facemod)),) @@ -141,7 +138,7 @@ class SumfactBasisMixin(GenericBasisMixin): temporary_variable(name, shape=()) quad_inames = self.quadrature_inames() inames = self.lfs_inames(element, restriction) - facedir = get_facedir(restriction) + facedir = self.get_facedir(restriction) # Map the direction to a quadrature iname quadinamemapping = {} @@ -161,7 +158,7 @@ class SumfactBasisMixin(GenericBasisMixin): prod.append(tab.pymbolic((prim.Variable(quadinamemapping[i]), prim.Variable(inames[i])))) if facedir is not None: - facemod = get_facemod(restriction) + facemod = self.get_facemod(restriction) prod.append(prim.Call(PolynomialLookup(name_polynomials(element.degree()), index == facedir), (prim.Variable(inames[facedir]), facemod)),) @@ -197,8 +194,8 @@ class SumfactBasisMixin(GenericBasisMixin): # Construct the matrix sequence for this sum factorization matrix_sequence = construct_basis_matrix_sequence(derivative=derivative, - facedir=get_facedir(restriction), - facemod=get_facemod(restriction), + facedir=self.get_facedir(restriction), + facemod=self.get_facemod(restriction), basis_size=basis_size) inp = LFSSumfactKernelInput(matrix_sequence=matrix_sequence, diff --git a/python/dune/codegen/sumfact/geometry.py b/python/dune/codegen/sumfact/geometry.py index 573d6ff932f068321701ea764cd43c3e87ba48eb..811a201e9e11238bcdc4965ce2c7f75426fc3487 100644 --- a/python/dune/codegen/sumfact/geometry.py +++ b/python/dune/codegen/sumfact/geometry.py @@ -39,7 +39,6 @@ from dune.codegen.sumfact.permutation import (permute_backward, sumfact_quadrature_permutation_strategy, ) from dune.codegen.sumfact.quadrature import additional_inames -from dune.codegen.sumfact.switch import get_facedir, get_facemod from dune.codegen.sumfact.symbolic import SumfactKernelInterfaceBase, SumfactKernel from dune.codegen.tools import get_pymbolic_basename, ImmutableCuttingRecord from dune.codegen.options import get_form_option, option_switch @@ -162,8 +161,8 @@ class SumfactMultiLinearGeometryMixin(GenericPDELabGeometryMixin): def outer_normal(self): """ This is the *unnormalized* outer normal """ name = "outer_normal" - facedir_s = get_facedir(Restriction.POSITIVE) - facemod_s = get_facemod(Restriction.POSITIVE) + facedir_s = self.get_facedir(Restriction.POSITIVE) + facemod_s = self.get_facemod(Restriction.POSITIVE) temporary_variable(name, shape=(world_dimension(),)) for i in range(world_dimension()): @@ -210,8 +209,8 @@ class SumfactMultiLinearGeometryMixin(GenericPDELabGeometryMixin): restriction = enforce_boundary_restriction(self) # Generate sum factorization kernel and add vectorization info - matrix_sequence = construct_basis_matrix_sequence(facedir=get_facedir(restriction), - facemod=get_facemod(restriction), + matrix_sequence = construct_basis_matrix_sequence(facedir=self.get_facedir(restriction), + facemod=self.get_facemod(restriction), basis_size=(2,) * world_dimension()) inp = GeoCornersInput(matrix_sequence=matrix_sequence, direction=self.indices[0], @@ -253,8 +252,8 @@ class SumfactAxiParallelGeometryMixin(AxiparallelGeometryMixin): assert isinstance(i, int) # Use facemod_s and facedir_s - if i == get_facedir(Restriction.POSITIVE): - if get_facemod(Restriction.POSITIVE): + if i == self.get_facedir(Restriction.POSITIVE): + if self.get_facemod(Restriction.POSITIVE): return 1 else: return -1 @@ -270,7 +269,7 @@ class SumfactEqudistantGeometryMixin(EquidistantGeometryMixin, SumfactAxiParalle def facet_jacobian_determinant(self, o): name = "fdetjac" self.define_facet_jacobian_determinant(name) - facedir = get_facedir(Restriction.POSITIVE) + facedir = self.get_facedir(Restriction.POSITIVE) globalarg(name, shape=(world_dimension(),)) return prim.Subscript(prim.Variable(name), (facedir,)) @@ -304,8 +303,7 @@ class SumfactEqudistantGeometryMixin(EquidistantGeometryMixin, SumfactAxiParalle restriction = Restriction.NONE if self.measure == "interior_facet": restriction = Restriction.POSITIVE - from dune.codegen.sumfact.switch import get_facedir - face = get_facedir(restriction) + face = self.get_facedir(restriction) lowcorner = name_lowerleft_corner() meshwidth = name_meshwidth() @@ -527,8 +525,8 @@ def _name_jacobian(i, j, restriction, visitor): """ # Create matrix sequence with derivative in j direction matrix_sequence = construct_basis_matrix_sequence(derivative=j, - facedir=get_facedir(restriction), - facemod=get_facemod(restriction), + facedir=visitor.get_facedir(restriction), + facemod=visitor.get_facemod(restriction), basis_size=(2,) * world_dimension()) # Sum factorization input for the i'th component of the geometry mapping diff --git a/python/dune/codegen/sumfact/permutation.py b/python/dune/codegen/sumfact/permutation.py index 7f37dfeae795031ec81b292b96e65649f8fc78cc..6e9fdaaad5dc9021d0ec076df7cc743806eab512 100644 --- a/python/dune/codegen/sumfact/permutation.py +++ b/python/dune/codegen/sumfact/permutation.py @@ -3,7 +3,6 @@ import itertools from dune.codegen.options import get_option -from dune.codegen.sumfact.switch import get_facedir, get_facemod from dune.codegen.ufl.modified_terminals import Restriction @@ -125,8 +124,9 @@ def sumfact_quadrature_permutation_strategy(dim, restriction): # all others can be derived by rotating the cube and matching edge # directions. def _order_on_self(restriction): - facedir = get_facedir(restriction) - facemod = get_facemod(restriction) + from dune.codegen.sumfact.accumulation import SumfactAccumulationMixin + facedir = SumfactAccumulationMixin.get_facedir(None, restriction) + facemod = SumfactAccumulationMixin.get_facemod(None, restriction) quadrature_order = { (0, 0): (0, 1, 2), diff --git a/python/dune/codegen/sumfact/quadrature.py b/python/dune/codegen/sumfact/quadrature.py index 91e99c4c3c0cd333d6e64353ed06ac0c129a7e05..ae387426ebda79a8435811c548e91761f2de1b3c 100644 --- a/python/dune/codegen/sumfact/quadrature.py +++ b/python/dune/codegen/sumfact/quadrature.py @@ -10,7 +10,6 @@ from dune.codegen.generation import (domain, quadrature_mixin, temporary_variable, ) -from dune.codegen.sumfact.switch import get_facedir from dune.codegen.sumfact.tabulation import (quadrature_points_per_direction, local_quadrature_points_per_direction, name_oned_quadrature_points, @@ -22,7 +21,6 @@ from dune.codegen.pdelab.geometry import (local_dimension, ) from dune.codegen.pdelab.quadrature import GenericQuadratureMixin from dune.codegen.options import get_form_option -from dune.codegen.sumfact.switch import get_facedir from dune.codegen.loopy.target import dtype_floatingpoint from loopy import CallMangleInfo diff --git a/python/dune/codegen/sumfact/switch.py b/python/dune/codegen/sumfact/switch.py index a031790dd5321841eb8c100195cde675651fe1d3..65d35c17e4a8d2753e680c8c440fee6e9f4c6adf 100644 --- a/python/dune/codegen/sumfact/switch.py +++ b/python/dune/codegen/sumfact/switch.py @@ -169,25 +169,3 @@ def generate_interior_facet_switch(): block.append("}") return ClassMember(signature + block) - - -def get_facedir(restriction): - from dune.codegen.pdelab.restriction import Restriction - if restriction == Restriction.POSITIVE or get_global_context_value("integral_type") == "exterior_facet": - return get_global_context_value("facedir_s") - if restriction == Restriction.NEGATIVE: - return get_global_context_value("facedir_n") - if restriction == Restriction.NONE: - return None - assert False - - -def get_facemod(restriction): - from dune.codegen.pdelab.restriction import Restriction - if restriction == Restriction.POSITIVE or get_global_context_value("integral_type") == "exterior_facet": - return get_global_context_value("facemod_s") - if restriction == Restriction.NEGATIVE: - return get_global_context_value("facemod_n") - if restriction == Restriction.NONE: - return None - assert False