Newer
Older
""" Sum factorization vectorization """
from dune.perftool.loopy.vcl import get_vcl_type_size
from dune.perftool.loopy.symbolic import SumfactKernel, VectorizedSumfactKernel
from dune.perftool.generation import (generator_factory,
get_counted_variable,
from dune.perftool.pdelab.restriction import (Restriction,
restricted_name,
)
from dune.perftool.sumfact.tabulation import BasisTabulationMatrixArray
from dune.perftool.error import PerftoolError
from dune.perftool.options import get_option
import loopy as lp
import numpy as np
@generator_factory(item_tags=("vecinfo", "dryrundata"), cache_key_generator=lambda o, n: o)
def _cache_vectorization_info(old, new):
if new is None:
raise PerftoolError("Vectorization info for sum factorization kernel was not gathered correctly!")
return new
_collect_sumfact_nodes = generator_factory(item_tags=("sumfactnodes", "dryrundata"), context_tags="kernel", no_deco=True)
def attach_vectorization_info(sf):
assert isinstance(sf, SumfactKernel)
if get_global_context_value("dry_run"):
def no_vec(sf):
return sf.copy(buffer=get_counted_variable("buffer"),
input=get_counted_variable("input"))
def no_vectorization(sumfacts):
return {sf: no_vec(sf) for sf in sumfacts}
def vertical_vectorization_strategy(sumfact, depth):
# Assert that this is not already sliced
assert all(mat.slice_size is None for mat in sumfact.matrix_sequence)
# Determine which of the matrices in the kernel should be sliced
def determine_slice_direction():
for i, mat in enumerate(sumfact.matrix_sequence):
if mat.quadrature_size % depth == 0:
return i
elif mat.quadrature_size != 1:
raise PerftoolError("Vertical vectorization is not possible!")
sliced = determine_slice_direction()
kernels = []
oldtab = sumfact.matrix_sequence[sliced]
for i in range(depth):
seq = list(sumfact.matrix_sequence)
seq[sliced] = oldtab.copy(slice_size=depth,
slice_index=i)
kernels.append(sumfact.copy(matrix_sequence=tuple(seq)))
buffer = get_counted_variable("vertical_buffer")
input = get_counted_variable("vertical_input")
vsf = VectorizedSumfactKernel(kernels=tuple(kernels),
buffer=buffer,
input=input,
vertical_width=depth,
)
return {sumfact: vsf}
def horizontal_vectorization_strategy(sumfacts, width):
result = {}
todo = set(sumfacts)
while todo:
available = set(range(width))
for sf in todo:
if sf.preferred_position is not None and sf.preferred_position in available:
available.discard(sf.preferred_position)
position_mapping[sf.preferred_position] = sf
for sf in position_mapping.values():
todo.discard(sf)
for pos in available:
if todo:
position_mapping[pos] = todo.pop()
kernels = [None] * len(position_mapping)
for pos in position_mapping:
kernels[pos] = position_mapping[pos]
kernels = tuple(kernels)
buffer = get_counted_variable("joined_buffer")
input = get_counted_variable("joined_input")
for sumf in kernels:
if len(kernels) in (width, width - 1):
result[sumf] = VectorizedSumfactKernel(kernels=kernels,
horizontal_width=width,
buffer=buffer,
input=input,
)
else:
result[sumf] = no_vec(sumf)
return result
def diagonal_vectorization_stratget(sumfacts, width):
return horizontal_vectorization_strategy(sumfacts, width)
def decide_vectorization_strategy():
""" Decide how to vectorize!
Note that the vectorization of the quadrature loop is independent of this,
as it is implemented through a post-processing (== loopy transformation) step.
"""
from dune.perftool.generation import retrieve_cache_items
sumfacts = [i for i in retrieve_cache_items("kernel_default and sumfactnodes")]
if get_option("vectorize_grads"):
# Currently we base our idea here on the fact that we only group sum
# factorization kernels with the same input.
inputkeys = set(sf.input_key for sf in sumfacts)
for inputkey in inputkeys:
width = get_vcl_type_size(np.float64)
sumfact_filter = [sf for sf in sumfacts if sf.input_key == inputkey]
sfdict.update(**horizontal_vectorization_strategy(sumfact_filter, width))
elif get_option("vectorize_slice"):
for sumfact in sumfacts:
width = get_vcl_type_size(np.float64)
sfdict.update(**vertical_vectorization_strategy(sumfact, width))
elif get_option("vectorize_diagonal"):
width = get_vcl_type_size(np.float64)
sfdict.update(**diagonal_vectorization_stragegy(sumfact, width))
sfdict.update(**no_vectorization(sumfacts))
# Register the results
for old, new in sfdict.items():
_cache_vectorization_info(old, new)