Skip to content
Snippets Groups Projects
geometry.py 4.3 KiB
Newer Older
""" Sum factorized geometry evaluations """

from dune.perftool.generation import (domain,
                                      get_backend,
                                      get_counted_variable,
                                      iname,
                                      instruction,
                                      kernel_cached,
                                      temporary_variable,
                                      )
from dune.perftool.loopy.buffer import get_buffer_temporary
from dune.perftool.pdelab.geometry import (local_dimension,
                                           world_dimension,
                                           )
from dune.perftool.sumfact.symbolic import SumfactKernelInputBase
from dune.perftool.sumfact.vectorization import attach_vectorization_info

from pytools import ImmutableRecord

import pymbolic.primitives as prim


@iname
def corner_iname():
    name = get_counted_variable("corneriname")
    domain(name, 2 ** local_dimension())
    return name


class GeoCornersInput(SumfactKernelInputBase, ImmutableRecord):
    def __init__(self, dir):
        ImmutableRecord.__init__(self, dir=dir)

    def realize(self, sf, index, insn_dep):
        name = get_buffer_temporary(sf.buffer,
                                    shape=(2 ** local_dimension(), sf.vector_width),
                                    name="input_{}".format(sf.buffer)
                                    )

        ciname = corner_iname()

        from dune.perftool.pdelab.geometry import name_geometry
        geo = name_geometry()

        # NB: We need to realize this as a C instruction, because the corner
        #     method does return a non-scalar, which does not fit into the current
        #     loopy philosophy for function calls. This problem will be solved once
        #     #11 is resolved.
        code = "{}[{}] = {}.corner({})[{}];".format(name,
                                                    ciname,
                                                    geo,
                                                    ciname,
                                                    self.dir,
                                                    )

        instruction(code=code,
                    within_inames=frozenset({ciname}),
                    assignees=(name,),
                    tags=frozenset({"sumfact_stage{}".format(sf.stage)}),
                    )


@kernel_cached
def pymbolic_spatial_coordinate():
    # Construct the matrix sequence for the evaluation of the global coordinate.
    # We need to manually construct this one, because on facets, we want to use the
    # geometry embedding of the facet into the global space directly without going
    # through the neighboring cell geometries. That matrix sequence will only have
    # dim-1 matrices!
    from dune.perftool.sumfact.tabulation import quadrature_points_per_direction, BasisTabulationMatrix
    quadrature_size = quadrature_points_per_direction()
    matrix_sequence = (BasisTabulationMatrix(quadrature_size=quadrature_size, basis_size=2),) * local_dimension()

    expressions = []
    insn_dep = frozenset()
    for i in range(world_dimension()):
        inp = GeoCornersInput(i)

        from dune.perftool.sumfact.symbolic import SumfactKernel
        sf = SumfactKernel(matrix_sequence=matrix_sequence,
                           input=inp,
                           )

        vsf = attach_vectorization_info(sf)

        # Add a sum factorization kernel that implements the evaluation of
        # the basis functions at quadrature points (stage 1)
        from dune.perftool.sumfact.realization import realize_sum_factorization_kernel
        var, insn_dep = realize_sum_factorization_kernel(vsf.copy(insn_dep=vsf.insn_dep.union(insn_dep)))

        expressions.append(prim.Subscript(var, vsf.quadrature_index(sf)))

    # Return an indexable temporary with the results!
    name = "pos_global"
    temporary_variable(name, shape=(world_dimension(),))
    for i, expr in enumerate(expressions):
        assignee = prim.Subscript(prim.Variable(name), (i,))
        instruction(assignee=assignee,
                    expression=expr,
                    within_inames=frozenset(get_backend("quad_inames")()),
                    within_inames_is_final=True,
                    depends_on=insn_dep,
                    )

    return prim.Variable(name)