""" Sum factorization vectorization """ import logging from dune.perftool.loopy.vcl import get_vcl_type_size from dune.perftool.loopy.symbolic import SumfactKernel, VectorizedSumfactKernel from dune.perftool.generation import (generator_factory, get_counted_variable, get_global_context_value, ) from dune.perftool.pdelab.restriction import (Restriction, restricted_name, ) from dune.perftool.sumfact.tabulation import (BasisTabulationMatrixArray, quadrature_points_per_direction, set_quadrature_points, ) from dune.perftool.error import PerftoolError from dune.perftool.options import get_option from dune.perftool.tools import add_to_frozendict,round_to_multiple from frozendict import frozendict import itertools as it import loopy as lp import numpy as np @generator_factory(item_tags=("vecinfo", "dryrundata"), cache_key_generator=lambda o, n: o) def _cache_vectorization_info(old, new): if new is None: raise PerftoolError("Vectorization info for sum factorization kernel was not gathered correctly!") return new _collect_sumfact_nodes = generator_factory(item_tags=("sumfactnodes", "dryrundata"), context_tags="kernel", no_deco=True) def get_all_sumfact_nodes(): from dune.perftool.generation import retrieve_cache_items return [i for i in retrieve_cache_items("kernel_default and sumfactnodes")] def attach_vectorization_info(sf): assert isinstance(sf, SumfactKernel) if get_global_context_value("dry_run"): return _collect_sumfact_nodes(sf) else: return _cache_vectorization_info(sf, None) def no_vec(sf): return sf.copy(buffer=get_counted_variable("buffer")) def no_vectorization(sumfacts): return {sf: no_vec(sf) for sf in sumfacts} def vertical_vectorization_strategy(sumfact, depth): # If depth is 1, there is nothing do if depth == 1: if isinstance(sumfact, SumfactKernel): return {sumfact: sumfact} else: return {k: sumfact for k in sumfact.kernels} # Assert that this is not already sliced assert all(mat.slice_size is None for mat in sumfact.matrix_sequence) result = {} # Determine which of the matrices in the kernel should be sliced def determine_slice_direction(sf): for i, mat in enumerate(sf.matrix_sequence): if mat.quadrature_size % depth == 0: return i elif get_option("vectorize_allow_quadrature_changes") and mat.quadrature_size != 1: quad = list(quadrature_points_per_direction()) quad[i] = round_to_multiple(quad[i], depth) set_quadrature_points(tuple(quad)) return i elif mat.quadrature_size != 1: raise PerftoolError("Vertical vectorization is not possible!") if isinstance(sumfact, SumfactKernel): kernels = [sumfact] else: assert isinstance(sumfact, VectorizedSumfactKernel) kernels = sumfact.kernels newkernels = [] for sf in kernels: sliced = determine_slice_direction(sf) oldtab = sf.matrix_sequence[sliced] for i in range(depth): seq = list(sf.matrix_sequence) seq[sliced] = oldtab.copy(slice_size=depth, slice_index=i) newkernels.append(sf.copy(matrix_sequence=tuple(seq))) if isinstance(sumfact, SumfactKernel): buffer = get_counted_variable("vertical_buffer") result[sumfact] = VectorizedSumfactKernel(kernels=tuple(newkernels), buffer=buffer, vertical_width=depth, ) else: assert isinstance(sumfact, VectorizedSumfactKernel) for sf in kernels: result[sf] = sumfact.copy(kernels=tuple(newkernels), vertical_width=depth, ) return result def horizontal_vectorization_strategy(sumfacts, width, allow_padding=1): result = {} todo = set(sumfacts) while todo: kernels = [] for sf in sorted(todo, key=lambda s: s.position_priority): if len(kernels) < width: kernels.append(sf) todo.discard(sf) buffer = get_counted_variable("joined_buffer") if len(kernels) in range(width - allow_padding, width + 1): for sf in kernels: result[sf] = VectorizedSumfactKernel(kernels=tuple(kernels), horizontal_width=width, buffer=buffer, ) return result def diagonal_vectorization_strategy(sumfacts, width): # Read explicitly set values horizontal = get_option("vectorize_horizontal") vertical = get_option("vectorize_vertical") padding = get_option("vectorize_padding") if width == 4: if horizontal is None: horizontal = 2 if vertical is None: vertical = 2 if padding is None: padding = 0 elif width == 8: if horizontal is None: horizontal = 4 if vertical is None: vertical = 2 if padding is None: padding = 1 else: raise NotImplementedError horizontal = int(horizontal) vertical = int(vertical) padding = int(padding) result = {} horizontal_kernels = horizontal_vectorization_strategy(sumfacts, horizontal, allow_padding=padding) for sf in horizontal_kernels: vert = vertical_vectorization_strategy(horizontal_kernels[sf], width // horizontal_kernels[sf].horizontal_width) for k in vert: result[k] = vert[k] return result def greedy_vectorization_strategy(sumfacts, width): sumfacts = set(sumfacts) horizontal = width vertical = 1 allowed_padding = 1 result = {} while horizontal > 0: if horizontal > 1: horizontal_kernels = horizontal_vectorization_strategy(sumfacts, horizontal, allow_padding=allowed_padding) else: horizontal_kernels = {sf: sf for sf in sumfacts} for sf in horizontal_kernels: if horizontal_kernels[sf].horizontal_width == horizontal: vert = vertical_vectorization_strategy(horizontal_kernels[sf], vertical) for k in vert: result[k] = vert[k] sumfacts.discard(sf) horizontal = horizontal // 2 vertical = vertical * 2 # We heuristically allow padding only on the full SIMD width allowed_padding = 0 return result def decide_vectorization_strategy(): """ Decide how to vectorize! Note that the vectorization of the quadrature loop is independent of this, as it is implemented through a post-processing (== loopy transformation) step. """ logger = logging.getLogger(__name__) # Retrieve all sum factorization kernels for stage 1 and 3 from dune.perftool.generation import retrieve_cache_items all_sumfacts = [i for i in retrieve_cache_items("kernel_default and sumfactnodes")] # Stage 1 sumfactorizations that were actually used basis_sumfacts = [i for i in retrieve_cache_items('kernel_default and basis_sf_kernels')] # This means we can have sum factorizations that will not get used inactive_sumfacts = [i for i in all_sumfacts if i.stage == 1 and i not in basis_sumfacts] # All sum factorization kernels that get used active_sumfacts = [i for i in all_sumfacts if i.stage == 3 or i in basis_sumfacts] # We map inacitve sum factorizatino kernels to 0 sfdict = {} for sf in inactive_sumfacts: sfdict[sf] = 0 logger.debug("decide_vectorization_strategy: Found {} active sum factorization nodes" .format(len(active_sumfacts))) if get_option("vectorize_grads"): # Currently we base our idea here on the fact that we only group sum # factorization kernels with the same input. inputkeys = set(sf.input_key for sf in active_sumfacts) for inputkey in inputkeys: width = get_vcl_type_size(np.float64) sumfact_filter = [sf for sf in active_sumfacts if sf.input_key == inputkey] for old, new in horizontal_vectorization_strategy(sumfact_filter, width).items(): sfdict[old] = new elif get_option("vectorize_slice"): for sumfact in active_sumfacts: width = get_vcl_type_size(np.float64) for old, new in vertical_vectorization_strategy(sumfact, width).items(): sfdict[old] = new elif get_option("vectorize_diagonal"): inputkeys = set(sf.input_key for sf in active_sumfacts) for inputkey in inputkeys: width = get_vcl_type_size(np.float64) sumfact_filter = [sf for sf in active_sumfacts if sf.input_key == inputkey] for old, new in diagonal_vectorization_strategy(sumfact_filter, width).items(): sfdict[old] = new elif get_option("vectorize_greedy"): inputkeys = set(sf.input_key for sf in active_sumfacts) for inputkey in inputkeys: width = get_vcl_type_size(np.float64) sumfact_filter = [sf for sf in active_sumfacts if sf.input_key == inputkey] for old, new in greedy_vectorization_strategy(sumfact_filter, width).items(): sfdict[old] = new else: for old, new in no_vectorization(active_sumfacts).items(): sfdict[old] = new # Register the results for sf in all_sumfacts: _cache_vectorization_info(sf, sfdict.get(sf, no_vec(sf))) def vectorization_opportunity_generator(sumfacts, width): """ Generator that yields all vectorization opportunities for the given sum factorization kernels as tuples of quadrature point tuple and vectorization dictionary """ # # Find all the possible quadrature point tuples # quad_points = [quadrature_points_per_direction()] if True or get_option("vectorize_allow_quadrature_changes"): sf = next(iter(sumfacts)) depth = 1 while depth <= width: i = 0 if sf.matrix_sequence[0].face is None else 1 quad = list(quadrature_points_per_direction()) quad[i] = round_to_multiple(quad[i], depth) quad_points.append(tuple(quad)) depth = depth * 2 quad_points = list(set(quad_points)) for qp in quad_points: # # Determine vectorization opportunities given a fixed quadrature point number # for opp in fixed_quad_vectorization_opportunity_generator(frozenset(sumfacts), width, qp): yield qp, opp def fixed_quad_vectorization_opportunity_generator(sumfacts, width, qp, already=frozendict()): if len(sumfacts) == 0: # We have gone into recursion deep enough to have all sum factorization nodes # assigned their vectorized counterpart. We can yield the result now! yield already return # Otherwise we pick a random sum factorization kernel and construct all the vectorization # opportunities realizing this particular kernel and go into recursion. sf_to_decide = next(iter(sumfacts)) # Have "unvectorized" as an option, although it is not good for opp in fixed_quad_vectorization_opportunity_generator(sumfacts.difference({sf_to_decide}), width, qp, add_to_frozendict(already, {sf_to_decide: sf_to_decide}), ): yield opp # Find all the sum factorization kernels that the chosen kernel can be parallelized with # # TODO: Right now we check for same input, which is not actually needed in order # to be a suitable candidate! We should relax this concept at some point! candidates = filter(lambda sf: sf.input_key == sf_to_decide.input_key, sumfacts) horizontal = 1 while horizontal <= width: # Iterate over the possible combinations of sum factorization kernels # taking into account all the permutations of kernels. This also includes # combinations which use a padding of 1. for combo in it.chain(it.permutations(candidates, horizontal), it.permutations(candidates, horizontal - 1)): # The chosen kernels must be part of the kernels for recursion # to work correctly if sf_to_decide not in combo: continue # Set up the vectorization dict for this combo vecdict = get_vectorization_dict(combo, width // horizontal, horizontal, qp) if vecdict is None: # This particular choice was rejected for some reason. # Possible reasons: # * the quadrature point tuple not being suitable # for this vectorization strategy continue # Go into recursion to also vectorize all kernels not in this combo for opp in fixed_quad_vectorization_opportunity_generator(sumfacts.difference(combo), width, qp, add_to_frozendict(already, vecdict), ): yield opp horizontal = horizontal * 2 def get_vectorization_dict(sumfacts, vertical, horizontal, qp): # Enhance the list of sumfact nodes by adding vertical splittings kernels = [] for sf in sumfacts: # No slicing needed in the pure horizontal case if vertical == 1: kernels.append(sf) continue # Determine the slicing direction slice_direction = 0 if sf.matrix_sequence[0].face is None else 1 if qp[slice_direction] % vertical != 0: return None # Split the basis tabulation matrices oldtab = sf.matrix_sequence[slice_direction] for i in range(vertical): seq = list(sf.matrix_sequence) seq[slice_direction] = oldtab.copy(slice_size=vertical, slice_index=i) kernels.append(sf.copy(matrix_sequence=tuple(seq))) # Join the new kernels into a sum factorization node buffer = get_counted_variable("joined_buffer") return {sf: VectorizedSumfactKernel(kernels=tuple(kernels), horizontal_width=horizontal, vertical_width=vertical, buffer=buffer, ) for sf in sumfacts}