Skip to content
Snippets Groups Projects
vectorization.py 5.95 KiB
Newer Older
""" Sum factorization vectorization """

Dominic Kempf's avatar
Dominic Kempf committed
from dune.perftool.loopy.symbolic import SumfactKernel
from dune.perftool.generation import (generator_factory,
                                      get_counted_variable,
Dominic Kempf's avatar
Dominic Kempf committed
                                      get_global_context_value,
from dune.perftool.pdelab.restriction import (Restriction,
                                              restricted_name,
                                              )
from dune.perftool.error import PerftoolError
from dune.perftool.options import get_option

import loopy as lp


Dominic Kempf's avatar
Dominic Kempf committed
@generator_factory(item_tags=("vecinfo", "dryrundata"), cache_key_generator=lambda o, n: o)
def _cache_vectorization_info(old, new):
    if new is None:
        raise PerftoolError("Vectorization info for sum factorization kernel was not gathered correctly!")
    return new
Dominic Kempf's avatar
Dominic Kempf committed
_collect_sumfact_nodes = generator_factory(item_tags=("sumfactnodes", "dryrundata"), no_deco=True)

Dominic Kempf's avatar
Dominic Kempf committed
def attach_vectorization_info(sf):
    assert isinstance(sf, SumfactKernel)
    if get_global_context_value("dry_run"):
Dominic Kempf's avatar
Dominic Kempf committed
        return _collect_sumfact_nodes(sf)
Dominic Kempf's avatar
Dominic Kempf committed
    else:
        return _cache_vectorization_info(sf, None)
Dominic Kempf's avatar
Dominic Kempf committed
    for sf in sumfacts:
Dominic Kempf's avatar
Dominic Kempf committed
        _cache_vectorization_info(sf, sf.copy(buffer=get_counted_variable("buffer"),
                                              input=get_counted_variable("input")))
def decide_stage_vectorization_strategy(sumfacts, stage, restriction):
    stage_sumfacts = frozenset([sf for sf in sumfacts if sf.stage == stage and sf.restriction == restriction])
    if len(stage_sumfacts) in (3, 4):
        # Map the sum factorization to their position in the joint kernel
        position_mapping = {}
        available = set(range(4))
        for sf in stage_sumfacts:
            if sf.preferred_position is not None:
                # This asserts that no two kernels want to take the same position
                # Later on, more complicated stuff might be necessary here.
                assert sf.preferred_position in available
                available.discard(sf.preferred_position)
                position_mapping[sf] = sf.preferred_position

        # Choose a position for those that have no preferred one!
        for sumf in stage_sumfacts:
            if sumf.preferred_position is None:
                position_mapping[sumf] = available.pop()
René Heß's avatar
René Heß committed
        inp = get_counted_variable("joined_input")
        buf = get_counted_variable("joined_buffer")
        # Collect the large matrices!
        large_a_matrices = []
        for i in range(len(next(iter(stage_sumfacts)).a_matrices)):
            # Assert that the matrices of all sum factorizations have the same size
            assert len(set(tuple(sf.a_matrices[i].rows for sf in stage_sumfacts))) == 1
            assert len(set(tuple(sf.a_matrices[i].cols for sf in stage_sumfacts))) == 1

            # Collect the derivative information
            derivative = [False] * 4
            for sf in stage_sumfacts:
                derivative[position_mapping[sf]] = sf.a_matrices[i].derivative

            from dune.perftool.sumfact.amatrix import LargeAMatrix
            large = LargeAMatrix(rows=next(iter(stage_sumfacts)).a_matrices[i].rows,
                                 cols=next(iter(stage_sumfacts)).a_matrices[i].cols,
                                 transpose=next(iter(stage_sumfacts)).a_matrices[i].transpose,
                                 derivative=tuple(derivative),
                                 face=next(iter(stage_sumfacts)).a_matrices[i].face,
                                 )
            large_a_matrices.append(large)

Dominic Kempf's avatar
Dominic Kempf committed
            _cache_vectorization_info(sumf,
                                      sumf.copy(a_matrices=tuple(large_a_matrices),
                                                buffer=buf,
                                                input=inp,
                                                index=position_mapping[sumf],
                                                padding=frozenset(available),
                                                insn_dep=frozenset().union(sf.insn_dep for sf in stage_sumfacts),
                                                )
Dominic Kempf's avatar
Dominic Kempf committed
                                      )
    else:
        # Disable vectorization strategy
        no_vectorization(stage_sumfacts)


def decide_vectorization_strategy():
    """ Decide how to vectorize!
    Note that the vectorization of the quadrature loop is independent of this,
    as it is implemented through a post-processing (== loopy transformation) step.
    """
    from dune.perftool.generation import retrieve_cache_items
    insns = [i for i in retrieve_cache_items("kernel_default and instruction")]

    # Find all sum factorization kernels
    sumfacts = frozenset()
    for insn in insns:
        if isinstance(insn, (lp.Assignment, lp.CallInstruction)):
            sumfacts = sumfacts.union(find_sumfact(insn.expression))

    if not get_option("vectorize_grads"):
        no_vectorization(sumfacts)
    else:
Dominic Kempf's avatar
Dominic Kempf committed
        res = (Restriction.NONE, Restriction.POSITIVE, Restriction.NEGATIVE)
        # Stage 1 kernels
        for restriction in res:
            decide_stage_vectorization_strategy(sumfacts, 1, restriction)

        # Stage 3 kernels
        import itertools as it
        for restriction in it.product(res, res):
            decide_stage_vectorization_strategy(sumfacts, 3, restriction)


class HasSumfactMapper(lp.symbolic.CombineMapper):
    def combine(self, *args):
        return frozenset().union(*tuple(*args))

    def map_constant(self, expr):
        return frozenset()

    def map_algebraic_leaf(self, expr):
        return frozenset()

    def map_loopy_function_identifier(self, expr):
        return frozenset()

    def map_sumfact_kernel(self, expr):
        return frozenset({expr})

Dominic Kempf's avatar
Dominic Kempf committed
    def map_tagged_variable(self, expr):
        return frozenset()

Dominic Kempf's avatar
Dominic Kempf committed
    return HasSumfactMapper()(expr)