Skip to content
Snippets Groups Projects
symbolic.py 19.47 KiB
""" A pymbolic node representing a sum factorization kernel """

from dune.perftool.options import get_option
from dune.perftool.generation import get_counted_variable
from dune.perftool.pdelab.geometry import local_dimension, world_dimension
from dune.perftool.sumfact.quadrature import quadrature_inames
from dune.perftool.sumfact.tabulation import BasisTabulationMatrixBase, BasisTabulationMatrixArray

from pytools import ImmutableRecord, product

from ufl import MixedElement

import pymbolic.primitives as prim
import loopy as lp
import frozendict
import inspect


class SumfactKernelInputBase(object):
    @property
    def direct_input(self):
        return None

    def realize(self, sf, i, dep):
        pass


class SumfactKernelBase(object):
    pass


class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
    def __init__(self,
                 matrix_sequence=None,
                 buffer=None,
                 stage=1,
                 position_priority=None,
                 restriction=None,
                 insn_dep=frozenset(),
                 input=None,
                 accumvar=None,
                 test_element=None,
                 test_element_index=None,
                 trial_element=None,
                 trial_element_index=None,
                 predicates=frozenset(),
                 ):
        """Create a sum factorization kernel

        Sum factorization can be written as

        Y = R_{d-1} (A_{d-1} * ... * R_0 (A_0 * X)...)

        with:
        - X: Input rank d tensor of dimension n_0 x ... x n_{d-1}
        - Y: Output rank d tensor of dimension m_0 x ... x m_{d-1}
        - A_l: Values of 1D basis evaluations at quadrature points in l
               direction, matrix of dimension m_l x n_l
        - R_l: Transformation operator that permutes the underlying data
               vector of the rank d tensor in such a way that the fastest
               direction gets the slowest direction

        In the l'th step we have the following setup:
        - A_l: Matrix of dimensions m_l x n_l
        - X_l: Rank d tensor of dimensions n_l x ... x n_{d-1} x m_0 x ... x m_{l-1}
        - R_l: Transformation operator

        Looking at the indizes the following will happen:
        X --> [n_l,...,n_{d-1},m_0,...,m_{l-1}]
        A_l * X --> [m_l,n_l] * [n_l, ...] = [m_l,n_{l+1},...,n_{d-1},m_0,...,m_{l-1}]
        R_l (A_l*X) --> [n_{l+1},...,n_{d-1},m_0,...,m_{l-1}]

        So the multiplication with A_l is a reduction over one index and
        the transformation brings the next reduction index in the fastest
        position.

        It can make sense to permute the order of directions. If you have
        a small m_l (e.g. stage 1 on faces) it is better to do direction l
        first. This can be done by:

        - Permuting the order of the A matrices.
        - Permuting the input tensor.
        - Permuting the output tensor (this assures that the directions of
          the output tensor are again ordered from 0 to d-1).

        Note, that you will typically *not* set all of the below arguments,
        but only some. The vectorization strategy may set others for you.
        The only argument really needed in all cases is matrix_sequence.

        Arguments:
        ----------
        matrix_sequence: A tuple of BasisTabulationMatrixBase instances
            The list of tensors to be applied to the input.
            Order of application is from 0 up.
        buffer: A string identifying the flip flop buffer in use
            for intermediate results. The memory is expected to be
            pre-initialized with the input or you have to provide
            direct_input (FastDGGridOperator).
        stage: 1 or 3
        position_priority: Will be used in the dry run to order kernels
            when doing vectorization e.g. (dx u,dy u,dz u, u).
        restriction: Restriction for faces values.
        insn_dep: An instruction ID that the first issued instruction
            should depend upon. All following ones will depend on each
            other.
        input: An SumfactKernelInputBase instance describing the input of the kernel
        accumvar: The accumulation variable to accumulate into
        trial_element: The leaf element of the trial function space.
            Used to correctly nest stage 3 in the jacobian case.
        test_element: The leaf element of the test function space
            Used to compute offsets in the fastdg case.
        test_element_index: the component of the test_element
        trial_element_index: the component of the trial_element
        """
        # Assert the inputs!
        assert isinstance(matrix_sequence, tuple)
        assert all(isinstance(m, BasisTabulationMatrixBase) for m in matrix_sequence)

        assert stage in (1, 3)

        if stage == 1:
            assert isinstance(input, SumfactKernelInputBase)

        if stage == 3:
            assert isinstance(restriction, tuple)

        assert isinstance(insn_dep, frozenset)

        # The following construction is a bit weird: Dict comprehensions do not have
        # access to the locals of the calling scope: So we need to do the eval beforehand
        defaultdict = {}
        for a in SumfactKernel.init_arg_names:
            defaultdict[a] = eval(a)

        # Call the base class constructors
        ImmutableRecord.__init__(self, **defaultdict)
        prim.Variable.__init__(self, "SUMFACT")

    #
    # The methods/fields needed to get a well-formed pymbolic node
    #

    def __getinitargs__(self):
        return tuple(getattr(self, arg) for arg in SumfactKernel.init_arg_names)

    def stringifier(self):
        return lp.symbolic.StringifyMapper

    mapper_method = "map_sumfact_kernel"

    #
    # Some cache key definitions
    # Watch out for the documentation to see which key is used unter what circumstances
    #

    @property
    def cache_key(self):
        """ The cache key that can be used in generation magic
        Any two sum factorization kernels having the same cache_key
        are realized simulatenously!
        """
        return (self.matrix_sequence, self.restriction, self.stage, self.buffer, self.test_element_index)

    @property
    def input_key(self):
        """ A cache key for the input coefficients
        Any two sum factorization kernels having the same input_key
        work on the same input coefficient (and are suitable for simultaneous
        treatment because of that)
        """
        return (self.input, self.restriction, self.accumvar, self.trial_element_index)

    @property
    def group_name(self):
        return "sfgroup_{}_{}_{}_{}".format(self.input, self.restriction, self.accumvar, self.trial_element_index)

    #
    # Some convenience methods to extract information about the sum factorization kernel
    #

    @property
    def length(self):
        """ The number of matrices to apply """
        return len(self.matrix_sequence)

    @property
    def vectorized(self):
        return False

    @property
    def transposed(self):
        return self.matrix_sequence[0].transpose

    @property
    def within_inames(self):
        if self.trial_element is None:
            return ()
        else:
            from dune.perftool.sumfact.basis import lfs_inames
            element = self.trial_element
            if isinstance(element, MixedElement):
                element = element.extract_component(self.trial_element_index)[1]
            return lfs_inames(element, self.restriction)

    def vec_index(self, sf):
        """ Map an unvectorized sumfact kernel object to its position
        in the vectorized kernel
        """
        return 0

    @property
    def quadrature_shape(self):
        """ The shape of a temporary for the quadrature points

        Takes into account the lower dimensionality of faces and vectorization.
        """
        return tuple(mat.quadrature_size for mat in self.matrix_sequence)

    def quadrature_index(self, sf, visitor):
        if visitor.current_info[1] is None:
            element = None
            element_index = 0
        else:
            element = visitor.current_info[1].element
            element_index = visitor.current_info[1].element_index
            if isinstance(element, MixedElement):
                element = element.extract_component(element_index)[1]

        quad_inames = quadrature_inames(element)
        if len(self.matrix_sequence) == local_dimension():
            return tuple(prim.Variable(i) for i in quad_inames)

        # Traverse all the quadrature inames and map them to their correct direction
        index = []
        i = 0
        for d in range(world_dimension()):
            if self.matrix_sequence[d].face is None:
                index.append(prim.Variable(quad_inames[i]))
                i = i + 1
            else:
                index.append(0)

        return tuple(index)

    @property
    def quadrature_dimtags(self):
        """ The dim_tags of a temporary for the quadrature points

        Takes into account the lower dimensionality of faces and vectorization.
        """
        tags = ["f"] * len(self.quadrature_shape)
        return ",".join(tags)

    @property
    def dof_shape(self):
        """ The shape of a temporary for the degrees of freedom

        Takes into account vectorization.
        """
        return tuple(mat.basis_size for mat in self.matrix_sequence)

    @property
    def dof_dimtags(self):
        """ The dim_tags of a temporary for the degrees of freedom

        Takes into account vectorization.
        """
        tags = ["f"] * len(self.dof_shape)
        return ",".join(tags)

    @property
    def output_shape(self):
        if self.stage == 1:
            return self.quadrature_shape
        else:
            return self.dof_shape

    @property
    def output_dimtags(self):
        if self.stage == 1:
            return self.quadrature_dimtags
        else:
            return self.dof_dimtags

    @property
    def tag(self):
        return "sumfac"

    #
    # Define properties for conformity with the interface of VectorizedSumfactKernel
    #

    @property
    def padded_indices(self):
        return set()

    @property
    def horizontal_width(self):
        return 1

    def horizontal_index(self, _):
        return 0

    @property
    def vertical_width(self):
        return 1

    @property
    def vector_width(self):
        return 1

# Extract the argument list and store it on the class. This needs to be done
# outside of the class because the SumfactKernel class object needs to be fully
# initialized in order to extract the information from __init__.
SumfactKernel.init_arg_names = tuple(inspect.getargspec(SumfactKernel.__init__)[0][1:])


class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
    def __init__(self,
                 kernels=None,
                 horizontal_width=1,
                 vertical_width=1,
                 buffer=None,
                 insn_dep=frozenset(),
                 ):
        # Assert the input data structure
        assert isinstance(kernels, tuple)
        assert all(isinstance(k, SumfactKernel) for k in kernels)

        # Assert all the properties that need to be the same across all subkernels
        assert len(set(k.stage for k in kernels)) == 1
        assert len(set(k.length for k in kernels)) == 1
        assert len(set(k.restriction for k in kernels)) == 1
        assert len(set(k.within_inames for k in kernels)) == 1
        assert len(set(k.predicates for k in kernels)) == 1

        # Assert properties of the matrix sequence of the underlying kernels
        for i in range(kernels[0].length):
            assert len(set(tuple(k.matrix_sequence[i].rows for k in kernels))) == 1
            assert len(set(tuple(k.matrix_sequence[i].cols for k in kernels))) == 1
            assert len(set(tuple(k.matrix_sequence[i].face for k in kernels))) == 1
            assert len(set(tuple(k.matrix_sequence[i].transpose for k in kernels))) == 1

        # Join the instruction dependencies of all subkernels
        insn_dep = insn_dep.union(k.insn_dep for k in kernels)

        # We currently assume that all subkernels are consecutive, 0-based within the vector
        assert None not in kernels

        ImmutableRecord.__init__(self,
                                 kernels=kernels,
                                 horizontal_width=horizontal_width,
                                 buffer=buffer,
                                 insn_dep=insn_dep,
                                 vertical_width=vertical_width,
                                 )

        prim.Variable.__init__(self, "VecSUMFAC")

    def __getinitargs__(self):
        return (self.kernels, self.horizontal_width, self.vertical_width, self.buffer, self.insn_dep)

    def stringifier(self):
        return lp.symbolic.StringifyMapper

    mapper_method = "map_vectorized_sumfact_kernel"

    init_arg_names = ("kernels", "horizontal_width", "vertical_width", "buffer", "insn_dep")

    #
    # Some cache key definitions
    # Watch out for the documentation to see which key is used unter what circumstances
    #

    @property
    def cache_key(self):
        """ The cache key that can be used in generation magic
        Any two sum factorization kernels having the same cache_key
        are realized simulatenously!
        """
        return (self.matrix_sequence, self.restriction, self.stage, self.buffer)

    #
    # Deduce all data fields of normal sum factorization kernels from the underlying kernels
    #

    @property
    def matrix_sequence(self):
        return tuple(BasisTabulationMatrixArray(tuple(k.matrix_sequence[i] for k in self.kernels),
                                                width=self.vector_width,
                                                )
                     for i in range(self.length))

    @property
    def stage(self):
        return self.kernels[0].stage

    @property
    def restriction(self):
        return self.kernels[0].restriction

    @property
    def within_inames(self):
        return self.kernels[0].within_inames

    @property
    def test_element(self):
        return self.kernels[0].test_element

    @property
    def test_element_index(self):
        return self.kernels[0].test_element_index

    @property
    def trial_element(self):
        return self.kernels[0].trial_element

    @property
    def trial_element_index(self):
        return self.kernels[0].trial_element_index

    @property
    def predicates(self):
        return self.kernels[0].predicates

    @property
    def input(self):
        assert len(set(k.input for k in self.kernels)) == 1
        return self.kernels[0].input

    @property
    def accumvar(self):
        assert len(set(k.accumvar for k in self.kernels)) == 1
        return self.kernels[0].accumvar

    @property
    def transposed(self):
        return self.kernels[0].transposed

    #
    # Define some properties only needed for this one
    #

    @property
    def padded_indices(self):
        indices = set(range(self.vector_width)) - set(range(len(self.kernels)))
        return tuple(self.kernels[0].quadrature_index(None) + (i,) for i in indices)

    @property
    def vector_width(self):
        return self.horizontal_width * self.vertical_width
    #
    # Define the same properties the normal SumfactKernel defines
    #

    @property
    def cache_key(self):
        return (tuple(k.cache_key for k in self.kernels), self.buffer)

    @property
    def input_key(self):
        return tuple(k.input_key for k in self.kernels)

    @property
    def group_name(self):
        return "_".join(k.group_name for k in self.kernels)

    @property
    def length(self):
        return self.kernels[0].length

    @property
    def vectorized(self):
        return True

    def horizontal_index(self, sf):
        key = tuple(mat.derivative for mat in sf.matrix_sequence)
        for i, k in enumerate(self.kernels):
            if tuple(mat.derivative for mat in k.matrix_sequence) == key:
                return i
        return 0

    def _quadrature_index(self, sf, visitor):
        if visitor.current_info[1] is None:
            element = None
            element_index = 0
        else:
            element = visitor.current_info[1].element
            element_index = visitor.current_info[1].element_index
            if isinstance(element, MixedElement):
                element = element.extract_component(element_index)[1]

        quad_inames = quadrature_inames(element)
        index = []

        if len(self.matrix_sequence) == local_dimension():
            for d in range(local_dimension()):
                addindex = prim.Variable(quad_inames[d])

                if self.matrix_sequence[d].slice_size:
                    addindex = addindex // self.vertical_width

                index.append(addindex)
        else:
            # Traverse all the quadrature inames and map them to their correct direction
            i = 0
            for d in range(world_dimension()):
                if self.matrix_sequence[d].face is None:
                    addindex = prim.Variable(quad_inames[i])

                    if self.matrix_sequence[d].slice_size:
                        addindex = addindex // self.vertical_width

                    index.append(addindex)
                    i = i + 1
                else:
                    index.append(0)

        return tuple(index)

    def vec_index(self, sf, visitor):
        if visitor.current_info[1] is None:
            element = None
            element_index = 0
        else:
            element = visitor.current_info[1].element
            element_index = visitor.current_info[1].element_index
            if isinstance(element, MixedElement):
                element = element.extract_component(element_index)[1]

        quad_inames = quadrature_inames(element)
        sliced = 0
        if len(sf.matrix_sequence) == local_dimension():
            for d in range(local_dimension()):
                if self.matrix_sequence[d].slice_size:
                    sliced = prim.Variable(quad_inames[d])
        else:
            i = 0
            for d in range(world_dimension()):
                if self.matrix_sequence[d].face is None:
                    if self.matrix_sequence[d].slice_size:
                        sliced = prim.Variable(quad_inames[i])
                    i = i + 1

        return self.horizontal_index(sf) + prim.Remainder(sliced, self.vertical_width)

    @property
    def quadrature_shape(self):
        return tuple(mat.quadrature_size for mat in self.matrix_sequence) + (self.vector_width,)

    def quadrature_index(self, sf, visitor, direct_index=None):
        quad = self._quadrature_index(sf, visitor)
        if direct_index is not None:
            assert isinstance(direct_index, tuple)
            return quad + direct_index
        else:
            return quad + (self.vec_index(sf, visitor),)

    @property
    def quadrature_dimtags(self):
        tags = ["f"] * len(self.quadrature_shape)
        tags[-1] = 'c'
        return ",".join(tags)

    @property
    def dof_shape(self):
        return tuple(mat.basis_size for mat in self.matrix_sequence) + (self.vector_width,)

    @property
    def dof_dimtags(self):
        tags = ["f"] * len(self.dof_shape)
        tags[-1] = 'vec'
        return ",".join(tags)

    @property
    def output_shape(self):
        if self.stage == 1:
            return self.quadrature_shape
        else:
            return self.dof_shape

    @property
    def output_dimtags(self):
        if self.stage == 1:
            return self.quadrature_dimtags
        else:
            return self.dof_dimtags

    @property
    def tag(self):
        return "vecsumfac_h{}_v{}".format(self.horizontal_width, self.vertical_width)