""" Sum factorization vectorization """ import logging from dune.perftool.loopy.vcl import get_vcl_type_size from dune.perftool.loopy.symbolic import SumfactKernel, VectorizedSumfactKernel from dune.perftool.generation import (generator_factory, get_counted_variable, get_global_context_value, ) from dune.perftool.pdelab.restriction import (Restriction, restricted_name, ) from dune.perftool.sumfact.tabulation import BasisTabulationMatrixArray from dune.perftool.error import PerftoolError from dune.perftool.options import get_option import loopy as lp import numpy as np @generator_factory(item_tags=("vecinfo", "dryrundata"), cache_key_generator=lambda o, n: o) def _cache_vectorization_info(old, new): if new is None: raise PerftoolError("Vectorization info for sum factorization kernel was not gathered correctly!") return new _collect_sumfact_nodes = generator_factory(item_tags=("sumfactnodes", "dryrundata"), context_tags="kernel", no_deco=True) def get_all_sumfact_nodes(): from dune.perftool.generation import retrieve_cache_items return [i for i in retrieve_cache_items("kernel_default and sumfactnodes")] def attach_vectorization_info(sf): assert isinstance(sf, SumfactKernel) if get_global_context_value("dry_run"): return _collect_sumfact_nodes(sf) else: return _cache_vectorization_info(sf, None) def no_vec(sf): return sf.copy(buffer=get_counted_variable("buffer")) def no_vectorization(sumfacts): return {sf: no_vec(sf) for sf in sumfacts} def vertical_vectorization_strategy(sumfact, depth): # Assert that this is not already sliced assert all(mat.slice_size is None for mat in sumfact.matrix_sequence) result = {} # Determine which of the matrices in the kernel should be sliced def determine_slice_direction(sf): for i, mat in enumerate(sf.matrix_sequence): if mat.quadrature_size % depth == 0: return i elif mat.quadrature_size != 1: raise PerftoolError("Vertical vectorization is not possible!") if isinstance(sumfact, SumfactKernel): kernels = [sumfact] else: assert isinstance(sumfact, VectorizedSumfactKernel) kernels = sumfact.kernels newkernels = [] for sf in kernels: sliced = determine_slice_direction(sf) oldtab = sf.matrix_sequence[sliced] for i in range(depth): seq = list(sf.matrix_sequence) seq[sliced] = oldtab.copy(slice_size=depth, slice_index=i) newkernels.append(sf.copy(matrix_sequence=tuple(seq))) if isinstance(sumfact, SumfactKernel): buffer = get_counted_variable("vertical_buffer") result[sumfact] = VectorizedSumfactKernel(kernels=tuple(newkernels), buffer=buffer, vertical_width=depth, ) else: assert isinstance(sumfact, VectorizedSumfactKernel) for sf in kernels: result[sf] = sumfact.copy(kernels=tuple(newkernels), vertical_width=depth, ) return result def horizontal_vectorization_strategy(sumfacts, width, allow_padding=1): result = {} todo = set(sumfacts) while todo: kernels = [] for sf in sorted(todo, key=lambda s: s.position_priority): if len(kernels) < width: kernels.append(sf) todo.discard(sf) buffer = get_counted_variable("joined_buffer") if len(kernels) in range(width - allow_padding, width + 1): for sf in kernels: result[sf] = VectorizedSumfactKernel(kernels=tuple(kernels), horizontal_width=width, buffer=buffer, ) return result def diagonal_vectorization_strategy(sumfacts, width): if width == 4: horizontal, vertical = 2, 2 padding = 0 elif width == 8: horizontal, vertical = 4, 2 padding = 1 else: raise NotImplementedError result = {} horizontal_kernels = horizontal_vectorization_strategy(sumfacts, horizontal, allow_padding=padding) for sf in horizontal_kernels: if horizontal_kernels[sf].horizontal_width > 1: vert = vertical_vectorization_strategy(horizontal_kernels[sf], width // horizontal_kernels[sf].horizontal_width) for k in vert: result[k] = vert[k] return result def decide_vectorization_strategy(): """ Decide how to vectorize! Note that the vectorization of the quadrature loop is independent of this, as it is implemented through a post-processing (== loopy transformation) step. """ logger = logging.getLogger(__name__) # Retrieve all sum factorization kernels for stage 1 and 3 from dune.perftool.generation import retrieve_cache_items all_sumfacts = [i for i in retrieve_cache_items("kernel_default and sumfactnodes")] # Stage 1 sumfactorizations that were actually used basis_sumfacts = [i for i in retrieve_cache_items('kernel_default and basis_sf_kernels')] # This means we can have sum factorizations that will not get used inactive_sumfacts = [i for i in all_sumfacts if i.stage == 1 and i not in basis_sumfacts] # All sum factorization kernels that get used active_sumfacts = [i for i in all_sumfacts if i.stage == 3 or i in basis_sumfacts] # We map inacitve sum factorizatino kernels to 0 sfdict = {} for sf in inactive_sumfacts: sfdict[sf] = 0 logger.debug("decide_vectorization_strategy: Found {} active sum factorization nodes" .format(len(active_sumfacts))) if get_option("vectorize_grads"): # Currently we base our idea here on the fact that we only group sum # factorization kernels with the same input. inputkeys = set(sf.input_key for sf in active_sumfacts) for inputkey in inputkeys: width = get_vcl_type_size(np.float64) sumfact_filter = [sf for sf in active_sumfacts if sf.input_key == inputkey] for old, new in horizontal_vectorization_strategy(sumfact_filter, width).items(): sfdict[old] = new elif get_option("vectorize_slice"): for sumfact in active_sumfacts: width = get_vcl_type_size(np.float64) for old, new in vertical_vectorization_strategy(sumfact, width).items(): sfdict[old] = new elif get_option("vectorize_diagonal"): inputkeys = set(sf.input_key for sf in active_sumfacts) for inputkey in inputkeys: width = get_vcl_type_size(np.float64) sumfact_filter = [sf for sf in active_sumfacts if sf.input_key == inputkey] for old, new in diagonal_vectorization_strategy(sumfact_filter, width).items(): sfdict[old] = new else: for old, new in no_vectorization(active_sumfacts).items(): sfdict[old] = new # Register the results for sf in all_sumfacts: _cache_vectorization_info(sf, sfdict.get(sf, no_vec(sf)))