Newer
Older
""" Sum factorization vectorization """
from dune.perftool.loopy.vcl import get_vcl_type_size
from dune.perftool.loopy.symbolic import SumfactKernel, VectorizedSumfactKernel
from dune.perftool.generation import (generator_factory,
get_counted_variable,
from dune.perftool.pdelab.restriction import (Restriction,
restricted_name,
)
from dune.perftool.sumfact.tabulation import (BasisTabulationMatrixArray,
quadrature_points_per_direction,
set_quadrature_points,
from dune.perftool.error import PerftoolError
from dune.perftool.options import get_option
from dune.perftool.tools import add_to_frozendict,round_to_multiple
from frozendict import frozendict
import itertools as it
import loopy as lp
import numpy as np
@generator_factory(item_tags=("vecinfo", "dryrundata"), cache_key_generator=lambda o, n: o)
def _cache_vectorization_info(old, new):
if new is None:
raise PerftoolError("Vectorization info for sum factorization kernel was not gathered correctly!")
return new
_collect_sumfact_nodes = generator_factory(item_tags=("sumfactnodes", "dryrundata"), context_tags="kernel", no_deco=True)
def get_all_sumfact_nodes():
from dune.perftool.generation import retrieve_cache_items
return [i for i in retrieve_cache_items("kernel_default and sumfactnodes")]
def attach_vectorization_info(sf):
assert isinstance(sf, SumfactKernel)
if get_global_context_value("dry_run"):
return sf.copy(buffer=get_counted_variable("buffer"))
def no_vectorization(sumfacts):
return {sf: no_vec(sf) for sf in sumfacts}
def vertical_vectorization_strategy(sumfact, depth):
# If depth is 1, there is nothing do
if depth == 1:
if isinstance(sumfact, SumfactKernel):
return {sumfact: sumfact}
else:
return {k: sumfact for k in sumfact.kernels}
# Assert that this is not already sliced
assert all(mat.slice_size is None for mat in sumfact.matrix_sequence)
# Determine which of the matrices in the kernel should be sliced
def determine_slice_direction(sf):
for i, mat in enumerate(sf.matrix_sequence):
if mat.quadrature_size % depth == 0:
return i
elif get_option("vectorize_allow_quadrature_changes") and mat.quadrature_size != 1:
quad = list(quadrature_points_per_direction())
quad[i] = round_to_multiple(quad[i], depth)
set_quadrature_points(tuple(quad))
elif mat.quadrature_size != 1:
raise PerftoolError("Vertical vectorization is not possible!")
if isinstance(sumfact, SumfactKernel):
kernels = [sumfact]
else:
assert isinstance(sumfact, VectorizedSumfactKernel)
kernels = sumfact.kernels
newkernels = []
for sf in kernels:
sliced = determine_slice_direction(sf)
oldtab = sf.matrix_sequence[sliced]
for i in range(depth):
seq = list(sf.matrix_sequence)
seq[sliced] = oldtab.copy(slice_size=depth,
slice_index=i)
newkernels.append(sf.copy(matrix_sequence=tuple(seq)))
if isinstance(sumfact, SumfactKernel):
buffer = get_counted_variable("vertical_buffer")
result[sumfact] = VectorizedSumfactKernel(kernels=tuple(newkernels),
buffer=buffer,
vertical_width=depth,
)
else:
assert isinstance(sumfact, VectorizedSumfactKernel)
for sf in kernels:
result[sf] = sumfact.copy(kernels=tuple(newkernels),
vertical_width=depth,
)
def horizontal_vectorization_strategy(sumfacts, width, allow_padding=1):
result = {}
todo = set(sumfacts)
while todo:
kernels = []
for sf in sorted(todo, key=lambda s: s.position_priority):
if len(kernels) < width:
kernels.append(sf)
todo.discard(sf)
buffer = get_counted_variable("joined_buffer")
if len(kernels) in range(width - allow_padding, width + 1):
for sf in kernels:
result[sf] = VectorizedSumfactKernel(kernels=tuple(kernels),
horizontal_width=width,
buffer=buffer,
)
# Read explicitly set values
horizontal = get_option("vectorize_horizontal")
vertical = get_option("vectorize_vertical")
padding = get_option("vectorize_padding")
if horizontal is None:
horizontal = 2
if vertical is None:
vertical = 2
if padding is None:
padding = 0
if horizontal is None:
horizontal = 4
if vertical is None:
vertical = 2
if padding is None:
padding = 1
horizontal = int(horizontal)
vertical = int(vertical)
padding = int(padding)
horizontal_kernels = horizontal_vectorization_strategy(sumfacts, horizontal, allow_padding=padding)
vert = vertical_vectorization_strategy(horizontal_kernels[sf], width // horizontal_kernels[sf].horizontal_width)
for k in vert:
result[k] = vert[k]
def greedy_vectorization_strategy(sumfacts, width):
sumfacts = set(sumfacts)
horizontal = width
vertical = 1
result = {}
while horizontal > 0:
if horizontal > 1:
horizontal_kernels = horizontal_vectorization_strategy(sumfacts, horizontal, allow_padding=allowed_padding)
else:
horizontal_kernels = {sf: sf for sf in sumfacts}
for sf in horizontal_kernels:
if horizontal_kernels[sf].horizontal_width == horizontal:
vert = vertical_vectorization_strategy(horizontal_kernels[sf],
vertical)
for k in vert:
result[k] = vert[k]
sumfacts.discard(sf)
# We heuristically allow padding only on the full SIMD width
allowed_padding = 0
def decide_vectorization_strategy():
""" Decide how to vectorize!
Note that the vectorization of the quadrature loop is independent of this,
as it is implemented through a post-processing (== loopy transformation) step.
"""
logger = logging.getLogger(__name__)
# Retrieve all sum factorization kernels for stage 1 and 3
from dune.perftool.generation import retrieve_cache_items
all_sumfacts = [i for i in retrieve_cache_items("kernel_default and sumfactnodes")]
# Stage 1 sumfactorizations that were actually used
basis_sumfacts = [i for i in retrieve_cache_items('kernel_default and basis_sf_kernels')]
# This means we can have sum factorizations that will not get used
inactive_sumfacts = [i for i in all_sumfacts if i.stage == 1 and i not in basis_sumfacts]
# All sum factorization kernels that get used
active_sumfacts = [i for i in all_sumfacts if i.stage == 3 or i in basis_sumfacts]
# We map inacitve sum factorizatino kernels to 0
for sf in inactive_sumfacts:
sfdict[sf] = 0
logger.debug("decide_vectorization_strategy: Found {} active sum factorization nodes"
.format(len(active_sumfacts)))
if get_option("vectorize_grads"):
# Currently we base our idea here on the fact that we only group sum
# factorization kernels with the same input.
inputkeys = set(sf.input_key for sf in active_sumfacts)
width = get_vcl_type_size(np.float64)
sumfact_filter = [sf for sf in active_sumfacts if sf.input_key == inputkey]
for old, new in horizontal_vectorization_strategy(sumfact_filter, width).items():
sfdict[old] = new
elif get_option("vectorize_slice"):
for sumfact in active_sumfacts:
width = get_vcl_type_size(np.float64)
for old, new in vertical_vectorization_strategy(sumfact, width).items():
sfdict[old] = new
elif get_option("vectorize_diagonal"):
inputkeys = set(sf.input_key for sf in active_sumfacts)
for inputkey in inputkeys:
width = get_vcl_type_size(np.float64)
sumfact_filter = [sf for sf in active_sumfacts if sf.input_key == inputkey]
for old, new in diagonal_vectorization_strategy(sumfact_filter, width).items():
sfdict[old] = new
elif get_option("vectorize_greedy"):
inputkeys = set(sf.input_key for sf in active_sumfacts)
for inputkey in inputkeys:
width = get_vcl_type_size(np.float64)
sumfact_filter = [sf for sf in active_sumfacts if sf.input_key == inputkey]
for old, new in greedy_vectorization_strategy(sumfact_filter, width).items():
sfdict[old] = new
for old, new in no_vectorization(active_sumfacts).items():
for sf in all_sumfacts:
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
def vectorization_opportunity_generator(sumfacts, width):
""" Generator that yields all vectorization opportunities for the given
sum factorization kernels as tuples of quadrature point tuple and vectorization
dictionary """
#
# Find all the possible quadrature point tuples
#
quad_points = [quadrature_points_per_direction()]
if True or get_option("vectorize_allow_quadrature_changes"):
sf = next(iter(sumfacts))
depth = 1
while depth <= width:
i = 0 if sf.matrix_sequence[0].face is None else 1
quad = list(quadrature_points_per_direction())
quad[i] = round_to_multiple(quad[i], depth)
quad_points.append(tuple(quad))
depth = depth * 2
quad_points = list(set(quad_points))
for qp in quad_points:
#
# Determine vectorization opportunities given a fixed quadrature point number
#
for opp in fixed_quad_vectorization_opportunity_generator(frozenset(sumfacts), width, qp):
yield qp, opp
def fixed_quad_vectorization_opportunity_generator(sumfacts, width, qp, already=frozendict()):
if len(sumfacts) == 0:
# We have gone into recursion deep enough to have all sum factorization nodes
# assigned their vectorized counterpart. We can yield the result now!
yield already
return
# Otherwise we pick a random sum factorization kernel and construct all the vectorization
# opportunities realizing this particular kernel and go into recursion.
sf_to_decide = next(iter(sumfacts))
# Have "unvectorized" as an option, although it is not good
for opp in fixed_quad_vectorization_opportunity_generator(sumfacts.difference({sf_to_decide}),
width,
qp,
add_to_frozendict(already, {sf_to_decide: sf_to_decide}),
):
yield opp
# Find all the sum factorization kernels that the chosen kernel can be parallelized with
#
# TODO: Right now we check for same input, which is not actually needed in order
# to be a suitable candidate! We should relax this concept at some point!
candidates = filter(lambda sf: sf.input_key == sf_to_decide.input_key, sumfacts)
horizontal = 1
while horizontal <= width:
# Iterate over the possible combinations of sum factorization kernels
# taking into account all the permutations of kernels. This also includes
# combinations which use a padding of 1.
for combo in it.chain(it.permutations(candidates, horizontal),
it.permutations(candidates, horizontal - 1)):
# The chosen kernels must be part of the kernels for recursion
# to work correctly
if sf_to_decide not in combo:
continue
# Set up the vectorization dict for this combo
vecdict = get_vectorization_dict(combo, width // horizontal, horizontal, qp)
if vecdict is None:
# This particular choice was rejected for some reason.
# Possible reasons:
# * the quadrature point tuple not being suitable
# for this vectorization strategy
continue
# Go into recursion to also vectorize all kernels not in this combo
for opp in fixed_quad_vectorization_opportunity_generator(sumfacts.difference(combo),
width,
qp,
add_to_frozendict(already, vecdict),
):
yield opp
horizontal = horizontal * 2
def get_vectorization_dict(sumfacts, vertical, horizontal, qp):
# Enhance the list of sumfact nodes by adding vertical splittings
kernels = []
for sf in sumfacts:
# No slicing needed in the pure horizontal case
if vertical == 1:
kernels.append(sf)
continue
# Determine the slicing direction
slice_direction = 0 if sf.matrix_sequence[0].face is None else 1
if qp[slice_direction] % vertical != 0:
return None
# Split the basis tabulation matrices
oldtab = sf.matrix_sequence[slice_direction]
for i in range(vertical):
seq = list(sf.matrix_sequence)
seq[slice_direction] = oldtab.copy(slice_size=vertical,
slice_index=i)
kernels.append(sf.copy(matrix_sequence=tuple(seq)))
# Join the new kernels into a sum factorization node
buffer = get_counted_variable("joined_buffer")
return {sf: VectorizedSumfactKernel(kernels=tuple(kernels),
horizontal_width=horizontal,
vertical_width=vertical,
buffer=buffer,
) for sf in sumfacts}