Skip to content
Snippets Groups Projects
Commit 62c236a2 authored by René Heß's avatar René Heß
Browse files

[skip ci] Cleanup and documentation

parent 12cac856
No related branches found
No related tags found
No related merge requests found
...@@ -171,7 +171,7 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord): ...@@ -171,7 +171,7 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord):
from dune.codegen.sumfact.basis import lfs_inames from dune.codegen.sumfact.basis import lfs_inames
return lfs_inames(get_leaf(self.trial_element, self.trial_element_index), self.restriction) return lfs_inames(get_leaf(self.trial_element, self.trial_element_index), self.restriction)
def realize_input(self, inames, shape, vec_iname, vec_shape, buffer, ftags, l): def realize_input(self, inames, shape, vec_iname, vec_shape, buf, ftags):
# TODO: This should happen in stage 2 and not in stage 3 # TODO: This should happen in stage 2 and not in stage 3
shape = permute_backward(shape, self.cost_permutation) shape = permute_backward(shape, self.cost_permutation)
inames = permute_backward(inames, self.cost_permutation) inames = permute_backward(inames, self.cost_permutation)
...@@ -179,10 +179,10 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord): ...@@ -179,10 +179,10 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord):
# Get a temporary that interprets the base storage of the input # Get a temporary that interprets the base storage of the input
# as a column-major matrix. In later iteration of the matrix loop # as a column-major matrix. In later iteration of the matrix loop
# this reinterprets the output of the previous iteration. # this reinterprets the output of the previous iteration.
inp = buffer.get_temporary("buff_step{}_in".format(l), inp = buf.get_temporary("buff_step0_in",
shape=shape + vec_shape, shape=shape + vec_shape,
dim_tags=ftags, dim_tags=ftags,
) )
# The input temporary will only be read from, so we need to silence # The input temporary will only be read from, so we need to silence
# the loopy warning # the loopy warning
......
...@@ -175,17 +175,17 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord): ...@@ -175,17 +175,17 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord):
return insn_dep.union(frozenset({insn})) return insn_dep.union(frozenset({insn}))
def realize_input(self, inames, shape, vec_iname, vec_shape, buffer, ftags, l): def realize_input(self, inames, shape, vec_iname, vec_shape, buf, ftags):
# Note: Here we do not need to reverse any permutation since this is # Note: Here we do not need to reverse any permutation since this is
# already done in the setup_input method above! # already done in the setup_input method above!
# Get a temporary that interprets the base storage of the input # Get a temporary that interprets the base storage of the input
# as a column-major matrix. In later iteration of the matrix loop # as a column-major matrix. In later iteration of the matrix loop
# this reinterprets the output of the previous iteration. # this reinterprets the output of the previous iteration.
inp = buffer.get_temporary("buff_step{}_in".format(l), inp = buf.get_temporary("buff_step0_in",
shape=shape + vec_shape, shape=shape + vec_shape,
dim_tags=ftags, dim_tags=ftags,
) )
# The input temporary will only be read from, so we need to silence # The input temporary will only be read from, so we need to silence
# the loopy warning # the loopy warning
......
...@@ -189,14 +189,14 @@ class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableRecord): ...@@ -189,14 +189,14 @@ class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableRecord):
return insn_dep.union(frozenset({insn})) return insn_dep.union(frozenset({insn}))
def realize_input(self, inames, shape, vec_iname, vec_shape, buffer, ftags, l): def realize_input(self, inames, shape, vec_iname, vec_shape, buf, ftags):
# Get a temporary that interprets the base storage of the input # Get a temporary that interprets the base storage of the input
# as a column-major matrix. In later iteration of the matrix loop # as a column-major matrix. In later iteration of the matrix loop
# this reinterprets the output of the previous iteration. # this reinterprets the output of the previous iteration.
inp = buffer.get_temporary("buff_step{}_in".format(l), inp = buf.get_temporary("buff_step0_in",
shape=shape + vec_shape, shape=shape + vec_shape,
dim_tags=ftags, dim_tags=ftags,
) )
# The input temporary will only be read from, so we need to silence # The input temporary will only be read from, so we need to silence
# the loopy warning # the loopy warning
......
...@@ -145,6 +145,9 @@ def sumfact_quadrature_permutation_strategy(dim, restriction): ...@@ -145,6 +145,9 @@ def sumfact_quadrature_permutation_strategy(dim, restriction):
if restriction == Restriction.POSITIVE: if restriction == Restriction.POSITIVE:
return _order_on_self(restriction) return _order_on_self(restriction)
else: else:
# Still do normal direction first. The other two directions need to
# be done in reverse order to go through the quadrature points in
# the same order as on self (draw cubes!).
assert restriction == Restriction.NEGATIVE assert restriction == Restriction.NEGATIVE
l = list(_order_on_self(restriction)) l = list(_order_on_self(restriction))
return (l[0], l[2], l[1]) return (l[0], l[2], l[1])
...@@ -201,7 +201,6 @@ def realize_sumfact_kernel_function(sf): ...@@ -201,7 +201,6 @@ def realize_sumfact_kernel_function(sf):
vec_shape, vec_shape,
buffer, buffer,
ftags, ftags,
l,
) )
else: else:
# Get a temporary that interprets the base storage of the input # Get a temporary that interprets the base storage of the input
......
...@@ -35,23 +35,43 @@ class SumfactKernelInterfaceBase(object): ...@@ -35,23 +35,43 @@ class SumfactKernelInterfaceBase(object):
In stage 1, this represents the input object, in stage 3 the output object. In stage 1, this represents the input object, in stage 3 the output object.
""" """
def setup_input(self, sf, insn_dep, index=0): def setup_input(self, sf, insn_dep, index=0):
"""Create and fill an input array for sumfact kernel function (non fastdg) """Create and fill an input array for a stage 1 sumfact kernel function (non fastdg)
This happens before the function call. This happens before the function call. The input will be quadrature
(for unstructured grids) and cost permuted.
Parameters
----------
sf : SumfactKernel or VectorizedSumfactKernel
insn_dep : frozenset
Instructions this setup depends on.
index : int
Vectorization index, SIMD lane.
TODO: Add note about permutation
TODO: Document input arguments
""" """
raise NotImplementedError raise NotImplementedError
def realize_input(self, inames, shape, vec_iname, vec_shape, buffer, ftags, l): def realize_input(self, inames, shape, vec_iname, vec_shape, buf, ftags):
"""Interpret the input of sumfact kernel function in the right way (non fastdgg) """Interpret the input of sumfact kernel function in the right way (non fastdgg)
This happens inside the sumfact kernel function. This happens inside the sumfact kernel function.
TODO: Cleanup arguments Stage 1: Input is already permuted the right way in setup_input.
TODO: Add note about permutation
TODO: Document input arguments Stage 3: TODO -> Check permutation in accumulation.py
Parameters
----------
inames : tuple of pymbolic.primitives.Variable
shape : tuple of int
vec_iname : tuple of pymbolic.primitives.Variable
In case of vectorized kernel provide vectorization iname.
vec_shape : tuple of int
In case of vectorized kernel provide the number of vectorized kernels.
buf : dune.codegen.sumfact.realization.BufferSwitcher
Provides the input variable.
ftags : str
dim_tags needed to access input variable correctly.
""" """
raise NotImplementedError raise NotImplementedError
...@@ -66,7 +86,7 @@ class SumfactKernelInterfaceBase(object): ...@@ -66,7 +86,7 @@ class SumfactKernelInterfaceBase(object):
raise NotImplementedError raise NotImplementedError
def accumulate_output(self, sf, result, insn_dep, inames=None, additional_inames=()): def accumulate_output(self, sf, result, insn_dep, inames=None, additional_inames=()):
"""Generate accumulate instruction after sumfact kernel function (non fastdg) """Generate accumulate instruction after a stage 3 sumfact kernel function (non fastdg)
This happens after the function call. This happens after the function call.
...@@ -80,8 +100,13 @@ class SumfactKernelInterfaceBase(object): ...@@ -80,8 +100,13 @@ class SumfactKernelInterfaceBase(object):
This happens inside the sumfact kernel function. This happens inside the sumfact kernel function.
Stage 1: Reverse cost permutation, output should only be quadrature
permuted.
Stage 3: Reverse cost and quadrature permutation. The output will be
sorted according to dof/residual vector.
TODO: Cleanup arguments TODO: Cleanup arguments
TODO: Add note about permutation
TODO: Document input arguments TODO: Document input arguments
""" """
inames = permute_backward(inames, self.cost_permutation) inames = permute_backward(inames, self.cost_permutation)
...@@ -223,16 +248,16 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase): ...@@ -223,16 +248,16 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase):
dep = dep.union(inp.setup_input(sf, dep, index=i)) dep = dep.union(inp.setup_input(sf, dep, index=i))
return dep return dep
def realize_input(self, inames, shape, vec_iname, vec_shape, buffer, ftags, l): def realize_input(self, inames, shape, vec_iname, vec_shape, buf, ftags):
# TODO: vector_cost_permutation not used! # TODO: vector_cost_permutation not used!
# Get a temporary that interprets the base storage of the input # Get a temporary that interprets the base storage of the input
# as a column-major matrix. In later iteration of the matrix loop # as a column-major matrix. In later iteration of the matrix loop
# this reinterprets the output of the previous iteration. # this reinterprets the output of the previous iteration.
inp = buffer.get_temporary("buff_step{}_in".format(l), inp = buf.get_temporary("buff_step0_in",
shape=shape + vec_shape, shape=shape + vec_shape,
dim_tags=ftags, dim_tags=ftags,
) )
# The input temporary will only be read from, so we need to silence # The input temporary will only be read from, so we need to silence
# the loopy warning # the loopy warning
...@@ -332,7 +357,7 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase): ...@@ -332,7 +357,7 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase):
return prim.Call(prim.Variable(hadd_function), (result,)) return prim.Call(prim.Variable(hadd_function), (result,))
def realize_input(self, inames, shape, vec_iname, vec_shape, buffer, ftags, l): def realize_input(self, inames, shape, vec_iname, vec_shape, buf, ftags):
# The input for stage 3 is quadrature permuted. The inames and shape # The input for stage 3 is quadrature permuted. The inames and shape
# passed to this method are quadrature and cost permuted. This means we # passed to this method are quadrature and cost permuted. This means we
# need to take the cost permutation back to get the right inames and # need to take the cost permutation back to get the right inames and
...@@ -342,10 +367,10 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase): ...@@ -342,10 +367,10 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase):
# Get a temporary that interprets the base storage of the input as a # Get a temporary that interprets the base storage of the input as a
# column-major matrix. # column-major matrix.
inp = buffer.get_temporary("buff_step{}_in".format(l), inp = buf.get_temporary("buff_step0_in",
shape=shape + vec_shape, shape=shape + vec_shape,
dim_tags=ftags, dim_tags=ftags,
) )
# The input temporary will only be read from, so we need to silence # The input temporary will only be read from, so we need to silence
# the loopy warning # the loopy warning
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment