Skip to content
Snippets Groups Projects
vectorization.py 4.04 KiB
Newer Older
""" Sum factorization vectorization """

from dune.perftool.loopy.symbolic import SumfactKernel, VectorizedSumfactKernel
from dune.perftool.generation import (generator_factory,
                                      get_counted_variable,
Dominic Kempf's avatar
Dominic Kempf committed
                                      get_global_context_value,
from dune.perftool.pdelab.restriction import (Restriction,
                                              restricted_name,
                                              )
from dune.perftool.sumfact.tabulation import BasisTabulationMatrixArray
from dune.perftool.error import PerftoolError
from dune.perftool.options import get_option

import loopy as lp


Dominic Kempf's avatar
Dominic Kempf committed
@generator_factory(item_tags=("vecinfo", "dryrundata"), cache_key_generator=lambda o, n: o)
def _cache_vectorization_info(old, new):
    if new is None:
        raise PerftoolError("Vectorization info for sum factorization kernel was not gathered correctly!")
    return new
_collect_sumfact_nodes = generator_factory(item_tags=("sumfactnodes", "dryrundata"), context_tags="kernel", no_deco=True)

Dominic Kempf's avatar
Dominic Kempf committed

Dominic Kempf's avatar
Dominic Kempf committed
def attach_vectorization_info(sf):
    assert isinstance(sf, SumfactKernel)
    if get_global_context_value("dry_run"):
Dominic Kempf's avatar
Dominic Kempf committed
        return _collect_sumfact_nodes(sf)
Dominic Kempf's avatar
Dominic Kempf committed
    else:
        return _cache_vectorization_info(sf, None)
Dominic Kempf's avatar
Dominic Kempf committed
    for sf in sumfacts:
Dominic Kempf's avatar
Dominic Kempf committed
        _cache_vectorization_info(sf, sf.copy(buffer=get_counted_variable("buffer"),
                                              input=get_counted_variable("input")))
def horizontal_vectorization_strategy(sumfacts):
    if len(sumfacts) in (3, 4):
        # Map the sum factorization to their position in the joint kernel
        position_mapping = {}
        for sf in sumfacts:
            if sf.preferred_position is not None:
                # This asserts that no two kernels want to take the same position
                # Later on, more complicated stuff might be necessary here.
                assert sf.preferred_position in available
                available.discard(sf.preferred_position)
                position_mapping[sf] = sf.preferred_position

        # Choose a position for those that have no preferred one!
        for sumf in sumfacts:
            if sumf.preferred_position is None:
                position_mapping[sumf] = available.pop()
        # Store the kernels as tuple according to their positions
        sorting = [None] * len(position_mapping)
        for sf, pos in position_mapping.items():
            sorting[pos] = sf
        kernels = tuple(sorting)

        buffer = get_counted_variable("joined_buffer")
        input = get_counted_variable("joined_input")
        for sumf in sumfacts:
Dominic Kempf's avatar
Dominic Kempf committed
            _cache_vectorization_info(sumf,
                                      VectorizedSumfactKernel(kernels=kernels,
                                                              vector_width=4,
                                                              buffer=buffer,
                                                              input=input,
                                                              )
Dominic Kempf's avatar
Dominic Kempf committed
                                      )
    else:
        # Disable vectorization strategy
        no_vectorization(sumfacts)


def decide_vectorization_strategy():
    """ Decide how to vectorize!
    Note that the vectorization of the quadrature loop is independent of this,
    as it is implemented through a post-processing (== loopy transformation) step.
    """
    from dune.perftool.generation import retrieve_cache_items
    sumfacts = [i for i in retrieve_cache_items("kernel_default and sumfactnodes")]

    if not get_option("vectorize_grads"):
        no_vectorization(sumfacts)
    else:
        # Currently we base our idea here on the fact that we only group sum
        # factorization kernels with the same input.
        inputkeys = set(sf.input_key for sf in sumfacts)
        for inputkey in inputkeys:
            sumfact_filter = [sf for sf in sumfacts if sf.input_key == inputkey]
            horizontal_vectorization_strategy(sumfact_filter)