""" Sum factorization vectorization """ from dune.perftool.loopy.symbolic import SumfactKernel, VectorizedSumfactKernel from dune.perftool.generation import (generator_factory, get_counted_variable, get_global_context_value, ) from dune.perftool.pdelab.restriction import (Restriction, restricted_name, ) from dune.perftool.sumfact.tabulation import BasisTabulationMatrixArray from dune.perftool.error import PerftoolError from dune.perftool.options import get_option import loopy as lp @generator_factory(item_tags=("vecinfo", "dryrundata"), cache_key_generator=lambda o, n: o) def _cache_vectorization_info(old, new): if new is None: raise PerftoolError("Vectorization info for sum factorization kernel was not gathered correctly!") return new _collect_sumfact_nodes = generator_factory(item_tags=("sumfactnodes", "dryrundata"), context_tags="kernel", no_deco=True) def attach_vectorization_info(sf): assert isinstance(sf, SumfactKernel) if get_global_context_value("dry_run"): return _collect_sumfact_nodes(sf) else: return _cache_vectorization_info(sf, None) def no_vectorization(sumfacts): for sf in sumfacts: _cache_vectorization_info(sf, sf.copy(buffer=get_counted_variable("buffer"), input=get_counted_variable("input"))) def horizontal_vectorization_strategy(sumfacts): if len(sumfacts) in (3, 4): # Map the sum factorization to their position in the joint kernel position_mapping = {} available = set(range(4)) for sf in sumfacts: if sf.preferred_position is not None: # This asserts that no two kernels want to take the same position # Later on, more complicated stuff might be necessary here. assert sf.preferred_position in available available.discard(sf.preferred_position) position_mapping[sf] = sf.preferred_position # Choose a position for those that have no preferred one! for sumf in sumfacts: if sumf.preferred_position is None: position_mapping[sumf] = available.pop() # Store the kernels as tuple according to their positions sorting = [None] * len(position_mapping) for sf, pos in position_mapping.items(): sorting[pos] = sf kernels = tuple(sorting) buffer = get_counted_variable("joined_buffer") input = get_counted_variable("joined_input") for sumf in sumfacts: _cache_vectorization_info(sumf, VectorizedSumfactKernel(kernels=kernels, vector_width=4, buffer=buffer, input=input, ) ) else: # Disable vectorization strategy no_vectorization(sumfacts) def decide_vectorization_strategy(): """ Decide how to vectorize! Note that the vectorization of the quadrature loop is independent of this, as it is implemented through a post-processing (== loopy transformation) step. """ from dune.perftool.generation import retrieve_cache_items sumfacts = [i for i in retrieve_cache_items("kernel_default and sumfactnodes")] if not get_option("vectorize_grads"): no_vectorization(sumfacts) else: # Currently we base our idea here on the fact that we only group sum # factorization kernels with the same input. inputkeys = set(sf.input_key for sf in sumfacts) for inputkey in inputkeys: sumfact_filter = [sf for sf in sumfacts if sf.input_key == inputkey] horizontal_vectorization_strategy(sumfact_filter)