""" Sum factorization vectorization """ from dune.perftool.loopy.symbolic import SumfactKernel from dune.perftool.generation import (generator_factory, get_counted_variable, get_global_context_value, ) from dune.perftool.pdelab.restriction import (Restriction, restricted_name, ) from dune.perftool.error import PerftoolError from dune.perftool.options import get_option import loopy as lp @generator_factory(item_tags=("vecinfo", "dryrundata"), cache_key_generator=lambda o, n: o) def _cache_vectorization_info(old, new): if new is None: raise PerftoolError("Vectorization info for sum factorization kernel was not gathered correctly!") return new _collect_sumfact_nodes = generator_factory(item_tags=("sumfactnodes", "dryrundata"), no_deco=True) def attach_vectorization_info(sf): assert isinstance(sf, SumfactKernel) if get_global_context_value("dry_run"): return _collect_sumfact_nodes(sf) else: return _cache_vectorization_info(sf, None) def no_vectorization(sumfacts): for sf in sumfacts: _cache_vectorization_info(sf, sf.copy(buffer=get_counted_variable("buffer"), input=get_counted_variable("input"))) def decide_stage_vectorization_strategy(sumfacts, stage, restriction): stage_sumfacts = frozenset([sf for sf in sumfacts if sf.stage == stage and sf.restriction == restriction]) if len(stage_sumfacts) in (3, 4): # Map the sum factorization to their position in the joint kernel position_mapping = {} available = set(range(4)) for sf in stage_sumfacts: if sf.preferred_position is not None: # This asserts that no two kernels want to take the same position # Later on, more complicated stuff might be necessary here. assert sf.preferred_position in available available.discard(sf.preferred_position) position_mapping[sf] = sf.preferred_position # Choose a position for those that have no preferred one! for sumf in stage_sumfacts: if sumf.preferred_position is None: position_mapping[sumf] = available.pop() # Enable vectorization strategy: inp = get_counted_variable("joined_input") buf = get_counted_variable("joined_buffer") # Collect the large matrices! large_a_matrices = [] for i in range(len(next(iter(stage_sumfacts)).a_matrices)): # Assert that the matrices of all sum factorizations have the same size assert len(set(tuple(sf.a_matrices[i].rows for sf in stage_sumfacts))) == 1 assert len(set(tuple(sf.a_matrices[i].cols for sf in stage_sumfacts))) == 1 # Collect the derivative information derivative = [False] * 4 for sf in stage_sumfacts: derivative[position_mapping[sf]] = sf.a_matrices[i].derivative from dune.perftool.sumfact.amatrix import LargeAMatrix large = LargeAMatrix(rows=next(iter(stage_sumfacts)).a_matrices[i].rows, cols=next(iter(stage_sumfacts)).a_matrices[i].cols, transpose=next(iter(stage_sumfacts)).a_matrices[i].transpose, derivative=tuple(derivative), face=next(iter(stage_sumfacts)).a_matrices[i].face, ) large_a_matrices.append(large) for sumf in stage_sumfacts: _cache_vectorization_info(sumf, sumf.copy(a_matrices=tuple(large_a_matrices), buffer=buf, input=inp, index=position_mapping[sumf], padding=frozenset(available), insn_dep=frozenset().union(sf.insn_dep for sf in stage_sumfacts), ) ) else: # Disable vectorization strategy no_vectorization(stage_sumfacts) def decide_vectorization_strategy(): """ Decide how to vectorize! Note that the vectorization of the quadrature loop is independent of this, as it is implemented through a post-processing (== loopy transformation) step. """ from dune.perftool.generation import retrieve_cache_items insns = [i for i in retrieve_cache_items("kernel_default and instruction")] # Find all sum factorization kernels sumfacts = frozenset() for insn in insns: if isinstance(insn, (lp.Assignment, lp.CallInstruction)): sumfacts = sumfacts.union(find_sumfact(insn.expression)) if not get_option("vectorize_grads"): no_vectorization(sumfacts) else: res = (Restriction.NONE, Restriction.POSITIVE, Restriction.NEGATIVE) # Stage 1 kernels for restriction in res: decide_stage_vectorization_strategy(sumfacts, 1, restriction) # Stage 3 kernels import itertools as it for restriction in it.product(res, res): decide_stage_vectorization_strategy(sumfacts, 3, restriction) class HasSumfactMapper(lp.symbolic.CombineMapper): def combine(self, *args): return frozenset().union(*tuple(*args)) def map_constant(self, expr): return frozenset() def map_algebraic_leaf(self, expr): return frozenset() def map_loopy_function_identifier(self, expr): return frozenset() def map_sumfact_kernel(self, expr): return frozenset({expr}) def map_tagged_variable(self, expr): return frozenset() def find_sumfact(expr): return HasSumfactMapper()(expr)