-
Dominic Kempf authored
When doing jacobians of nonlinear systems, the inames used for a given sum factorization kernel output depend on the accumulation context.
Dominic Kempf authoredWhen doing jacobians of nonlinear systems, the inames used for a given sum factorization kernel output depend on the accumulation context.
""" A pymbolic node representing a sum factorization kernel """
from dune.perftool.options import get_option
from dune.perftool.generation import get_counted_variable
from dune.perftool.pdelab.geometry import local_dimension, world_dimension
from dune.perftool.sumfact.quadrature import quadrature_inames
from dune.perftool.sumfact.tabulation import BasisTabulationMatrixBase, BasisTabulationMatrixArray
from pytools import ImmutableRecord, product
from ufl import MixedElement
import pymbolic.primitives as prim
import loopy as lp
import frozendict
import inspect
class SumfactKernelInputBase(object):
@property
def direct_input(self):
return None
def realize(self, sf, i, dep):
pass
class SumfactKernelBase(object):
pass
class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
def __init__(self,
matrix_sequence=None,
buffer=None,
stage=1,
position_priority=None,
restriction=None,
insn_dep=frozenset(),
input=None,
accumvar=None,
test_element=None,
test_element_index=None,
trial_element=None,
trial_element_index=None,
predicates=frozenset(),
):
"""Create a sum factorization kernel
Sum factorization can be written as
Y = R_{d-1} (A_{d-1} * ... * R_0 (A_0 * X)...)
with:
- X: Input rank d tensor of dimension n_0 x ... x n_{d-1}
- Y: Output rank d tensor of dimension m_0 x ... x m_{d-1}
- A_l: Values of 1D basis evaluations at quadrature points in l
direction, matrix of dimension m_l x n_l
- R_l: Transformation operator that permutes the underlying data
vector of the rank d tensor in such a way that the fastest
direction gets the slowest direction
In the l'th step we have the following setup:
- A_l: Matrix of dimensions m_l x n_l
- X_l: Rank d tensor of dimensions n_l x ... x n_{d-1} x m_0 x ... x m_{l-1}
- R_l: Transformation operator
Looking at the indizes the following will happen:
X --> [n_l,...,n_{d-1},m_0,...,m_{l-1}]
A_l * X --> [m_l,n_l] * [n_l, ...] = [m_l,n_{l+1},...,n_{d-1},m_0,...,m_{l-1}]
R_l (A_l*X) --> [n_{l+1},...,n_{d-1},m_0,...,m_{l-1}]
So the multiplication with A_l is a reduction over one index and
the transformation brings the next reduction index in the fastest
position.
It can make sense to permute the order of directions. If you have
a small m_l (e.g. stage 1 on faces) it is better to do direction l
first. This can be done by:
- Permuting the order of the A matrices.
- Permuting the input tensor.
- Permuting the output tensor (this assures that the directions of
the output tensor are again ordered from 0 to d-1).
Note, that you will typically *not* set all of the below arguments,
but only some. The vectorization strategy may set others for you.
The only argument really needed in all cases is matrix_sequence.
Arguments:
----------
matrix_sequence: A tuple of BasisTabulationMatrixBase instances
The list of tensors to be applied to the input.
Order of application is from 0 up.
buffer: A string identifying the flip flop buffer in use
for intermediate results. The memory is expected to be
pre-initialized with the input or you have to provide
direct_input (FastDGGridOperator).
stage: 1 or 3
position_priority: Will be used in the dry run to order kernels
when doing vectorization e.g. (dx u,dy u,dz u, u).
restriction: Restriction for faces values.
insn_dep: An instruction ID that the first issued instruction
should depend upon. All following ones will depend on each
other.
input: An SumfactKernelInputBase instance describing the input of the kernel
accumvar: The accumulation variable to accumulate into
trial_element: The leaf element of the trial function space.
Used to correctly nest stage 3 in the jacobian case.
test_element: The leaf element of the test function space
Used to compute offsets in the fastdg case.
test_element_index: the component of the test_element
trial_element_index: the component of the trial_element
"""
# Assert the inputs!
assert isinstance(matrix_sequence, tuple)
assert all(isinstance(m, BasisTabulationMatrixBase) for m in matrix_sequence)
assert stage in (1, 3)
if stage == 1:
assert isinstance(input, SumfactKernelInputBase)
if stage == 3:
assert isinstance(restriction, tuple)
assert isinstance(insn_dep, frozenset)
# The following construction is a bit weird: Dict comprehensions do not have
# access to the locals of the calling scope: So we need to do the eval beforehand
defaultdict = {}
for a in SumfactKernel.init_arg_names:
defaultdict[a] = eval(a)
# Call the base class constructors
ImmutableRecord.__init__(self, **defaultdict)
prim.Variable.__init__(self, "SUMFACT")
#
# The methods/fields needed to get a well-formed pymbolic node
#
def __getinitargs__(self):
return tuple(getattr(self, arg) for arg in SumfactKernel.init_arg_names)
def stringifier(self):
return lp.symbolic.StringifyMapper
mapper_method = "map_sumfact_kernel"
#
# Some cache key definitions
# Watch out for the documentation to see which key is used unter what circumstances
#
@property
def cache_key(self):
""" The cache key that can be used in generation magic
Any two sum factorization kernels having the same cache_key
are realized simulatenously!
"""
return (self.matrix_sequence, self.restriction, self.stage, self.buffer, self.test_element_index)
@property
def input_key(self):
""" A cache key for the input coefficients
Any two sum factorization kernels having the same input_key
work on the same input coefficient (and are suitable for simultaneous
treatment because of that)
"""
return (self.input, self.restriction, self.accumvar, self.trial_element_index)
@property
def group_name(self):
return "sfgroup_{}_{}_{}_{}".format(self.input, self.restriction, self.accumvar, self.trial_element_index)
#
# Some convenience methods to extract information about the sum factorization kernel
#
@property
def length(self):
""" The number of matrices to apply """
return len(self.matrix_sequence)
@property
def vectorized(self):
return False
@property
def transposed(self):
return self.matrix_sequence[0].transpose
@property
def within_inames(self):
if self.trial_element is None:
return ()
else:
from dune.perftool.sumfact.basis import lfs_inames
element = self.trial_element
if isinstance(element, MixedElement):
element = element.extract_component(self.trial_element_index)[1]
return lfs_inames(element, self.restriction)
def vec_index(self, sf):
""" Map an unvectorized sumfact kernel object to its position
in the vectorized kernel
"""
return 0
@property
def quadrature_shape(self):
""" The shape of a temporary for the quadrature points
Takes into account the lower dimensionality of faces and vectorization.
"""
return tuple(mat.quadrature_size for mat in self.matrix_sequence)
def quadrature_index(self, sf, visitor):
if visitor.current_info[1] is None:
element = None
element_index = 0
else:
element = visitor.current_info[1].element
element_index = visitor.current_info[1].element_index
if isinstance(element, MixedElement):
element = element.extract_component(element_index)[1]
quad_inames = quadrature_inames(element)
if len(self.matrix_sequence) == local_dimension():
return tuple(prim.Variable(i) for i in quad_inames)
# Traverse all the quadrature inames and map them to their correct direction
index = []
i = 0
for d in range(world_dimension()):
if self.matrix_sequence[d].face is None:
index.append(prim.Variable(quad_inames[i]))
i = i + 1
else:
index.append(0)
return tuple(index)
@property
def quadrature_dimtags(self):
""" The dim_tags of a temporary for the quadrature points
Takes into account the lower dimensionality of faces and vectorization.
"""
tags = ["f"] * len(self.quadrature_shape)
return ",".join(tags)
@property
def dof_shape(self):
""" The shape of a temporary for the degrees of freedom
Takes into account vectorization.
"""
return tuple(mat.basis_size for mat in self.matrix_sequence)
@property
def dof_dimtags(self):
""" The dim_tags of a temporary for the degrees of freedom
Takes into account vectorization.
"""
tags = ["f"] * len(self.dof_shape)
return ",".join(tags)
@property
def output_shape(self):
if self.stage == 1:
return self.quadrature_shape
else:
return self.dof_shape
@property
def output_dimtags(self):
if self.stage == 1:
return self.quadrature_dimtags
else:
return self.dof_dimtags
@property
def tag(self):
return "sumfac"
#
# Define properties for conformity with the interface of VectorizedSumfactKernel
#
@property
def padded_indices(self):
return set()
@property
def horizontal_width(self):
return 1
def horizontal_index(self, _):
return 0
@property
def vertical_width(self):
return 1
@property
def vector_width(self):
return 1
# Extract the argument list and store it on the class. This needs to be done
# outside of the class because the SumfactKernel class object needs to be fully
# initialized in order to extract the information from __init__.
SumfactKernel.init_arg_names = tuple(inspect.getargspec(SumfactKernel.__init__)[0][1:])
class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
def __init__(self,
kernels=None,
horizontal_width=1,
vertical_width=1,
buffer=None,
insn_dep=frozenset(),
):
# Assert the input data structure
assert isinstance(kernels, tuple)
assert all(isinstance(k, SumfactKernel) for k in kernels)
# Assert all the properties that need to be the same across all subkernels
assert len(set(k.stage for k in kernels)) == 1
assert len(set(k.length for k in kernels)) == 1
assert len(set(k.restriction for k in kernels)) == 1
assert len(set(k.within_inames for k in kernels)) == 1
assert len(set(k.predicates for k in kernels)) == 1
# Assert properties of the matrix sequence of the underlying kernels
for i in range(kernels[0].length):
assert len(set(tuple(k.matrix_sequence[i].rows for k in kernels))) == 1
assert len(set(tuple(k.matrix_sequence[i].cols for k in kernels))) == 1
assert len(set(tuple(k.matrix_sequence[i].face for k in kernels))) == 1
assert len(set(tuple(k.matrix_sequence[i].transpose for k in kernels))) == 1
# Join the instruction dependencies of all subkernels
insn_dep = insn_dep.union(k.insn_dep for k in kernels)
# We currently assume that all subkernels are consecutive, 0-based within the vector
assert None not in kernels
ImmutableRecord.__init__(self,
kernels=kernels,
horizontal_width=horizontal_width,
buffer=buffer,
insn_dep=insn_dep,
vertical_width=vertical_width,
)
prim.Variable.__init__(self, "VecSUMFAC")
def __getinitargs__(self):
return (self.kernels, self.horizontal_width, self.vertical_width, self.buffer, self.insn_dep)
def stringifier(self):
return lp.symbolic.StringifyMapper
mapper_method = "map_vectorized_sumfact_kernel"
init_arg_names = ("kernels", "horizontal_width", "vertical_width", "buffer", "insn_dep")
#
# Some cache key definitions
# Watch out for the documentation to see which key is used unter what circumstances
#
@property
def cache_key(self):
""" The cache key that can be used in generation magic
Any two sum factorization kernels having the same cache_key
are realized simulatenously!
"""
return (self.matrix_sequence, self.restriction, self.stage, self.buffer)
#
# Deduce all data fields of normal sum factorization kernels from the underlying kernels
#
@property
def matrix_sequence(self):
return tuple(BasisTabulationMatrixArray(tuple(k.matrix_sequence[i] for k in self.kernels),
width=self.vector_width,
)
for i in range(self.length))
@property
def stage(self):
return self.kernels[0].stage
@property
def restriction(self):
return self.kernels[0].restriction
@property
def within_inames(self):
return self.kernels[0].within_inames
@property
def test_element(self):
return self.kernels[0].test_element
@property
def test_element_index(self):
return self.kernels[0].test_element_index
@property
def trial_element(self):
return self.kernels[0].trial_element
@property
def trial_element_index(self):
return self.kernels[0].trial_element_index
@property
def predicates(self):
return self.kernels[0].predicates
@property
def input(self):
assert len(set(k.input for k in self.kernels)) == 1
return self.kernels[0].input
@property
def accumvar(self):
assert len(set(k.accumvar for k in self.kernels)) == 1
return self.kernels[0].accumvar
@property
def transposed(self):
return self.kernels[0].transposed
#
# Define some properties only needed for this one
#
@property
def padded_indices(self):
indices = set(range(self.vector_width)) - set(range(len(self.kernels)))
return tuple(self.kernels[0].quadrature_index(None) + (i,) for i in indices)
@property
def vector_width(self):
return self.horizontal_width * self.vertical_width
#
# Define the same properties the normal SumfactKernel defines
#
@property
def cache_key(self):
return (tuple(k.cache_key for k in self.kernels), self.buffer)
@property
def input_key(self):
return tuple(k.input_key for k in self.kernels)
@property
def group_name(self):
return "_".join(k.group_name for k in self.kernels)
@property
def length(self):
return self.kernels[0].length
@property
def vectorized(self):
return True
def horizontal_index(self, sf):
key = tuple(mat.derivative for mat in sf.matrix_sequence)
for i, k in enumerate(self.kernels):
if tuple(mat.derivative for mat in k.matrix_sequence) == key:
return i
return 0
def _quadrature_index(self, sf, visitor):
if visitor.current_info[1] is None:
element = None
element_index = 0
else:
element = visitor.current_info[1].element
element_index = visitor.current_info[1].element_index
if isinstance(element, MixedElement):
element = element.extract_component(element_index)[1]
quad_inames = quadrature_inames(element)
index = []
if len(self.matrix_sequence) == local_dimension():
for d in range(local_dimension()):
addindex = prim.Variable(quad_inames[d])
if self.matrix_sequence[d].slice_size:
addindex = addindex // self.vertical_width
index.append(addindex)
else:
# Traverse all the quadrature inames and map them to their correct direction
i = 0
for d in range(world_dimension()):
if self.matrix_sequence[d].face is None:
addindex = prim.Variable(quad_inames[i])
if self.matrix_sequence[d].slice_size:
addindex = addindex // self.vertical_width
index.append(addindex)
i = i + 1
else:
index.append(0)
return tuple(index)
def vec_index(self, sf, visitor):
if visitor.current_info[1] is None:
element = None
element_index = 0
else:
element = visitor.current_info[1].element
element_index = visitor.current_info[1].element_index
if isinstance(element, MixedElement):
element = element.extract_component(element_index)[1]
quad_inames = quadrature_inames(element)
sliced = 0
if len(sf.matrix_sequence) == local_dimension():
for d in range(local_dimension()):
if self.matrix_sequence[d].slice_size:
sliced = prim.Variable(quad_inames[d])
else:
i = 0
for d in range(world_dimension()):
if self.matrix_sequence[d].face is None:
if self.matrix_sequence[d].slice_size:
sliced = prim.Variable(quad_inames[i])
i = i + 1
return self.horizontal_index(sf) + prim.Remainder(sliced, self.vertical_width)
@property
def quadrature_shape(self):
return tuple(mat.quadrature_size for mat in self.matrix_sequence) + (self.vector_width,)
def quadrature_index(self, sf, visitor, direct_index=None):
quad = self._quadrature_index(sf, visitor)
if direct_index is not None:
assert isinstance(direct_index, tuple)
return quad + direct_index
else:
return quad + (self.vec_index(sf, visitor),)
@property
def quadrature_dimtags(self):
tags = ["f"] * len(self.quadrature_shape)
tags[-1] = 'c'
return ",".join(tags)
@property
def dof_shape(self):
return tuple(mat.basis_size for mat in self.matrix_sequence) + (self.vector_width,)
@property
def dof_dimtags(self):
tags = ["f"] * len(self.dof_shape)
tags[-1] = 'vec'
return ",".join(tags)
@property
def output_shape(self):
if self.stage == 1:
return self.quadrature_shape
else:
return self.dof_shape
@property
def output_dimtags(self):
if self.stage == 1:
return self.quadrature_dimtags
else:
return self.dof_dimtags
@property
def tag(self):
return "vecsumfac_h{}_v{}".format(self.horizontal_width, self.vertical_width)