Newer
Older
""" Sum factorization vectorization """
from dune.perftool.loopy.symbolic import SumfactKernel, VectorizedSumfactKernel
from dune.perftool.generation import (generator_factory,
get_counted_variable,
from dune.perftool.pdelab.restriction import (Restriction,
restricted_name,
)
from dune.perftool.sumfact.tabulation import BasisTabulationMatrixArray
from dune.perftool.error import PerftoolError
from dune.perftool.options import get_option
import loopy as lp
@generator_factory(item_tags=("vecinfo", "dryrundata"), cache_key_generator=lambda o, n: o)
def _cache_vectorization_info(old, new):
if new is None:
raise PerftoolError("Vectorization info for sum factorization kernel was not gathered correctly!")
return new
_collect_sumfact_nodes = generator_factory(item_tags=("sumfactnodes", "dryrundata"), context_tags="kernel", no_deco=True)
def attach_vectorization_info(sf):
assert isinstance(sf, SumfactKernel)
if get_global_context_value("dry_run"):
def no_vectorization(sumfacts):
_cache_vectorization_info(sf, sf.copy(buffer=get_counted_variable("buffer"),
input=get_counted_variable("input")))
def horizontal_vectorization_strategy(sumfacts):
if len(sumfacts) in (3, 4):
# Map the sum factorization to their position in the joint kernel
available = set(range(4))
if sf.preferred_position is not None:
# This asserts that no two kernels want to take the same position
# Later on, more complicated stuff might be necessary here.
assert sf.preferred_position in available
available.discard(sf.preferred_position)
position_mapping[sf] = sf.preferred_position
# Choose a position for those that have no preferred one!
if sumf.preferred_position is None:
position_mapping[sumf] = available.pop()
# Store the kernels as tuple according to their positions
sorting = [None] * len(position_mapping)
for sf, pos in position_mapping.items():
sorting[pos] = sf
kernels = tuple(sorting)
buffer = get_counted_variable("joined_buffer")
input = get_counted_variable("joined_input")
VectorizedSumfactKernel(kernels=kernels,
vector_width=4,
buffer=buffer,
input=input,
)
else:
# Disable vectorization strategy
def decide_vectorization_strategy():
""" Decide how to vectorize!
Note that the vectorization of the quadrature loop is independent of this,
as it is implemented through a post-processing (== loopy transformation) step.
"""
from dune.perftool.generation import retrieve_cache_items
sumfacts = [i for i in retrieve_cache_items("kernel_default and sumfactnodes")]
if not get_option("vectorize_grads"):
no_vectorization(sumfacts)
else:
# Currently we base our idea here on the fact that we only group sum
# factorization kernels with the same input.
inputkeys = set(sf.input_key for sf in sumfacts)
for inputkey in inputkeys:
sumfact_filter = [sf for sf in sumfacts if sf.input_key == inputkey]
horizontal_vectorization_strategy(sumfact_filter)