Skip to content
Snippets Groups Projects
Commit 85e2dd2d authored by Dominic Kempf's avatar Dominic Kempf
Browse files

[bugfix] Correct quadrature inames at sumfact kernel output

When doing jacobians of nonlinear systems, the inames used for
a given sum factorization kernel output depend on the accumulation
context.
parent 560099ae
No related branches found
No related tags found
No related merge requests found
......@@ -46,23 +46,19 @@ class SumFactInterface(PDELabInterface):
return ret
def pymbolic_trialfunction_gradient(self, element, restriction, index):
ret, indices = pymbolic_coefficient_gradient(element, restriction, index, name_coefficientcontainer, self.visitor.indices)
self.visitor.indices = indices
ret = pymbolic_coefficient_gradient(element, restriction, index, name_coefficientcontainer, self.visitor)
return ret
def pymbolic_trialfunction(self, element, restriction, index):
ret, indices = pymbolic_coefficient(element, restriction, index, name_coefficientcontainer, self.visitor.indices)
self.visitor.indices = indices
ret = pymbolic_coefficient(element, restriction, index, name_coefficientcontainer, self.visitor)
return ret
def pymbolic_apply_function_gradient(self, element, restriction, index):
ret, indices = pymbolic_coefficient_gradient(element, restriction, index, name_applycontainer, self.visitor.indices)
self.visitor.indices = indices
ret = pymbolic_coefficient_gradient(element, restriction, index, name_applycontainer, self.visitor)
return ret
def pymbolic_apply_function(self, element, restriction, index):
ret, indices = pymbolic_coefficient(element, restriction, index, name_applycontainer, self.visitor.indices)
self.visitor.indices = indices
ret = pymbolic_coefficient(element, restriction, index, name_applycontainer, self.visitor)
return ret
def quadrature_inames(self):
......@@ -73,8 +69,7 @@ class SumFactInterface(PDELabInterface):
def pymbolic_spatial_coordinate(self):
import dune.perftool.sumfact.geometry
ret, indices = get_backend(interface="spatial_coordinate", selector=option_switch("diagonal_transformation_matrix"))(self.visitor.indices, self.visitor.do_predicates, self.visitor)
self.visitor.indices = indices
ret = get_backend(interface="spatial_coordinate", selector=option_switch("diagonal_transformation_matrix"))(self.visitor.do_predicates, self.visitor)
return ret
def pymbolic_unit_outer_normal(self):
......
......@@ -320,7 +320,7 @@ def generate_accumulation_instruction(expr, visitor):
# Issue an instruction in the quadrature loop that fills the buffer
# with the evaluation of the contribution at all quadrature points
assignee = prim.Subscript(lp.TaggedVariable(temp, vsf.tag),
vsf.quadrature_index(sf))
vsf.quadrature_index(sf, visitor))
contrib_dep = instruction(assignee=assignee,
expression=expr,
forced_iname_deps=frozenset(quadrature_inames(trial_leaf_element) + jacobian_inames),
......
......@@ -111,10 +111,9 @@ def _basis_functions_per_direction(element):
return basis_size
@kernel_cached
def pymbolic_coefficient_gradient(element, restriction, index, coeff_func, visitor_indices):
def pymbolic_coefficient_gradient(element, restriction, index, coeff_func, visitor):
sub_element = element
grad_index = visitor_indices[0]
grad_index = visitor.indices[0]
if isinstance(element, MixedElement):
sub_element = element.extract_component(index)[1]
......@@ -154,11 +153,11 @@ def pymbolic_coefficient_gradient(element, restriction, index, coeff_func, visit
from dune.perftool.sumfact.realization import realize_sum_factorization_kernel
var, insn_dep = realize_sum_factorization_kernel(vsf)
return prim.Subscript(var, vsf.quadrature_index(sf)), None
visitor.indices = None
return prim.Subscript(var, vsf.quadrature_index(sf, visitor))
@kernel_cached
def pymbolic_coefficient(element, restriction, index, coeff_func, visitor_indices):
def pymbolic_coefficient(element, restriction, index, coeff_func, visitor):
sub_element = element
if isinstance(element, MixedElement):
sub_element = element.extract_component(index)[1]
......@@ -197,7 +196,8 @@ def pymbolic_coefficient(element, restriction, index, coeff_func, visitor_indice
from dune.perftool.sumfact.realization import realize_sum_factorization_kernel
var, _ = realize_sum_factorization_kernel(vsf)
return prim.Subscript(var, vsf.quadrature_index(sf)), None
visitor.indices = None
return prim.Subscript(var, vsf.quadrature_index(sf, visitor))
@iname
......
......@@ -67,10 +67,9 @@ class GeoCornersInput(SumfactKernelInputBase, ImmutableRecord):
)
@kernel_cached
@backend(interface="spatial_coordinate", name="default")
def pymbolic_spatial_coordinate_multilinear(visitor_indices, do_predicates, visitor):
assert len(visitor_indices) == 1
def pymbolic_spatial_coordinate_multilinear(do_predicates, visitor):
assert len(visitor.indices) == 1
# Construct the matrix sequence for the evaluation of the global coordinate.
# We need to manually construct this one, because on facets, we want to use the
......@@ -80,7 +79,7 @@ def pymbolic_spatial_coordinate_multilinear(visitor_indices, do_predicates, visi
from dune.perftool.sumfact.tabulation import quadrature_points_per_direction, BasisTabulationMatrix
quadrature_size = quadrature_points_per_direction()
matrix_sequence = (BasisTabulationMatrix(quadrature_size=quadrature_size, basis_size=2),) * local_dimension()
inp = GeoCornersInput(visitor_indices[0])
inp = GeoCornersInput(visitor.indices[0])
from dune.perftool.sumfact.symbolic import SumfactKernel
sf = SumfactKernel(matrix_sequence=matrix_sequence,
......@@ -94,7 +93,8 @@ def pymbolic_spatial_coordinate_multilinear(visitor_indices, do_predicates, visi
from dune.perftool.sumfact.realization import realize_sum_factorization_kernel
var, _ = realize_sum_factorization_kernel(vsf)
return prim.Subscript(var, vsf.quadrature_index(sf)), None
visitor.indices = None
return prim.Subscript(var, vsf.quadrature_index(sf, visitor)), None
@preamble
......@@ -126,11 +126,10 @@ def name_meshwidth():
return name
@kernel_cached
@backend(interface="spatial_coordinate", name="diagonal_transformation_matrix")
def pymbolic_spatial_coordinate_axiparallel(visitor_indices, do_predicates, visitor):
assert len(visitor_indices) == 1
index, = visitor_indices
def pymbolic_spatial_coordinate_axiparallel(do_predicates, visitor):
assert len(visitor.indices) == 1
index, = visitor.indices
# Urgh: *SOMEHOW* construct a face direction
from dune.perftool.pdelab.restriction import Restriction
......@@ -159,6 +158,7 @@ def pymbolic_spatial_coordinate_axiparallel(visitor_indices, do_predicates, visi
from dune.perftool.sumfact.quadrature import pymbolic_quadrature_position
x = pymbolic_quadrature_position(iindex, visitor)
visitor.indices = None
return prim.Subscript(prim.Variable(lowcorner), (index,)) + x * prim.Subscript(prim.Variable(meshwidth), (index,)), None
......
......@@ -216,10 +216,16 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
"""
return tuple(mat.quadrature_size for mat in self.matrix_sequence)
def quadrature_index(self, _):
element = self.trial_element
if element is not None and isinstance(element, MixedElement):
element = element.extract_component(self.trial_element_index)[1]
def quadrature_index(self, sf, visitor):
if visitor.current_info[1] is None:
element = None
element_index = 0
else:
element = visitor.current_info[1].element
element_index = visitor.current_info[1].element_index
if isinstance(element, MixedElement):
element = element.extract_component(element_index)[1]
quad_inames = quadrature_inames(element)
if len(self.matrix_sequence) == local_dimension():
return tuple(prim.Variable(i) for i in quad_inames)
......@@ -474,10 +480,16 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
return i
return 0
def _quadrature_index(self, sf):
element = self.trial_element
if element is not None and isinstance(element, MixedElement):
element = element.extract_component(self.trial_element_index)[1]
def _quadrature_index(self, sf, visitor):
if visitor.current_info[1] is None:
element = None
element_index = 0
else:
element = visitor.current_info[1].element
element_index = visitor.current_info[1].element_index
if isinstance(element, MixedElement):
element = element.extract_component(element_index)[1]
quad_inames = quadrature_inames(element)
index = []
......@@ -506,10 +518,16 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
return tuple(index)
def vec_index(self, sf):
element = self.trial_element
if element is not None and isinstance(element, MixedElement):
element = element.extract_component(self.trial_element_index)[1]
def vec_index(self, sf, visitor):
if visitor.current_info[1] is None:
element = None
element_index = 0
else:
element = visitor.current_info[1].element
element_index = visitor.current_info[1].element_index
if isinstance(element, MixedElement):
element = element.extract_component(element_index)[1]
quad_inames = quadrature_inames(element)
sliced = 0
if len(sf.matrix_sequence) == local_dimension():
......@@ -530,13 +548,13 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
def quadrature_shape(self):
return tuple(mat.quadrature_size for mat in self.matrix_sequence) + (self.vector_width,)
def quadrature_index(self, sf, direct_index=None):
quad = self._quadrature_index(sf)
def quadrature_index(self, sf, visitor, direct_index=None):
quad = self._quadrature_index(sf, visitor)
if direct_index is not None:
assert isinstance(direct_index, tuple)
return quad + direct_index
else:
return quad + (self.vec_index(sf),)
return quad + (self.vec_index(sf, visitor),)
@property
def quadrature_dimtags(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment