diff --git a/python/dune/perftool/sumfact/__init__.py b/python/dune/perftool/sumfact/__init__.py index df246dee44aa60c12827f3ec272b70b6beea7e65..a4fce88372bcb5546f060cf9785b726d25966e6a 100644 --- a/python/dune/perftool/sumfact/__init__.py +++ b/python/dune/perftool/sumfact/__init__.py @@ -46,23 +46,19 @@ class SumFactInterface(PDELabInterface): return ret def pymbolic_trialfunction_gradient(self, element, restriction, index): - ret, indices = pymbolic_coefficient_gradient(element, restriction, index, name_coefficientcontainer, self.visitor.indices) - self.visitor.indices = indices + ret = pymbolic_coefficient_gradient(element, restriction, index, name_coefficientcontainer, self.visitor) return ret def pymbolic_trialfunction(self, element, restriction, index): - ret, indices = pymbolic_coefficient(element, restriction, index, name_coefficientcontainer, self.visitor.indices) - self.visitor.indices = indices + ret = pymbolic_coefficient(element, restriction, index, name_coefficientcontainer, self.visitor) return ret def pymbolic_apply_function_gradient(self, element, restriction, index): - ret, indices = pymbolic_coefficient_gradient(element, restriction, index, name_applycontainer, self.visitor.indices) - self.visitor.indices = indices + ret = pymbolic_coefficient_gradient(element, restriction, index, name_applycontainer, self.visitor) return ret def pymbolic_apply_function(self, element, restriction, index): - ret, indices = pymbolic_coefficient(element, restriction, index, name_applycontainer, self.visitor.indices) - self.visitor.indices = indices + ret = pymbolic_coefficient(element, restriction, index, name_applycontainer, self.visitor) return ret def quadrature_inames(self): @@ -73,8 +69,7 @@ class SumFactInterface(PDELabInterface): def pymbolic_spatial_coordinate(self): import dune.perftool.sumfact.geometry - ret, indices = get_backend(interface="spatial_coordinate", selector=option_switch("diagonal_transformation_matrix"))(self.visitor.indices, self.visitor.do_predicates, self.visitor) - self.visitor.indices = indices + ret = get_backend(interface="spatial_coordinate", selector=option_switch("diagonal_transformation_matrix"))(self.visitor.do_predicates, self.visitor) return ret def pymbolic_unit_outer_normal(self): diff --git a/python/dune/perftool/sumfact/accumulation.py b/python/dune/perftool/sumfact/accumulation.py index dac38218879442f9cc9c91d9a3c41cd2f3cbec5c..f8ef3940e0ffb348c59ea9ec1904b1537c91dd80 100644 --- a/python/dune/perftool/sumfact/accumulation.py +++ b/python/dune/perftool/sumfact/accumulation.py @@ -320,7 +320,7 @@ def generate_accumulation_instruction(expr, visitor): # Issue an instruction in the quadrature loop that fills the buffer # with the evaluation of the contribution at all quadrature points assignee = prim.Subscript(lp.TaggedVariable(temp, vsf.tag), - vsf.quadrature_index(sf)) + vsf.quadrature_index(sf, visitor)) contrib_dep = instruction(assignee=assignee, expression=expr, forced_iname_deps=frozenset(quadrature_inames(trial_leaf_element) + jacobian_inames), diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py index d01c8ef8edead963d0bd56fd4e34c956484957d3..c1ca1a842601a790bc10bcefda5ec8d0c44ce409 100644 --- a/python/dune/perftool/sumfact/basis.py +++ b/python/dune/perftool/sumfact/basis.py @@ -111,10 +111,9 @@ def _basis_functions_per_direction(element): return basis_size -@kernel_cached -def pymbolic_coefficient_gradient(element, restriction, index, coeff_func, visitor_indices): +def pymbolic_coefficient_gradient(element, restriction, index, coeff_func, visitor): sub_element = element - grad_index = visitor_indices[0] + grad_index = visitor.indices[0] if isinstance(element, MixedElement): sub_element = element.extract_component(index)[1] @@ -154,11 +153,11 @@ def pymbolic_coefficient_gradient(element, restriction, index, coeff_func, visit from dune.perftool.sumfact.realization import realize_sum_factorization_kernel var, insn_dep = realize_sum_factorization_kernel(vsf) - return prim.Subscript(var, vsf.quadrature_index(sf)), None + visitor.indices = None + return prim.Subscript(var, vsf.quadrature_index(sf, visitor)) -@kernel_cached -def pymbolic_coefficient(element, restriction, index, coeff_func, visitor_indices): +def pymbolic_coefficient(element, restriction, index, coeff_func, visitor): sub_element = element if isinstance(element, MixedElement): sub_element = element.extract_component(index)[1] @@ -197,7 +196,8 @@ def pymbolic_coefficient(element, restriction, index, coeff_func, visitor_indice from dune.perftool.sumfact.realization import realize_sum_factorization_kernel var, _ = realize_sum_factorization_kernel(vsf) - return prim.Subscript(var, vsf.quadrature_index(sf)), None + visitor.indices = None + return prim.Subscript(var, vsf.quadrature_index(sf, visitor)) @iname diff --git a/python/dune/perftool/sumfact/geometry.py b/python/dune/perftool/sumfact/geometry.py index e7d6c1271b5251626961df0bf3d320446a92908a..9f0fbcbd31a0e8aada1374296f8f8d4f714409a3 100644 --- a/python/dune/perftool/sumfact/geometry.py +++ b/python/dune/perftool/sumfact/geometry.py @@ -67,10 +67,9 @@ class GeoCornersInput(SumfactKernelInputBase, ImmutableRecord): ) -@kernel_cached @backend(interface="spatial_coordinate", name="default") -def pymbolic_spatial_coordinate_multilinear(visitor_indices, do_predicates, visitor): - assert len(visitor_indices) == 1 +def pymbolic_spatial_coordinate_multilinear(do_predicates, visitor): + assert len(visitor.indices) == 1 # Construct the matrix sequence for the evaluation of the global coordinate. # We need to manually construct this one, because on facets, we want to use the @@ -80,7 +79,7 @@ def pymbolic_spatial_coordinate_multilinear(visitor_indices, do_predicates, visi from dune.perftool.sumfact.tabulation import quadrature_points_per_direction, BasisTabulationMatrix quadrature_size = quadrature_points_per_direction() matrix_sequence = (BasisTabulationMatrix(quadrature_size=quadrature_size, basis_size=2),) * local_dimension() - inp = GeoCornersInput(visitor_indices[0]) + inp = GeoCornersInput(visitor.indices[0]) from dune.perftool.sumfact.symbolic import SumfactKernel sf = SumfactKernel(matrix_sequence=matrix_sequence, @@ -94,7 +93,8 @@ def pymbolic_spatial_coordinate_multilinear(visitor_indices, do_predicates, visi from dune.perftool.sumfact.realization import realize_sum_factorization_kernel var, _ = realize_sum_factorization_kernel(vsf) - return prim.Subscript(var, vsf.quadrature_index(sf)), None + visitor.indices = None + return prim.Subscript(var, vsf.quadrature_index(sf, visitor)), None @preamble @@ -126,11 +126,10 @@ def name_meshwidth(): return name -@kernel_cached @backend(interface="spatial_coordinate", name="diagonal_transformation_matrix") -def pymbolic_spatial_coordinate_axiparallel(visitor_indices, do_predicates, visitor): - assert len(visitor_indices) == 1 - index, = visitor_indices +def pymbolic_spatial_coordinate_axiparallel(do_predicates, visitor): + assert len(visitor.indices) == 1 + index, = visitor.indices # Urgh: *SOMEHOW* construct a face direction from dune.perftool.pdelab.restriction import Restriction @@ -159,6 +158,7 @@ def pymbolic_spatial_coordinate_axiparallel(visitor_indices, do_predicates, visi from dune.perftool.sumfact.quadrature import pymbolic_quadrature_position x = pymbolic_quadrature_position(iindex, visitor) + visitor.indices = None return prim.Subscript(prim.Variable(lowcorner), (index,)) + x * prim.Subscript(prim.Variable(meshwidth), (index,)), None diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py index a5c4eb0630227769821be5382f156b2a03b2760d..884464175b62ccbe0955fb67992ccd08d6723e1f 100644 --- a/python/dune/perftool/sumfact/symbolic.py +++ b/python/dune/perftool/sumfact/symbolic.py @@ -216,10 +216,16 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): """ return tuple(mat.quadrature_size for mat in self.matrix_sequence) - def quadrature_index(self, _): - element = self.trial_element - if element is not None and isinstance(element, MixedElement): - element = element.extract_component(self.trial_element_index)[1] + def quadrature_index(self, sf, visitor): + if visitor.current_info[1] is None: + element = None + element_index = 0 + else: + element = visitor.current_info[1].element + element_index = visitor.current_info[1].element_index + if isinstance(element, MixedElement): + element = element.extract_component(element_index)[1] + quad_inames = quadrature_inames(element) if len(self.matrix_sequence) == local_dimension(): return tuple(prim.Variable(i) for i in quad_inames) @@ -474,10 +480,16 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) return i return 0 - def _quadrature_index(self, sf): - element = self.trial_element - if element is not None and isinstance(element, MixedElement): - element = element.extract_component(self.trial_element_index)[1] + def _quadrature_index(self, sf, visitor): + if visitor.current_info[1] is None: + element = None + element_index = 0 + else: + element = visitor.current_info[1].element + element_index = visitor.current_info[1].element_index + if isinstance(element, MixedElement): + element = element.extract_component(element_index)[1] + quad_inames = quadrature_inames(element) index = [] @@ -506,10 +518,16 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) return tuple(index) - def vec_index(self, sf): - element = self.trial_element - if element is not None and isinstance(element, MixedElement): - element = element.extract_component(self.trial_element_index)[1] + def vec_index(self, sf, visitor): + if visitor.current_info[1] is None: + element = None + element_index = 0 + else: + element = visitor.current_info[1].element + element_index = visitor.current_info[1].element_index + if isinstance(element, MixedElement): + element = element.extract_component(element_index)[1] + quad_inames = quadrature_inames(element) sliced = 0 if len(sf.matrix_sequence) == local_dimension(): @@ -530,13 +548,13 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) def quadrature_shape(self): return tuple(mat.quadrature_size for mat in self.matrix_sequence) + (self.vector_width,) - def quadrature_index(self, sf, direct_index=None): - quad = self._quadrature_index(sf) + def quadrature_index(self, sf, visitor, direct_index=None): + quad = self._quadrature_index(sf, visitor) if direct_index is not None: assert isinstance(direct_index, tuple) return quad + direct_index else: - return quad + (self.vec_index(sf),) + return quad + (self.vec_index(sf, visitor),) @property def quadrature_dimtags(self):