Skip to content
Snippets Groups Projects
Commit 62c236a2 authored by René Heß's avatar René Heß
Browse files

[skip ci] Cleanup and documentation

parent 12cac856
No related branches found
No related tags found
No related merge requests found
......@@ -171,7 +171,7 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord):
from dune.codegen.sumfact.basis import lfs_inames
return lfs_inames(get_leaf(self.trial_element, self.trial_element_index), self.restriction)
def realize_input(self, inames, shape, vec_iname, vec_shape, buffer, ftags, l):
def realize_input(self, inames, shape, vec_iname, vec_shape, buf, ftags):
# TODO: This should happen in stage 2 and not in stage 3
shape = permute_backward(shape, self.cost_permutation)
inames = permute_backward(inames, self.cost_permutation)
......@@ -179,10 +179,10 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord):
# Get a temporary that interprets the base storage of the input
# as a column-major matrix. In later iteration of the matrix loop
# this reinterprets the output of the previous iteration.
inp = buffer.get_temporary("buff_step{}_in".format(l),
shape=shape + vec_shape,
dim_tags=ftags,
)
inp = buf.get_temporary("buff_step0_in",
shape=shape + vec_shape,
dim_tags=ftags,
)
# The input temporary will only be read from, so we need to silence
# the loopy warning
......
......@@ -175,17 +175,17 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord):
return insn_dep.union(frozenset({insn}))
def realize_input(self, inames, shape, vec_iname, vec_shape, buffer, ftags, l):
def realize_input(self, inames, shape, vec_iname, vec_shape, buf, ftags):
# Note: Here we do not need to reverse any permutation since this is
# already done in the setup_input method above!
# Get a temporary that interprets the base storage of the input
# as a column-major matrix. In later iteration of the matrix loop
# this reinterprets the output of the previous iteration.
inp = buffer.get_temporary("buff_step{}_in".format(l),
shape=shape + vec_shape,
dim_tags=ftags,
)
inp = buf.get_temporary("buff_step0_in",
shape=shape + vec_shape,
dim_tags=ftags,
)
# The input temporary will only be read from, so we need to silence
# the loopy warning
......
......@@ -189,14 +189,14 @@ class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableRecord):
return insn_dep.union(frozenset({insn}))
def realize_input(self, inames, shape, vec_iname, vec_shape, buffer, ftags, l):
def realize_input(self, inames, shape, vec_iname, vec_shape, buf, ftags):
# Get a temporary that interprets the base storage of the input
# as a column-major matrix. In later iteration of the matrix loop
# this reinterprets the output of the previous iteration.
inp = buffer.get_temporary("buff_step{}_in".format(l),
shape=shape + vec_shape,
dim_tags=ftags,
)
inp = buf.get_temporary("buff_step0_in",
shape=shape + vec_shape,
dim_tags=ftags,
)
# The input temporary will only be read from, so we need to silence
# the loopy warning
......
......@@ -145,6 +145,9 @@ def sumfact_quadrature_permutation_strategy(dim, restriction):
if restriction == Restriction.POSITIVE:
return _order_on_self(restriction)
else:
# Still do normal direction first. The other two directions need to
# be done in reverse order to go through the quadrature points in
# the same order as on self (draw cubes!).
assert restriction == Restriction.NEGATIVE
l = list(_order_on_self(restriction))
return (l[0], l[2], l[1])
......@@ -201,7 +201,6 @@ def realize_sumfact_kernel_function(sf):
vec_shape,
buffer,
ftags,
l,
)
else:
# Get a temporary that interprets the base storage of the input
......
......@@ -35,23 +35,43 @@ class SumfactKernelInterfaceBase(object):
In stage 1, this represents the input object, in stage 3 the output object.
"""
def setup_input(self, sf, insn_dep, index=0):
"""Create and fill an input array for sumfact kernel function (non fastdg)
"""Create and fill an input array for a stage 1 sumfact kernel function (non fastdg)
This happens before the function call.
This happens before the function call. The input will be quadrature
(for unstructured grids) and cost permuted.
Parameters
----------
sf : SumfactKernel or VectorizedSumfactKernel
insn_dep : frozenset
Instructions this setup depends on.
index : int
Vectorization index, SIMD lane.
TODO: Add note about permutation
TODO: Document input arguments
"""
raise NotImplementedError
def realize_input(self, inames, shape, vec_iname, vec_shape, buffer, ftags, l):
def realize_input(self, inames, shape, vec_iname, vec_shape, buf, ftags):
"""Interpret the input of sumfact kernel function in the right way (non fastdgg)
This happens inside the sumfact kernel function.
TODO: Cleanup arguments
TODO: Add note about permutation
TODO: Document input arguments
Stage 1: Input is already permuted the right way in setup_input.
Stage 3: TODO -> Check permutation in accumulation.py
Parameters
----------
inames : tuple of pymbolic.primitives.Variable
shape : tuple of int
vec_iname : tuple of pymbolic.primitives.Variable
In case of vectorized kernel provide vectorization iname.
vec_shape : tuple of int
In case of vectorized kernel provide the number of vectorized kernels.
buf : dune.codegen.sumfact.realization.BufferSwitcher
Provides the input variable.
ftags : str
dim_tags needed to access input variable correctly.
"""
raise NotImplementedError
......@@ -66,7 +86,7 @@ class SumfactKernelInterfaceBase(object):
raise NotImplementedError
def accumulate_output(self, sf, result, insn_dep, inames=None, additional_inames=()):
"""Generate accumulate instruction after sumfact kernel function (non fastdg)
"""Generate accumulate instruction after a stage 3 sumfact kernel function (non fastdg)
This happens after the function call.
......@@ -80,8 +100,13 @@ class SumfactKernelInterfaceBase(object):
This happens inside the sumfact kernel function.
Stage 1: Reverse cost permutation, output should only be quadrature
permuted.
Stage 3: Reverse cost and quadrature permutation. The output will be
sorted according to dof/residual vector.
TODO: Cleanup arguments
TODO: Add note about permutation
TODO: Document input arguments
"""
inames = permute_backward(inames, self.cost_permutation)
......@@ -223,16 +248,16 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase):
dep = dep.union(inp.setup_input(sf, dep, index=i))
return dep
def realize_input(self, inames, shape, vec_iname, vec_shape, buffer, ftags, l):
def realize_input(self, inames, shape, vec_iname, vec_shape, buf, ftags):
# TODO: vector_cost_permutation not used!
# Get a temporary that interprets the base storage of the input
# as a column-major matrix. In later iteration of the matrix loop
# this reinterprets the output of the previous iteration.
inp = buffer.get_temporary("buff_step{}_in".format(l),
shape=shape + vec_shape,
dim_tags=ftags,
)
inp = buf.get_temporary("buff_step0_in",
shape=shape + vec_shape,
dim_tags=ftags,
)
# The input temporary will only be read from, so we need to silence
# the loopy warning
......@@ -332,7 +357,7 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase):
return prim.Call(prim.Variable(hadd_function), (result,))
def realize_input(self, inames, shape, vec_iname, vec_shape, buffer, ftags, l):
def realize_input(self, inames, shape, vec_iname, vec_shape, buf, ftags):
# The input for stage 3 is quadrature permuted. The inames and shape
# passed to this method are quadrature and cost permuted. This means we
# need to take the cost permutation back to get the right inames and
......@@ -342,10 +367,10 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase):
# Get a temporary that interprets the base storage of the input as a
# column-major matrix.
inp = buffer.get_temporary("buff_step{}_in".format(l),
shape=shape + vec_shape,
dim_tags=ftags,
)
inp = buf.get_temporary("buff_step0_in",
shape=shape + vec_shape,
dim_tags=ftags,
)
# The input temporary will only be read from, so we need to silence
# the loopy warning
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment