""" A pymbolic node representing a sum factorization kernel """ from dune.perftool.options import get_option from dune.perftool.generation import get_counted_variable from dune.perftool.pdelab.geometry import local_dimension, world_dimension from dune.perftool.sumfact.quadrature import quadrature_inames from dune.perftool.sumfact.tabulation import BasisTabulationMatrixBase, BasisTabulationMatrixArray from pytools import ImmutableRecord, product from ufl import MixedElement import pymbolic.primitives as prim import loopy as lp import frozendict import inspect class SumfactKernelInputBase(object): @property def direct_input(self): return None def realize(self, sf, i, dep): pass class SumfactKernelBase(object): pass class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): def __init__(self, matrix_sequence=None, buffer=None, stage=1, position_priority=None, restriction=None, insn_dep=frozenset(), input=None, accumvar=None, test_element=None, test_element_index=None, trial_element=None, trial_element_index=None, predicates=frozenset(), ): """Create a sum factorization kernel Sum factorization can be written as Y = R_{d-1} (A_{d-1} * ... * R_0 (A_0 * X)...) with: - X: Input rank d tensor of dimension n_0 x ... x n_{d-1} - Y: Output rank d tensor of dimension m_0 x ... x m_{d-1} - A_l: Values of 1D basis evaluations at quadrature points in l direction, matrix of dimension m_l x n_l - R_l: Transformation operator that permutes the underlying data vector of the rank d tensor in such a way that the fastest direction gets the slowest direction In the l'th step we have the following setup: - A_l: Matrix of dimensions m_l x n_l - X_l: Rank d tensor of dimensions n_l x ... x n_{d-1} x m_0 x ... x m_{l-1} - R_l: Transformation operator Looking at the indizes the following will happen: X --> [n_l,...,n_{d-1},m_0,...,m_{l-1}] A_l * X --> [m_l,n_l] * [n_l, ...] = [m_l,n_{l+1},...,n_{d-1},m_0,...,m_{l-1}] R_l (A_l*X) --> [n_{l+1},...,n_{d-1},m_0,...,m_{l-1}] So the multiplication with A_l is a reduction over one index and the transformation brings the next reduction index in the fastest position. It can make sense to permute the order of directions. If you have a small m_l (e.g. stage 1 on faces) it is better to do direction l first. This can be done by: - Permuting the order of the A matrices. - Permuting the input tensor. - Permuting the output tensor (this assures that the directions of the output tensor are again ordered from 0 to d-1). Note, that you will typically *not* set all of the below arguments, but only some. The vectorization strategy may set others for you. The only argument really needed in all cases is matrix_sequence. Arguments: ---------- matrix_sequence: A tuple of BasisTabulationMatrixBase instances The list of tensors to be applied to the input. Order of application is from 0 up. buffer: A string identifying the flip flop buffer in use for intermediate results. The memory is expected to be pre-initialized with the input or you have to provide direct_input (FastDGGridOperator). stage: 1 or 3 position_priority: Will be used in the dry run to order kernels when doing vectorization e.g. (dx u,dy u,dz u, u). restriction: Restriction for faces values. insn_dep: An instruction ID that the first issued instruction should depend upon. All following ones will depend on each other. input: An SumfactKernelInputBase instance describing the input of the kernel accumvar: The accumulation variable to accumulate into trial_element: The leaf element of the trial function space. Used to correctly nest stage 3 in the jacobian case. test_element: The leaf element of the test function space Used to compute offsets in the fastdg case. test_element_index: the component of the test_element trial_element_index: the component of the trial_element """ # Assert the inputs! assert isinstance(matrix_sequence, tuple) assert all(isinstance(m, BasisTabulationMatrixBase) for m in matrix_sequence) assert stage in (1, 3) if stage == 1: assert isinstance(input, SumfactKernelInputBase) if stage == 3: assert isinstance(restriction, tuple) assert isinstance(insn_dep, frozenset) # The following construction is a bit weird: Dict comprehensions do not have # access to the locals of the calling scope: So we need to do the eval beforehand defaultdict = {} for a in SumfactKernel.init_arg_names: defaultdict[a] = eval(a) # Call the base class constructors ImmutableRecord.__init__(self, **defaultdict) prim.Variable.__init__(self, "SUMFACT") # # The methods/fields needed to get a well-formed pymbolic node # def __getinitargs__(self): return tuple(getattr(self, arg) for arg in SumfactKernel.init_arg_names) def stringifier(self): return lp.symbolic.StringifyMapper mapper_method = "map_sumfact_kernel" # # Some cache key definitions # Watch out for the documentation to see which key is used unter what circumstances # @property def cache_key(self): """ The cache key that can be used in generation magic Any two sum factorization kernels having the same cache_key are realized simulatenously! """ return (self.matrix_sequence, self.restriction, self.stage, self.buffer, self.test_element_index) @property def input_key(self): """ A cache key for the input coefficients Any two sum factorization kernels having the same input_key work on the same input coefficient (and are suitable for simultaneous treatment because of that) """ return (self.input, self.restriction, self.accumvar, self.trial_element_index) @property def group_name(self): return "sfgroup_{}_{}_{}_{}".format(self.input, self.restriction, self.accumvar, self.trial_element_index) # # Some convenience methods to extract information about the sum factorization kernel # @property def length(self): """ The number of matrices to apply """ return len(self.matrix_sequence) @property def vectorized(self): return False @property def transposed(self): return self.matrix_sequence[0].transpose @property def within_inames(self): if self.trial_element is None: return () else: from dune.perftool.sumfact.basis import lfs_inames element = self.trial_element if isinstance(element, MixedElement): element = element.extract_component(self.trial_element_index)[1] return lfs_inames(element, self.restriction) def vec_index(self, sf): """ Map an unvectorized sumfact kernel object to its position in the vectorized kernel """ return 0 @property def quadrature_shape(self): """ The shape of a temporary for the quadrature points Takes into account the lower dimensionality of faces and vectorization. """ return tuple(mat.quadrature_size for mat in self.matrix_sequence) def quadrature_index(self, sf, visitor): if visitor.current_info[1] is None: element = None element_index = 0 else: element = visitor.current_info[1].element element_index = visitor.current_info[1].element_index if isinstance(element, MixedElement): element = element.extract_component(element_index)[1] quad_inames = quadrature_inames(element) if len(self.matrix_sequence) == local_dimension(): return tuple(prim.Variable(i) for i in quad_inames) # Traverse all the quadrature inames and map them to their correct direction index = [] i = 0 for d in range(world_dimension()): if self.matrix_sequence[d].face is None: index.append(prim.Variable(quad_inames[i])) i = i + 1 else: index.append(0) return tuple(index) @property def quadrature_dimtags(self): """ The dim_tags of a temporary for the quadrature points Takes into account the lower dimensionality of faces and vectorization. """ tags = ["f"] * len(self.quadrature_shape) return ",".join(tags) @property def dof_shape(self): """ The shape of a temporary for the degrees of freedom Takes into account vectorization. """ return tuple(mat.basis_size for mat in self.matrix_sequence) @property def dof_dimtags(self): """ The dim_tags of a temporary for the degrees of freedom Takes into account vectorization. """ tags = ["f"] * len(self.dof_shape) return ",".join(tags) @property def output_shape(self): if self.stage == 1: return self.quadrature_shape else: return self.dof_shape @property def output_dimtags(self): if self.stage == 1: return self.quadrature_dimtags else: return self.dof_dimtags @property def tag(self): return "sumfac" # # Define properties for conformity with the interface of VectorizedSumfactKernel # @property def padded_indices(self): return set() @property def horizontal_width(self): return 1 def horizontal_index(self, _): return 0 @property def vertical_width(self): return 1 @property def vector_width(self): return 1 # Extract the argument list and store it on the class. This needs to be done # outside of the class because the SumfactKernel class object needs to be fully # initialized in order to extract the information from __init__. SumfactKernel.init_arg_names = tuple(inspect.getargspec(SumfactKernel.__init__)[0][1:]) class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): def __init__(self, kernels=None, horizontal_width=1, vertical_width=1, buffer=None, insn_dep=frozenset(), ): # Assert the input data structure assert isinstance(kernels, tuple) assert all(isinstance(k, SumfactKernel) for k in kernels) # Assert all the properties that need to be the same across all subkernels assert len(set(k.stage for k in kernels)) == 1 assert len(set(k.length for k in kernels)) == 1 assert len(set(k.restriction for k in kernels)) == 1 assert len(set(k.within_inames for k in kernels)) == 1 assert len(set(k.predicates for k in kernels)) == 1 # Assert properties of the matrix sequence of the underlying kernels for i in range(kernels[0].length): assert len(set(tuple(k.matrix_sequence[i].rows for k in kernels))) == 1 assert len(set(tuple(k.matrix_sequence[i].cols for k in kernels))) == 1 assert len(set(tuple(k.matrix_sequence[i].face for k in kernels))) == 1 assert len(set(tuple(k.matrix_sequence[i].transpose for k in kernels))) == 1 # Join the instruction dependencies of all subkernels insn_dep = insn_dep.union(k.insn_dep for k in kernels) # We currently assume that all subkernels are consecutive, 0-based within the vector assert None not in kernels ImmutableRecord.__init__(self, kernels=kernels, horizontal_width=horizontal_width, buffer=buffer, insn_dep=insn_dep, vertical_width=vertical_width, ) prim.Variable.__init__(self, "VecSUMFAC") def __getinitargs__(self): return (self.kernels, self.horizontal_width, self.vertical_width, self.buffer, self.insn_dep) def stringifier(self): return lp.symbolic.StringifyMapper mapper_method = "map_vectorized_sumfact_kernel" init_arg_names = ("kernels", "horizontal_width", "vertical_width", "buffer", "insn_dep") # # Some cache key definitions # Watch out for the documentation to see which key is used unter what circumstances # @property def cache_key(self): """ The cache key that can be used in generation magic Any two sum factorization kernels having the same cache_key are realized simulatenously! """ return (self.matrix_sequence, self.restriction, self.stage, self.buffer) # # Deduce all data fields of normal sum factorization kernels from the underlying kernels # @property def matrix_sequence(self): return tuple(BasisTabulationMatrixArray(tuple(k.matrix_sequence[i] for k in self.kernels), width=self.vector_width, ) for i in range(self.length)) @property def stage(self): return self.kernels[0].stage @property def restriction(self): return self.kernels[0].restriction @property def within_inames(self): return self.kernels[0].within_inames @property def test_element(self): return self.kernels[0].test_element @property def test_element_index(self): return self.kernels[0].test_element_index @property def trial_element(self): return self.kernels[0].trial_element @property def trial_element_index(self): return self.kernels[0].trial_element_index @property def predicates(self): return self.kernels[0].predicates @property def input(self): assert len(set(k.input for k in self.kernels)) == 1 return self.kernels[0].input @property def accumvar(self): assert len(set(k.accumvar for k in self.kernels)) == 1 return self.kernels[0].accumvar @property def transposed(self): return self.kernels[0].transposed # # Define some properties only needed for this one # @property def padded_indices(self): indices = set(range(self.vector_width)) - set(range(len(self.kernels))) return tuple(self.kernels[0].quadrature_index(None) + (i,) for i in indices) @property def vector_width(self): return self.horizontal_width * self.vertical_width # # Define the same properties the normal SumfactKernel defines # @property def cache_key(self): return (tuple(k.cache_key for k in self.kernels), self.buffer) @property def input_key(self): return tuple(k.input_key for k in self.kernels) @property def group_name(self): return "_".join(k.group_name for k in self.kernels) @property def length(self): return self.kernels[0].length @property def vectorized(self): return True def horizontal_index(self, sf): key = tuple(mat.derivative for mat in sf.matrix_sequence) for i, k in enumerate(self.kernels): if tuple(mat.derivative for mat in k.matrix_sequence) == key: return i return 0 def _quadrature_index(self, sf, visitor): if visitor.current_info[1] is None: element = None element_index = 0 else: element = visitor.current_info[1].element element_index = visitor.current_info[1].element_index if isinstance(element, MixedElement): element = element.extract_component(element_index)[1] quad_inames = quadrature_inames(element) index = [] if len(self.matrix_sequence) == local_dimension(): for d in range(local_dimension()): addindex = prim.Variable(quad_inames[d]) if self.matrix_sequence[d].slice_size: addindex = addindex // self.vertical_width index.append(addindex) else: # Traverse all the quadrature inames and map them to their correct direction i = 0 for d in range(world_dimension()): if self.matrix_sequence[d].face is None: addindex = prim.Variable(quad_inames[i]) if self.matrix_sequence[d].slice_size: addindex = addindex // self.vertical_width index.append(addindex) i = i + 1 else: index.append(0) return tuple(index) def vec_index(self, sf, visitor): if visitor.current_info[1] is None: element = None element_index = 0 else: element = visitor.current_info[1].element element_index = visitor.current_info[1].element_index if isinstance(element, MixedElement): element = element.extract_component(element_index)[1] quad_inames = quadrature_inames(element) sliced = 0 if len(sf.matrix_sequence) == local_dimension(): for d in range(local_dimension()): if self.matrix_sequence[d].slice_size: sliced = prim.Variable(quad_inames[d]) else: i = 0 for d in range(world_dimension()): if self.matrix_sequence[d].face is None: if self.matrix_sequence[d].slice_size: sliced = prim.Variable(quad_inames[i]) i = i + 1 return self.horizontal_index(sf) + prim.Remainder(sliced, self.vertical_width) @property def quadrature_shape(self): return tuple(mat.quadrature_size for mat in self.matrix_sequence) + (self.vector_width,) def quadrature_index(self, sf, visitor, direct_index=None): quad = self._quadrature_index(sf, visitor) if direct_index is not None: assert isinstance(direct_index, tuple) return quad + direct_index else: return quad + (self.vec_index(sf, visitor),) @property def quadrature_dimtags(self): tags = ["f"] * len(self.quadrature_shape) tags[-1] = 'c' return ",".join(tags) @property def dof_shape(self): return tuple(mat.basis_size for mat in self.matrix_sequence) + (self.vector_width,) @property def dof_dimtags(self): tags = ["f"] * len(self.dof_shape) tags[-1] = 'vec' return ",".join(tags) @property def output_shape(self): if self.stage == 1: return self.quadrature_shape else: return self.dof_shape @property def output_dimtags(self): if self.stage == 1: return self.quadrature_dimtags else: return self.dof_dimtags @property def tag(self): return "vecsumfac_h{}_v{}".format(self.horizontal_width, self.vertical_width)