Newer
Older
""" Sum factorization vectorization """
from dune.perftool.generation import (generator_factory,
get_counted_variable,
from dune.perftool.pdelab.restriction import (Restriction,
restricted_name,
)
from dune.perftool.error import PerftoolError
from dune.perftool.options import get_option
import loopy as lp
@generator_factory(item_tags=("vecinfo", "dryrundata"), cache_key_generator=lambda o, n: o)
def _cache_vectorization_info(old, new):
if new is None:
raise PerftoolError("Vectorization info for sum factorization kernel was not gathered correctly!")
return new
_collect_sumfact_nodes = generator_factory(item_tags=("sumfactnodes", "dryrundata"), no_deco=True)
def attach_vectorization_info(sf):
assert isinstance(sf, SumfactKernel)
if get_global_context_value("dry_run"):
def no_vectorization(sumfacts):
_cache_vectorization_info(sf, sf.copy(buffer=get_counted_variable("buffer"),
input=get_counted_variable("input")))
def decide_stage_vectorization_strategy(sumfacts, stage, restriction):
stage_sumfacts = frozenset([sf for sf in sumfacts if sf.stage == stage and sf.restriction == restriction])
if len(stage_sumfacts) in (3, 4):
# Map the sum factorization to their position in the joint kernel
available = set(range(4))
for sf in stage_sumfacts:
if sf.preferred_position is not None:
# This asserts that no two kernels want to take the same position
# Later on, more complicated stuff might be necessary here.
assert sf.preferred_position in available
available.discard(sf.preferred_position)
position_mapping[sf] = sf.preferred_position
# Choose a position for those that have no preferred one!
for sumf in stage_sumfacts:
if sumf.preferred_position is None:
position_mapping[sumf] = available.pop()
# Enable vectorization strategy:
inp = get_counted_variable("joined_input")
buf = get_counted_variable("joined_buffer")
# Collect the large matrices!
large_a_matrices = []
for i in range(len(next(iter(stage_sumfacts)).a_matrices)):
# Assert that the matrices of all sum factorizations have the same size
assert len(set(tuple(sf.a_matrices[i].rows for sf in stage_sumfacts))) == 1
assert len(set(tuple(sf.a_matrices[i].cols for sf in stage_sumfacts))) == 1
# Collect the derivative information
derivative = [False] * 4
for sf in stage_sumfacts:
derivative[position_mapping[sf]] = sf.a_matrices[i].derivative
from dune.perftool.sumfact.amatrix import LargeAMatrix
large = LargeAMatrix(rows=next(iter(stage_sumfacts)).a_matrices[i].rows,
cols=next(iter(stage_sumfacts)).a_matrices[i].cols,
transpose=next(iter(stage_sumfacts)).a_matrices[i].transpose,
face=next(iter(stage_sumfacts)).a_matrices[i].face,
)
large_a_matrices.append(large)
for sumf in stage_sumfacts:
_cache_vectorization_info(sumf,
sumf.copy(a_matrices=tuple(large_a_matrices),
buffer=buf,
input=inp,
index=position_mapping[sumf],
padding=frozenset(available),
insn_dep=frozenset().union(sf.insn_dep for sf in stage_sumfacts),
)
else:
# Disable vectorization strategy
no_vectorization(stage_sumfacts)
def decide_vectorization_strategy():
""" Decide how to vectorize!
Note that the vectorization of the quadrature loop is independent of this,
as it is implemented through a post-processing (== loopy transformation) step.
"""
from dune.perftool.generation import retrieve_cache_items
insns = [i for i in retrieve_cache_items("kernel_default and instruction")]
# Find all sum factorization kernels
sumfacts = frozenset()
for insn in insns:
if isinstance(insn, (lp.Assignment, lp.CallInstruction)):
sumfacts = sumfacts.union(find_sumfact(insn.expression))
if not get_option("vectorize_grads"):
no_vectorization(sumfacts)
else:
res = (Restriction.NONE, Restriction.POSITIVE, Restriction.NEGATIVE)
# Stage 1 kernels
for restriction in res:
decide_stage_vectorization_strategy(sumfacts, 1, restriction)
# Stage 3 kernels
import itertools as it
for restriction in it.product(res, res):
decide_stage_vectorization_strategy(sumfacts, 3, restriction)
class HasSumfactMapper(lp.symbolic.CombineMapper):
def combine(self, *args):
return frozenset().union(*tuple(*args))
def map_constant(self, expr):
return frozenset()
def map_algebraic_leaf(self, expr):
return frozenset()
def map_loopy_function_identifier(self, expr):
return frozenset()
def map_sumfact_kernel(self, expr):
return frozenset({expr})
def map_tagged_variable(self, expr):
return frozenset()
def find_sumfact(expr):