Skip to content
Snippets Groups Projects
Commit f1382d61 authored by René Heß's avatar René Heß
Browse files

Restructure where permutation happens for sumfact vectorization

Non-fastdg: Permutation of the input happens before the sum factorization
kernel when we setup the input. This is done by a method of the corresponding
interface class.

Fastdg: In this case the input will always be ordered according to x,y,... This
means the permutation needs to happen in the sumfact kernel. Since we want to
vectorize sumfact kernels with different input permutation in an upper/lower
way we need to do this permutation in the corresponding interface class. This
is done in the realize_direct method and in the vectorized case the
corresponding methods of the scalar sumfact kernels are called.
parent 10b3bbde
No related branches found
No related tags found
No related merge requests found
......@@ -143,7 +143,10 @@ def pymbolic_coefficient(container, lfs, index):
if not isinstance(lfs, Expression):
lfs = Variable(lfs)
return Call(CoefficientAccess(container), (lfs, Variable(index),))
if isinstance(index, str):
index = Variable(index)
return Call(CoefficientAccess(container), (lfs, index,))
def type_coefficientcontainer():
......
......@@ -15,6 +15,7 @@ from dune.codegen.generation import (backend,
instruction,
post_include,
kernel_cached,
silenced_warning,
temporary_variable,
transform,
valuearg
......@@ -30,7 +31,8 @@ from dune.codegen.pdelab.restriction import restricted_name
from dune.codegen.pdelab.signatures import assembler_routine_name
from dune.codegen.pdelab.geometry import world_dimension
from dune.codegen.pdelab.spaces import name_lfs
from dune.codegen.sumfact.permutation import (permute_forward,
from dune.codegen.sumfact.permutation import (permute_backward,
permute_forward,
sumfact_cost_permutation_strategy,
sumfact_quadrature_permutation_strategy,
)
......@@ -239,6 +241,25 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord):
tags=frozenset({"sumfact_stage3"}),
**args)})
def realize_input(self, shape, inames, vec_shape, vec_iname, buffer, ftags, l):
# TODO: This should happen in stage 2 and not in stage 3
shape = permute_backward(shape, self.cost_permutation)
inames = permute_backward(inames, self.cost_permutation)
# Get a temporary that interprets the base storage of the input
# as a column-major matrix. In later iteration of the matrix loop
# this reinterprets the output of the previous iteration.
inp = buffer.get_temporary("buff_step{}_in".format(l),
shape=shape + vec_shape,
dim_tags=ftags,
)
# The input temporary will only be read from, so we need to silence
# the loopy warning
silenced_warning('read_no_write({})'.format(inp))
return prim.Subscript(prim.Variable(inp), inames + vec_iname)
@property
def function_name_suffix(self):
if get_form_option("fastdg"):
......
......@@ -14,8 +14,10 @@ from dune.codegen.generation import (backend,
iname,
instruction,
kernel_cached,
silenced_warning,
temporary_variable,
)
from dune.codegen.loopy.flatten import flatten_index
from dune.codegen.loopy.target import type_floatingpoint
from dune.codegen.sumfact.tabulation import (basis_functions_per_direction,
construct_basis_matrix_sequence,
......@@ -107,24 +109,50 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord):
def direct_is_possible(self):
return get_form_option("fastdg")
def realize(self, sf, insn_dep, index=0):
def setup_input(self, sf, insn_dep, index=0):
"""Setup input for a sum factorization kernel function
Write the coefficients into an array that can be passed to a sum
factorization kernel function (necessary if direct input "fastdg" is
not possible).
index: Vectorization index
"""
# Inames for interating over the coefficients. We take them from the
# cost permuted matrix sequence. In order to get the inames in order
# x,y,... we need to take the permutation back.
shape_cost_permuted = tuple(mat.basis_size for mat in sf.matrix_sequence_cost_permuted)
shape_ordered = permute_backward(shape_cost_permuted, self.cost_permutation)
shape_ordered = permute_backward(shape_ordered, self.quadrature_permutation)
inames_cost_permuted = tuple(sumfact_iname(length, "setup_inames_" + str(k)) for k, length in enumerate(shape_cost_permuted))
inames_ordered = permute_backward(inames_cost_permuted, self.cost_permutation)
inames_ordered = permute_backward(inames_ordered, self.quadrature_permutation)
# The coefficient needs to be accessed with a flat index of inames ordered x,y,...
flat_index = flatten_index(tuple(prim.Variable(i) for i in inames_ordered),
shape_ordered,
order="f")
# Get the coefficient container
lfs = name_lfs(self.element, self.restriction, self.element_index)
basisiname = sumfact_iname(name_lfs_bound(lfs), "basis")
container = self.coeff_func(self.restriction)
from dune.codegen.pdelab.argument import pymbolic_coefficient as pc
coeff = pc(container, lfs, basisiname)
coeff = pc(container, lfs, flat_index)
# Get the input temporary!
# The array that will be passed to the sum factorization kernel
# function should contain the coefficients in the cost permuted order!
from dune.codegen.sumfact.realization import name_buffer_storage
name = "input_{}".format(sf.buffer)
ftags = ",".join(["f"] * (sf.length + 1))
temporary_variable(name,
shape=(product(mat.basis_size for mat in sf.matrix_sequence_quadrature_permuted), sf.vector_width),
shape=(sf.vector_width,) + shape_cost_permuted,
custom_base_storage=name_buffer_storage(sf.buffer, 0),
managed=True,
dim_tags=ftags,
)
assignee = prim.Subscript(prim.Variable(name),
(prim.Variable(basisiname),) + (index,))
(index,) + tuple(prim.Variable(i) for i in inames_cost_permuted))
insn = instruction(assignee=assignee,
expression=coeff,
depends_on=sf.insn_dep.union(insn_dep),
......@@ -134,6 +162,16 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord):
return insn_dep.union(frozenset({insn}))
def realize_direct(self, shape, inames, which=0):
# If the input comes directly from a global data structure inames are
# ordered x,y,z,...
#
# The inames and shape passed to this method come from the cost
# permuted matrix sequence so we need to permute them back
shape = permute_backward(shape, self.cost_permutation)
shape = permute_backward(shape, self.quadrature_permutation)
inames = permute_backward(inames, self.cost_permutation)
inames = permute_backward(inames, self.quadrature_permutation)
arg = "fastdg{}".format(which)
from dune.codegen.sumfact.accumulation import _dof_offset
......@@ -145,16 +183,23 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord):
return prim.Subscript(prim.Variable(arg), inames)
def realize_input(self, shape, inames, which=0):
if self.direct_is_possible:
shape = permute_backward(shape, self.cost_permutation)
shape = permute_backward(shape, self.quadrature_permutation)
inames = permute_backward(inames, self.cost_permutation)
inames = permute_backward(inames, self.quadrature_permutation)
def realize_input(self, shape, inames, vec_shape, vec_iname, buffer, ftags, l):
# Note: Here we do not need to reverse any permutation since this is
# already done in the setup_input method above!
return self.realize_direct(shape, inames)
else:
raise NotImplementedError("TODO")
# Get a temporary that interprets the base storage of the input
# as a column-major matrix. In later iteration of the matrix loop
# this reinterprets the output of the previous iteration.
inp = buffer.get_temporary("buff_step{}_in".format(l),
shape=shape + vec_shape,
dim_tags=ftags,
)
# The input temporary will only be read from, so we need to silence
# the loopy warning
silenced_warning('read_no_write({})'.format(inp))
return prim.Subscript(prim.Variable(inp), inames + vec_iname)
@property
def function_name_suffix(self):
......
......@@ -15,6 +15,7 @@ from dune.codegen.generation import (backend,
globalarg,
valuearg,
)
from dune.codegen.loopy.flatten import flatten_index
from dune.codegen.options import get_option
from dune.codegen.pdelab.geometry import (local_dimension,
world_dimension,
......@@ -26,9 +27,10 @@ from dune.codegen.pdelab.localoperator import (name_ansatz_gfs_constructor_param
lop_template_range_field,
)
from dune.codegen.pdelab.restriction import restricted_name
from dune.codegen.sumfact.accumulation import basis_sf_kernels
from dune.codegen.sumfact.accumulation import basis_sf_kernels, sumfact_iname
from dune.codegen.sumfact.basis import construct_basis_matrix_sequence
from dune.codegen.sumfact.permutation import (permute_forward,
from dune.codegen.sumfact.permutation import (permute_backward,
permute_forward,
sumfact_cost_permutation_strategy,
sumfact_quadrature_permutation_strategy,
)
......@@ -113,15 +115,36 @@ class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableRecord):
def direct_is_possible(self):
return False
def realize(self, sf, insn_dep, index=0):
# Note: world_dimension, since we only do evaluation of cell geometry mappings
def setup_input(self, sf, insn_dep, index=0):
# Inames for interating over the coefficients (in this case the
# coordinate of the component 'sefl.direction' of the corners). We take
# them from the cost permuted matrix sequence. In order to get the
# inames in order x,y,... we need to take the permutation back.
shape_cost_permuted = tuple(mat.basis_size for mat in sf.matrix_sequence_cost_permuted)
shape_ordered = permute_backward(shape_cost_permuted, self.cost_permutation)
shape_ordered = permute_backward(shape_ordered, self.quadrature_permutation)
inames_cost_permuted = tuple(sumfact_iname(length, "corner_setup_inames_" + str(k)) for k, length in enumerate(shape_cost_permuted))
inames_ordered = permute_backward(inames_cost_permuted, self.cost_permutation)
inames_ordered = permute_backward(inames_ordered, self.quadrature_permutation)
# Flat indices needed to access pdelab corner method
flat_index_ordered = flatten_index(tuple(prim.Variable(i) for i in inames_ordered),
shape_ordered,
order="f")
flat_index_cost_permuted = flatten_index(tuple(prim.Variable(i) for i in inames_cost_permuted),
shape_cost_permuted,
order="f")
# The array that will be passed to the sum factorization kernel
# function should contain the coefficients in the cost permuted order!
name = "input_{}".format(sf.buffer)
ftags = ",".join(["f"] * (sf.length + 1))
temporary_variable(name,
shape=(2 ** world_dimension(), sf.vector_width),
shape=(sf.vector_width,) + shape_cost_permuted,
custom_base_storage=name_buffer_storage(sf.buffer, 0),
managed=True,
dim_tags=ftags,
)
ciname = global_corner_iname(self.restriction)
if self.restriction == 0:
geo = name_geometry()
......@@ -132,23 +155,38 @@ class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableRecord):
# method does return a non-scalar, which does not fit into the current
# loopy philosophy for function calls. This problem will be solved once
# #11 is resolved. Admittedly, the code looks *really* ugly until that happens.
code = "{}[{}*{}+{}] = {}.corner({})[{}];".format(name,
sf.vector_width,
ciname,
index,
geo,
ciname,
self.direction,
)
code = "{}[{}*({})+{}] = {}.corner({})[{}];".format(name,
sf.vector_width,
str(flat_index_cost_permuted),
index,
geo,
str(flat_index_ordered),
self.direction,
)
insn = instruction(code=code,
within_inames=frozenset({ciname}),
within_inames=frozenset(inames_cost_permuted),
assignees=(name,),
tags=frozenset({"sumfact_stage{}".format(sf.stage)}),
)
return insn_dep.union(frozenset({insn}))
def realize_input(self, shape, inames, vec_shape, vec_iname, buffer, ftags, l):
# Get a temporary that interprets the base storage of the input
# as a column-major matrix. In later iteration of the matrix loop
# this reinterprets the output of the previous iteration.
inp = buffer.get_temporary("buff_step{}_in".format(l),
shape=shape + vec_shape,
dim_tags=ftags,
)
# The input temporary will only be read from, so we need to silence
# the loopy warning
silenced_warning('read_no_write({})'.format(inp))
return prim.Subscript(prim.Variable(inp), inames + vec_iname)
@backend(interface="spatial_coordinate", name="default")
def pymbolic_spatial_coordinate_multilinear(do_predicates, visitor):
......
......@@ -75,8 +75,8 @@ def _realize_sum_factorization_kernel(sf):
for buf in buffers:
# Determine the necessary size of the buffer. We assume that we do not
# underintegrate the form!!!
size = max(product(m.quadrature_size for m in sf.matrix_sequence_quadrature_permuted) * sf.vector_width,
product(m.basis_size for m in sf.matrix_sequence_quadrature_permuted) * sf.vector_width)
size = max(product(m.quadrature_size for m in sf.matrix_sequence_cost_permuted) * sf.vector_width,
product(m.basis_size for m in sf.matrix_sequence_cost_permuted) * sf.vector_width)
temporary_variable("{}_dummy".format(buf),
shape=(size,),
custom_base_storage=buf,
......@@ -85,7 +85,7 @@ def _realize_sum_factorization_kernel(sf):
# Realize the input if it is not direct
if sf.stage == 1 and not sf.interface.direct_is_possible:
insn_dep = insn_dep.union(sf.interface.realize(sf, insn_dep))
insn_dep = insn_dep.union(sf.interface.setup_input(sf, insn_dep))
# Trigger generation of the sum factorization kernel function
qp = quadrature_points_per_direction()
......@@ -190,24 +190,20 @@ def realize_sumfact_kernel_function(sf):
# * a value from a global data structure, broadcasted to a vector type
# (vectorized + FastDGGridOperator)
input_inames = (k_expr,) + tuple(prim.Variable(j) for j in out_inames[1:])
if l == 0 and sf.stage == 1 and sf.interface.direct_is_possible:
input_summand = sf.interface.realize_input(inp_shape, input_inames)
input_summand = sf.interface.realize_direct(inp_shape, input_inames)
elif l == 0:
# TODO: Simplify arguments!
input_summand = sf.interface.realize_input(inp_shape,
input_inames,
vec_shape,
vec_iname,
buffer,
ftags,
l,
)
else:
# If we did permute the order of a matrices above we also
# permuted the order of out_inames. Unfortunately the
# order of our input is from 0 to d-1. This means we need
# to permute _back_ to get the right coefficients.
if l == 0:
inp_shape = permute_backward(inp_shape, sf.cost_permutation)
input_inames = permute_backward(input_inames, sf.cost_permutation)
if sf.stage == 1:
# In the unstructured case the sf.matrix_sequence_quadrature_permuted could
# already be permuted according to
# sf.interface.quadrature_permutation. We also need to reverse this
# permutation to get the input from 0 to d-1.
inp_shape = permute_backward(inp_shape, sf.interface.quadrature_permutation)
input_inames = permute_backward(input_inames, sf.interface.quadrature_permutation)
# Get a temporary that interprets the base storage of the input
# as a column-major matrix. In later iteration of the matrix loop
# this reinterprets the output of the previous iteration.
......@@ -237,6 +233,7 @@ def realize_sumfact_kernel_function(sf):
# end and reverse permutation
output_inames = tuple(prim.Variable(i) for i in out_inames[1:]) + (prim.Variable(out_inames[0]),)
if l == len(matrix_sequence) - 1:
# TODO: Move permutations to interface!
output_inames = permute_backward(output_inames, sf.cost_permutation)
if sf.stage == 3:
output_inames = permute_backward(output_inames, sf.interface.quadrature_permutation)
......@@ -262,6 +259,7 @@ def realize_sumfact_kernel_function(sf):
# If we are in the last step we reverse the permutation.
output_shape = tuple(out_shape[1:]) + (out_shape[0],)
if l == len(matrix_sequence) - 1:
# TODO: Move permutations to interface
output_shape = permute_backward(output_shape, sf.cost_permutation)
if sf.stage == 3:
output_shape = permute_backward(output_shape, sf.interface.quadrature_permutation)
......
......@@ -2,6 +2,7 @@
from dune.codegen.options import get_form_option, get_option
from dune.codegen.generation import (get_counted_variable,
silenced_warning,
subst_rule,
transform,
)
......@@ -99,9 +100,10 @@ class SumfactKernelInterfaceBase(object):
class VectorSumfactKernelInput(SumfactKernelInterfaceBase):
def __init__(self, interfaces):
def __init__(self, interfaces, perm):
assert(isinstance(interfaces, tuple))
self.interfaces = interfaces
self.vector_cost_permutation = perm
def __repr__(self):
return "_".join(repr(i) for i in self.interfaces)
......@@ -124,7 +126,7 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase):
for i in self.interfaces:
assert i.cost_permutation == cost_permutation
return cost_permutation
return vector_cost_permutation
@property
def stage(self):
......@@ -134,12 +136,14 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase):
def direct_is_possible(self):
return all(i.direct_is_possible for i in self.interfaces)
def realize(self, sf, dep):
def setup_input(self, sf, dep):
for i, inp in enumerate(self.interfaces):
dep = dep.union(inp.realize(sf, dep, index=i))
dep = dep.union(inp.setup_input(sf, dep, index=i))
return dep
def realize_direct(self, shape, inames):
# TODO: vector_cost_permutation not used!
# Check whether the input exhibits a favorable structure
# (whether we can broadcast scalar values into SIMD registers)
total = set(self.interfaces)
......@@ -166,16 +170,22 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase):
# need to load scalars into the SIMD vector.
raise NotImplementedError("SIMD loads from scalars not implemented!")
def realize_input(self, shape, inames):
if self.direct_is_possible:
shape = permute_backward(shape, self.cost_permutation)
shape = permute_backward(shape, self.quadrature_permutation)
inames = permute_backward(inames, self.cost_permutation)
inames = permute_backward(inames, self.quadrature_permutation)
def realize_input(self, shape, inames, vec_shape, vec_iname, buffer, ftags, l):
# TODO: vector_cost_permutation not used!
return self.realize_direct(shape, inames)
else:
raise NotImplementedError("TODO")
# Get a temporary that interprets the base storage of the input
# as a column-major matrix. In later iteration of the matrix loop
# this reinterprets the output of the previous iteration.
inp = buffer.get_temporary("buff_step{}_in".format(l),
shape=shape + vec_shape,
dim_tags=ftags,
)
# The input temporary will only be read from, so we need to silence
# the loopy warning
silenced_warning('read_no_write({})'.format(inp))
return prim.Subscript(prim.Variable(inp), inames + vec_iname)
@property
def function_args(self):
......@@ -198,8 +208,9 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase):
class VectorSumfactKernelOutput(SumfactKernelInterfaceBase):
def __init__(self, interfaces):
def __init__(self, interfaces, perm):
self.interfaces = interfaces
self.vector_cost_permutation = perm
def __repr__(self):
return "_".join(repr(o) for o in self.interfaces)
......@@ -231,6 +242,8 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase):
return prim.Call(prim.Variable(hadd_function), (result,))
def realize(self, sf, result, insn_dep):
# TODO: vector_cost_permutation not used!
outputs = set(self.interfaces)
trial_element, = set(o.trial_element for o in self.interfaces)
......@@ -250,6 +263,8 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase):
return deps
def realize_direct(self, result, inames, shape, **args):
# TODO: vector_cost_permutation not used!
outputs = set(self.interfaces)
# If multiple horizontal_add's are to be performed with 'result'
......@@ -268,6 +283,25 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase):
return deps
def realize_input(self, shape, inames, vec_shape, vec_iname, buffer, ftags, l):
# TODO: Include permutations of scalar kernels as soon as they could be different
shape = permute_backward(shape, self.vector_cost_permutation)
inames = permute_backward(inames, self.vector_cost_permutation)
# Get a temporary that interprets the base storage of the input
# as a column-major matrix. In later iteration of the matrix loop
# this reinterprets the output of the previous iteration.
inp = buffer.get_temporary("buff_step{}_in".format(l),
shape=shape + vec_shape,
dim_tags=ftags,
)
# The input temporary will only be read from, so we need to silence
# the loopy warning
silenced_warning('read_no_write({})'.format(inp))
return prim.Subscript(prim.Variable(inp), inames + vec_iname)
@property
def function_args(self):
if get_form_option("fastdg"):
......@@ -759,6 +793,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
@property
def matrix_sequence_quadrature_permuted(self):
# TODO: This should be turned into a RuntimeError
return tuple(BasisTabulationMatrixArray(tuple(k.matrix_sequence_quadrature_permuted[i] for k in self.kernels),
width=self.vector_width,
)
......@@ -819,10 +854,11 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
@property
def interface(self):
perm = self.cost_permutation
if self.stage == 1:
return VectorSumfactKernelInput(tuple(k.interface for k in self.kernels))
return VectorSumfactKernelInput(tuple(k.interface for k in self.kernels), perm)
else:
return VectorSumfactKernelOutput(tuple(k.interface for k in self.kernels))
return VectorSumfactKernelOutput(tuple(k.interface for k in self.kernels), perm)
@property
def cache_key(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment