Skip to content
Snippets Groups Projects
Commit 4118f2f5 authored by René Heß's avatar René Heß
Browse files

Add documentation and remove realize_output interface methods

The realize_output interface method was not really necessary and just adds a
layer of indiretion.
parent 88a7c1e1
No related branches found
No related tags found
No related merge requests found
......@@ -193,7 +193,7 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord):
return prim.Subscript(prim.Variable(inp), inames + vec_iname)
def realize_direct_input(self, shape, inames, which=0):
def realize_direct_input(self, inames, shape, which=0):
# If the input comes directly from a global data structure inames are
# ordered x,y,z,...
#
......
......@@ -192,7 +192,7 @@ def realize_sumfact_kernel_function(sf):
input_inames = (k_expr,) + tuple(prim.Variable(j) for j in out_inames[1:])
if l == 0 and sf.stage == 1 and sf.interface.direct_is_possible:
input_summand = sf.interface.realize_direct_input(inp_shape, input_inames)
input_summand = sf.interface.realize_direct_input(input_inames, inp_shape)
elif l == 0:
# TODO: Simplify arguments!
input_summand = sf.interface.realize_input(input_inames,
......@@ -241,17 +241,30 @@ def realize_sumfact_kernel_function(sf):
insn_args["forced_iname_deps"] = frozenset({vec_iname[0].name})
insn_dep = sf.interface.realize_direct_output(matprod, output_inames, out_shape, **insn_args)
elif l == len(matrix_sequence) - 1:
# Handle output of the last tensor contraction
#
# Stage 1: Reverse cost permutation, keep quadrature permutation
# Stage 3: Reverse cost and quadrature permuation
output_shape = tuple(out_shape[1:]) + (out_shape[0],)
insn_dep = sf.interface.realize_output(matprod,
output_inames,
output_shape,
vec_iname,
vec_shape,
buffer,
ftags,
l,
**insn_args,
)
output_inames = permute_backward(output_inames, sf.interface.cost_permutation)
output_shape = permute_backward(output_shape, sf.interface.cost_permutation)
if sf.interface.stage == 3:
output_inames = permute_backward(output_inames, sf.interface.quadrature_permutation)
output_shape = permute_backward(output_shape, sf.interface.quadrature_permutation)
out = buffer.get_temporary("buff_step{}_out".format(l),
shape=output_shape + vec_shape,
dim_tags=ftags,
)
# Issue the reduction instruction that implements the multiplication
# at the same time store the instruction ID for the next instruction to depend on
insn_dep = frozenset({instruction(assignee=prim.Subscript(prim.Variable(out), output_inames + vec_iname),
expression=matprod,
**insn_args,
)
})
else:
output_shape = tuple(out_shape[1:]) + (out_shape[0],)
out = buffer.get_temporary("buff_step{}_out".format(l),
......
......@@ -47,7 +47,6 @@ class SumfactKernelInterfaceBase(object):
Instructions this setup depends on.
index : int
Vectorization index, SIMD lane.
"""
raise NotImplementedError
......@@ -56,14 +55,16 @@ class SumfactKernelInterfaceBase(object):
This happens inside the sumfact kernel function.
Stage 1: Input is already permuted the right way in setup_input.
Stage 1 : Input is already permuted the right way in setup_input.
Stage 3: TODO -> Check permutation in accumulation.py
Stage 3 : TODO -> Check permutation in accumulation.py
Parameters
----------
inames : tuple of pymbolic.primitives.Variable
Inames for accesing the input. Ordered according to permuted matrix sequence.
shape : tuple of int
Shape of input. Ordered according to permuted matrix sequence.
vec_iname : tuple of pymbolic.primitives.Variable
In case of vectorized kernel provide vectorization iname.
vec_shape : tuple of int
......@@ -75,66 +76,66 @@ class SumfactKernelInterfaceBase(object):
"""
raise NotImplementedError
def realize_direct_input(self, shape, inames, which=0):
def realize_direct_input(self, inames, shape, which=0):
"""Interpret the input of sumfact kernel function in the right way (fastdg)
This happens inside the sumfact kernel function.
TODO: Add note about permutation
TODO: Document input arguments
Stage 1: The input to the sum factorization kernel will be ordered x,
y, z,... The shape and inames from this method come from the cost
permuted matrix sequence. Make sure to permute them back when accesing
the input.
Parameters
----------
inames : tuple of pymbolic.primitives.Variable
Inames for accesing the input. Ordered according to permuted matrix sequence.
shape: tuple of int
Shape of input. Ordered according to permuted matrix sequence.
which : int
In case of VetcorizedSumfactKernel this might specify if the lower or upper
part of a the SIMD register is for this input.
"""
raise NotImplementedError
def accumulate_output(self, sf, result, insn_dep, inames=None, additional_inames=()):
"""Generate accumulate instruction after a stage 3 sumfact kernel function (non fastdg)
This happens after the function call.
This happens after the function call. After stage 2 the result should
be ordered x, y, z,..., no permutations necessary.
TODO: Add note about permutation
TODO: Document input arguments
Parameters
----------
sf : SumfactKernel or VectorizedSumfactKernel
result : SumfactKernel or some pymbolic stuff
Result of a sum factorization
insn_dep : frozenset
Instructions this setup depends on.
inames : tuple of pymbolic.primitives.Variable
additional_inames : tuple of pymbolic.primitives.Variable
Additional inames the accumulation instruction depends on (eg. loop over
ansatz functions for jacobians).
"""
raise NotImplementedError
def realize_output(self, result, inames, shape, vec_iname, vec_shape, buffer, ftags, l, **args):
"""Handle the output of the last tensor contraction in the sumfact kernel function the right way
This happens inside the sumfact kernel function.
Stage 1: Reverse cost permutation, output should only be quadrature
permuted.
Stage 3: Reverse cost and quadrature permutation. The output will be
sorted according to dof/residual vector.
TODO: Cleanup arguments
TODO: Document input arguments
"""
inames = permute_backward(inames, self.cost_permutation)
shape = permute_backward(shape, self.cost_permutation)
if self.stage == 3:
inames = permute_backward(inames, self.quadrature_permutation)
shape = permute_backward(shape, self.quadrature_permutation)
out = buffer.get_temporary("buff_step{}_out".format(l),
shape=shape + vec_shape,
dim_tags=ftags,
)
# Issue the reduction instruction that implements the multiplication
# at the same time store the instruction ID for the next instruction to depend on
return frozenset({instruction(assignee=prim.Subscript(prim.Variable(out), inames + vec_iname),
expression=result,
**args
)
})
def realize_direct_output(self, result, iname, shape, which=0, **args):
def realize_direct_output(self, result, iname, shape, which=0, **kwargs):
"""Accumulate results directly in the sumfact kernel function (fastdg)
This happens inside the sumfact kernel function.
TODO: Add note about permutation
TODO: Document input arguments
Parameters
----------
result : pymbolic stuff
Result of the sum factorization
iname : tuple of pymbolic.primitives.Variable
shape : tuple of ints
which : int
TODO Doc me!
**kwargs :
Key word arguments passed to loopy instruction
"""
raise NotImplementedError
......@@ -265,7 +266,7 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase):
return prim.Subscript(prim.Variable(inp), inames + vec_iname)
def realize_direct_input(self, shape, inames):
def realize_direct_input(self, inames, shape):
# TODO: vector_cost_permutation not used!
# Check whether the input exhibits a favorable structure
......@@ -278,15 +279,15 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase):
# All input coefficients use the exact same input coefficient.
# We implement this by broadcasting it into a SIMD register
return prim.Call(ExplicitVCLCast(dtype_floatingpoint()),
(self.interfaces[0].realize_direct_input(shape, inames),)
(self.interfaces[0].realize_direct_input(inames, shape),)
)
elif len(total) == 2 and len(lower) == 1 and len(upper) == 1:
# The lower and the upper part of the SIMD register use
# the same input coefficient, we combine the SIMD register
# from two shorter SIMD types
return prim.Call(VCLLowerUpperLoad(dtype_floatingpoint()),
(self.interfaces[0].realize_direct_input(shape, inames),
self.interfaces[len(self.interfaces) // 2].realize_direct_input(shape, inames, which=1),
(self.interfaces[0].realize_direct_input(inames, shape),
self.interfaces[len(self.interfaces) // 2].realize_direct_input(inames, shape, which=1),
)
)
else:
......@@ -1025,9 +1026,9 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
# that derivatives are the same.
from copy import deepcopy
sf_interface = deepcopy(sf.interface)
sf_interface._cost_permutation=None
sf_interface._cost_permutation = None
k_interface = deepcopy(k.interface)
k_interface._cost_permutation=None
k_interface._cost_permutation = None
if repr(sf_interface) == repr(k_interface):
if tuple(mat.derivative for mat in sf.matrix_sequence_quadrature_permuted) == tuple(mat.derivative for mat in k.matrix_sequence_quadrature_permuted):
return i
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment