""" Sum factorization vectorization """ from dune.perftool.loopy.vcl import get_vcl_type_size from dune.perftool.loopy.symbolic import SumfactKernel, VectorizedSumfactKernel from dune.perftool.generation import (generator_factory, get_counted_variable, get_global_context_value, ) from dune.perftool.pdelab.restriction import (Restriction, restricted_name, ) from dune.perftool.sumfact.tabulation import BasisTabulationMatrixArray from dune.perftool.error import PerftoolError from dune.perftool.options import get_option import loopy as lp import numpy as np @generator_factory(item_tags=("vecinfo", "dryrundata"), cache_key_generator=lambda o, n: o) def _cache_vectorization_info(old, new): if new is None: raise PerftoolError("Vectorization info for sum factorization kernel was not gathered correctly!") return new _collect_sumfact_nodes = generator_factory(item_tags=("sumfactnodes", "dryrundata"), context_tags="kernel", no_deco=True) def attach_vectorization_info(sf): assert isinstance(sf, SumfactKernel) if get_global_context_value("dry_run"): return _collect_sumfact_nodes(sf) else: return _cache_vectorization_info(sf, None) def no_vec(sf): return sf.copy(buffer=get_counted_variable("buffer"), input=get_counted_variable("input")) def no_vectorization(sumfacts): return {sf: no_vec(sf) for sf in sumfacts} def vertical_vectorization_strategy(sumfact, depth): # Assert that this is not already sliced assert all(mat.slice_size is None for mat in sumfact.matrix_sequence) # Determine which of the matrices in the kernel should be sliced def determine_slice_direction(): for i, mat in enumerate(sumfact.matrix_sequence): if mat.quadrature_size % depth == 0: return i elif mat.quadrature_size != 1: raise PerftoolError("Vertical vectorization is not possible!") sliced = determine_slice_direction() kernels = [] oldtab = sumfact.matrix_sequence[sliced] for i in range(depth): seq = list(sumfact.matrix_sequence) seq[sliced] = oldtab.copy(slice_size=depth, slice_index=i) kernels.append(sumfact.copy(matrix_sequence=tuple(seq))) buffer = get_counted_variable("vertical_buffer") input = get_counted_variable("vertical_input") vsf = VectorizedSumfactKernel(kernels=tuple(kernels), buffer=buffer, input=input, vertical_width=depth, ) return {sumfact: vsf} def horizontal_vectorization_strategy(sumfacts, width): result = {} todo = set(sumfacts) while todo: position_mapping = {} available = set(range(width)) for sf in todo: if sf.preferred_position is not None and sf.preferred_position in available: available.discard(sf.preferred_position) position_mapping[sf.preferred_position] = sf for sf in position_mapping.values(): todo.discard(sf) for pos in available: if todo: position_mapping[pos] = todo.pop() kernels = [None] * len(position_mapping) for pos in position_mapping: kernels[pos] = position_mapping[pos] kernels = tuple(kernels) buffer = get_counted_variable("joined_buffer") input = get_counted_variable("joined_input") for sumf in kernels: if len(kernels) in (width, width - 1): result[sumf] = VectorizedSumfactKernel(kernels=kernels, horizontal_width=width, buffer=buffer, input=input, ) else: result[sumf] = no_vec(sumf) return result def diagonal_vectorization_stratget(sumfacts, width): return horizontal_vectorization_strategy(sumfacts, width) def decide_vectorization_strategy(): """ Decide how to vectorize! Note that the vectorization of the quadrature loop is independent of this, as it is implemented through a post-processing (== loopy transformation) step. """ from dune.perftool.generation import retrieve_cache_items sumfacts = [i for i in retrieve_cache_items("kernel_default and sumfactnodes")] sfdict = {} if get_option("vectorize_grads"): # Currently we base our idea here on the fact that we only group sum # factorization kernels with the same input. inputkeys = set(sf.input_key for sf in sumfacts) for inputkey in inputkeys: width = get_vcl_type_size(np.float64) sumfact_filter = [sf for sf in sumfacts if sf.input_key == inputkey] sfdict.update(**horizontal_vectorization_strategy(sumfact_filter, width)) elif get_option("vectorize_slice"): for sumfact in sumfacts: width = get_vcl_type_size(np.float64) sfdict.update(**vertical_vectorization_strategy(sumfact, width)) elif get_option("vectorize_diagonal"): width = get_vcl_type_size(np.float64) sfdict.update(**diagonal_vectorization_stragegy(sumfact, width)) else: sfdict.update(**no_vectorization(sumfacts)) # Register the results for old, new in sfdict.items(): _cache_vectorization_info(old, new)