""" Sum factorization vectorization """ import logging from dune.perftool.loopy.target import dtype_floatingpoint from dune.perftool.loopy.vcl import get_vcl_type_size from dune.perftool.loopy.symbolic import SumfactKernel, VectorizedSumfactKernel from dune.perftool.generation import (backend, generator_factory, get_backend, get_counted_variable, get_global_context_value, ) from dune.perftool.pdelab.restriction import (Restriction, restricted_name, ) from dune.perftool.sumfact.tabulation import (BasisTabulationMatrixArray, quadrature_points_per_direction, set_quadrature_points, ) from dune.perftool.error import PerftoolError from dune.perftool.options import get_option from dune.perftool.tools import add_to_frozendict, round_to_multiple from pytools import product from frozendict import frozendict import itertools as it import loopy as lp import numpy as np import math @generator_factory(item_tags=("vecinfo", "dryrundata"), cache_key_generator=lambda o, n: o) def _cache_vectorization_info(old, new): if new is None: raise PerftoolError("Vectorization info for sum factorization kernel was not gathered correctly!") return new _collect_sumfact_nodes = generator_factory(item_tags=("sumfactnodes", "dryrundata"), context_tags="kernel", no_deco=True) def attach_vectorization_info(sf): assert isinstance(sf, SumfactKernel) if get_global_context_value("dry_run"): return _collect_sumfact_nodes(sf) else: return _cache_vectorization_info(sf, None) def position_penalty_factor(sf): if isinstance(sf, SumfactKernel) or sf.vertical_width > 1: return 1 else: return 1 + sum(abs(sf.kernels[i].position_priority - i) if sf.kernels[i].position_priority is not None else 0 for i in range(sf.length)) @backend(interface="vectorization_strategy", name="model") def costmodel(sf): # Penalize vertical vectorization vertical_penalty = 1 + math.log(sf.vertical_width) # Penalize scalar sum factorization kernels scalar_penalty = 1 if isinstance(sf, SumfactKernel): scalar_penalty = get_vcl_type_size(dtype_floatingpoint()) # Return total operations return sf.operations * position_penalty_factor(sf) * vertical_penalty * scalar_penalty @backend(interface="vectorization_strategy", name="explicit") def explicit_costfunction(sf): # Read the explicitly set values for horizontal and vertical vectorization width = get_vcl_type_size(dtype_floatingpoint()) horizontal = get_option("vectorization_horizontal") if horizontal is None: horizontal = width vertical = get_option("vectorization_vertical") if vertical is None: vertical = 1 horizontal = int(horizontal) vertical = int(vertical) if sf.horizontal_width == horizontal and sf.vertical_width == vertical: # Penalize position mapping return position_penalty_factor(sf) else: return 1000000000000 def strategy_cost(strategy): func = get_backend(interface="vectorization_strategy", selector=lambda: get_option("vectorization_strategy")) keys = set(sf.cache_key for sf in strategy.values()) # Sum over all the sum factorization kernels in the realization score = 0.0 for sf in strategy.values(): if sf.cache_key in keys: score = score + float(func(sf)) keys.discard(sf.cache_key) return score def stringify_vectorization_strategy(strategy): result = [] qp, strategy = strategy result.append("Printing potential vectorization strategy:") result.append("Quadrature point tuple: {}".format(qp)) # Look for all realizations in the strategy and iterate over them cache_keys = frozenset(v.cache_key for v in strategy.values()) for ck in cache_keys: # Filter all the kernels that are realized by this and print for key in strategy: if strategy[key].cache_key == ck: result.append("{}:".format(key)) # Find one representative to print for val in strategy.values(): if val.cache_key == ck: result.append(" {}".format(val)) break return result def decide_vectorization_strategy(): """ Decide how to vectorize! Note that the vectorization of the quadrature loop is independent of this, as it is implemented through a post-processing (== loopy transformation) step. """ logger = logging.getLogger(__name__) # Retrieve all sum factorization kernels for stage 1 and 3 from dune.perftool.generation import retrieve_cache_items all_sumfacts = [i for i in retrieve_cache_items("kernel_default and sumfactnodes")] # Stage 1 sumfactorizations that were actually used basis_sumfacts = [i for i in retrieve_cache_items('kernel_default and basis_sf_kernels')] # This means we can have sum factorizations that will not get used inactive_sumfacts = [i for i in all_sumfacts if i.stage == 1 and i not in basis_sumfacts] # All sum factorization kernels that get used active_sumfacts = [i for i in all_sumfacts if i.stage == 3 or i in basis_sumfacts] # If no vectorization is needed, abort now if get_option("vectorization_strategy") == "none": for sf in all_sumfacts: _cache_vectorization_info(sf, sf.copy(buffer=get_counted_variable("buffer"))) return logger.debug("decide_vectorization_strategy: Found {} active sum factorization nodes" .format(len(active_sumfacts))) # Find the best vectorization strategy by using a costmodel width = get_vcl_type_size(dtype_floatingpoint()) # # Optimize over all the possible quadrature point tuples # quad_points = [quadrature_points_per_direction()] if get_option("vectorization_allow_quadrature_changes"): sf = next(iter(active_sumfacts)) depth = 1 while depth <= width: i = 0 if sf.matrix_sequence[0].face is None else 1 quad = list(quadrature_points_per_direction()) quad[i] = round_to_multiple(quad[i], depth) quad_points.append(tuple(quad)) depth = depth * 2 quad_points = list(set(quad_points)) # Find the minimum cost strategy between all the quadrature point tuples optimal_strategies = {qp: fixed_quadrature_optimal_vectorization(active_sumfacts, width, qp) for qp in quad_points} qp = min(optimal_strategies, key=lambda qp: strategy_cost(optimal_strategies[qp])) sfdict = optimal_strategies[qp] set_quadrature_points(qp) logger.debug("decide_vectorization_strategy: Decided for the following strategy:" "\n".join(stringify_vectorization_strategy((qp, sfdict)))) # We map inactive sum factorization kernels to 0 sfdict = add_to_frozendict(sfdict, {sf: 0 for sf in inactive_sumfacts}) # Register the results for sf in all_sumfacts: _cache_vectorization_info(sf, sfdict[sf]) def fixed_quadrature_optimal_vectorization(sumfacts, width, qp): """ For a given quadrature point tuple, find the optimal strategy! In order to have this scale sufficiently, we cannot simply list all vectorization opportunities and score them individually, but we need to do a divide and conquer approach. """ set_quadrature_points(qp) # Find the sets of simultaneously realizable kernels (thats an equivalence relation) keys = frozenset(sf.input_key for sf in sumfacts) # Find minimums for each of these sets sfdict = frozendict() for key in keys: key_sumfacts = frozenset(sf for sf in sumfacts if sf.input_key == key) minimum = min(fixed_quad_vectorization_opportunity_generator(key_sumfacts, width, qp), key=strategy_cost) sfdict = add_to_frozendict(sfdict, minimum) return sfdict def fixed_quad_vectorization_opportunity_generator(sumfacts, width, qp, already=frozendict()): if len(sumfacts) == 0: # We have gone into recursion deep enough to have all sum factorization nodes # assigned their vectorized counterpart. We can yield the result now! yield already return # Otherwise we pick a random sum factorization kernel and construct all the vectorization # opportunities realizing this particular kernel and go into recursion. sf_to_decide = next(iter(sumfacts)) # Have "unvectorized" as an option, although it is not good for opp in fixed_quad_vectorization_opportunity_generator(sumfacts.difference({sf_to_decide}), width, qp, add_to_frozendict(already, {sf_to_decide: sf_to_decide.copy(buffer=get_counted_variable("buffer"))} ), ): yield opp horizontal = 1 while horizontal <= width: # Iterate over the possible combinations of sum factorization kernels # taking into account all the permutations of kernels. This also includes # combinations which use a padding of 1 - but only for pure horizontality. generators = [it.permutations(sumfacts, horizontal)] if horizontal >= 4: generators.append(it.permutations(sumfacts, horizontal - 1)) for combo in it.chain(*generators): # The chosen kernels must be part of the kernels for recursion # to work correctly if sf_to_decide not in combo: continue # Set up the vectorization dict for this combo vecdict = get_vectorization_dict(combo, width // horizontal, horizontal, qp) if vecdict is None: # This particular choice was rejected for some reason. # Possible reasons: # * the quadrature point tuple not being suitable # for this vectorization strategy continue # Go into recursion to also vectorize all kernels not in this combo for opp in fixed_quad_vectorization_opportunity_generator(sumfacts.difference(combo), width, qp, add_to_frozendict(already, vecdict), ): yield opp horizontal = horizontal * 2 def get_vectorization_dict(sumfacts, vertical, horizontal, qp): # Enhance the list of sumfact nodes by adding vertical splittings kernels = [] for sf in sumfacts: # No slicing needed in the pure horizontal case if vertical == 1: kernels.append(sf) continue # Determine the slicing direction slice_direction = 0 if sf.matrix_sequence[0].face is None else 1 if qp[slice_direction] % vertical != 0: return None # Split the basis tabulation matrices oldtab = sf.matrix_sequence[slice_direction] for i in range(vertical): seq = list(sf.matrix_sequence) seq[slice_direction] = oldtab.copy(slice_size=vertical, slice_index=i) kernels.append(sf.copy(matrix_sequence=tuple(seq))) # Join the new kernels into a sum factorization node buffer = get_counted_variable("joined_buffer") return {sf: VectorizedSumfactKernel(kernels=tuple(kernels), horizontal_width=horizontal, vertical_width=vertical, buffer=buffer, ) for sf in sumfacts}