Skip to content
Snippets Groups Projects
Commit 43148454 authored by René Heß's avatar René Heß
Browse files

[skip ci] Rename sumfact interface methods

parent 5cd4058c
No related branches found
No related tags found
No related merge requests found
......@@ -154,7 +154,7 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord):
from dune.codegen.sumfact.basis import lfs_inames
return lfs_inames(get_leaf(self.trial_element, self.trial_element_index), self.restriction)
def realize(self, sf, result, insn_dep, inames=None, additional_inames=()):
def setup_output(self, sf, result, insn_dep, inames=None, additional_inames=()):
trial_leaf_element = get_leaf(self.trial_element, self.trial_element_index) if self.trial_element is not None else None
basis_size = tuple(mat.basis_size for mat in sf.matrix_sequence_quadrature_permuted)
......@@ -212,7 +212,7 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord):
return frozenset({dep})
def realize_direct(self, result, inames, shape, which=0, **args):
def realize_direct_output(self, result, inames, shape, which=0, **args):
direct_output = "fastdg{}".format(which)
ftags = ",".join(["f"] * len(shape))
......@@ -563,4 +563,4 @@ def generate_accumulation_instruction(expr, visitor):
result, insn_dep = realize_sum_factorization_kernel(vsf.copy(insn_dep=vsf.insn_dep.union(insn_dep)))
if not get_form_option("fastdg"):
insn_dep = vsf.interface.realize(vsf, result, insn_dep)
insn_dep = vsf.interface.setup_output(vsf, result, insn_dep)
......@@ -161,7 +161,7 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord):
return insn_dep.union(frozenset({insn}))
def realize_direct(self, shape, inames, which=0):
def realize_direct_input(self, shape, inames, which=0):
# If the input comes directly from a global data structure inames are
# ordered x,y,z,...
#
......
......@@ -192,7 +192,7 @@ def realize_sumfact_kernel_function(sf):
input_inames = (k_expr,) + tuple(prim.Variable(j) for j in out_inames[1:])
if l == 0 and sf.stage == 1 and sf.interface.direct_is_possible:
input_summand = sf.interface.realize_direct(inp_shape, input_inames)
input_summand = sf.interface.realize_direct_input(inp_shape, input_inames)
elif l == 0:
# TODO: Simplify arguments!
input_summand = sf.interface.realize_input(inp_shape,
......@@ -244,7 +244,7 @@ def realize_sumfact_kernel_function(sf):
if sf.vectorized:
insn_args["forced_iname_deps"] = frozenset({vec_iname[0].name})
insn_dep = sf.interface.realize_direct(matprod, output_inames, out_shape, **insn_args)
insn_dep = sf.interface.realize_direct_output(matprod, output_inames, out_shape, **insn_args)
elif l == len(matrix_sequence) - 1:
# TODO: Move permutations to interface!
output_inames = permute_backward(output_inames, sf.cost_permutation)
......
......@@ -33,18 +33,70 @@ class SumfactKernelInterfaceBase(object):
""" A base class for the input/output of a sum factorization kernel
In stage 1, this represents the input object, in stage 3 the output object.
"""
def realize(self, *a, **kw):
def setup_input(self, sf, insn_dep, index=0):
"""Create and fill an input array for sumfact kernel function (non fastdg)
This happens before the function call.
TODO: Add note about permutation
TODO: Document input arguments
"""
raise NotImplementedError
def realize_input(self, shape, inames, vec_shape, vec_iname, buffer, ftags, l):
"""Interpret the input of sumfact kernel function in the right way (non fastdgg)
This happens inside the sumfact kernel function.
TODO: Cleanup input
TODO: Add note about permutation
TODO: Document input arguments
"""
raise NotImplementedError
def realize_direct_input(self, shape, inames, which=0):
"""Interpret the input of sumfact kernel function in the right way (fastdg)
This happens inside the sumfact kernel function.
TODO: Add note about permutation
TODO: Document input arguments
"""
raise NotImplementedError
def realize_direct(self, *a, **kw):
def realize_direct_output(self, result, iname, shape, which=0, **args):
"""Accumulate results directly in the sumfact kernel function (fastdg)
This happens inside the sumfact kernel function.
TODO: Add note about permutation
TODO: Document input arguments
"""
def setup_output(self, sf, result, insn_dep, inames=None, additional_inames=()):
"""Generate accumulate instruction after sumfact kernel function (non fastdg)
This happens after the function call.
TODO: Add note about permutation
TODO: Document input arguments
"""
raise NotImplementedError
@property
def quadrature_permutation(self):
"""Order of local coordinate axis
On unstructured grids we sometimes need to go through the directions in
different order to make sure that we visit the (global) quadrature
points on self and neighbor in the same order.
"""
raise NotImplementedError
@property
def cost_permutation(self):
"""Permutation that minimizes flops
"""
raise NotImplementedError
@property
......@@ -141,7 +193,7 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase):
dep = dep.union(inp.setup_input(sf, dep, index=i))
return dep
def realize_direct(self, shape, inames):
def realize_direct_input(self, shape, inames):
# TODO: vector_cost_permutation not used!
# Check whether the input exhibits a favorable structure
......@@ -154,15 +206,15 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase):
# All input coefficients use the exact same input coefficient.
# We implement this by broadcasting it into a SIMD register
return prim.Call(ExplicitVCLCast(dtype_floatingpoint()),
(self.interfaces[0].realize_direct(shape, inames),)
(self.interfaces[0].realize_direct_input(shape, inames),)
)
elif len(total) == 2 and len(lower) == 1 and len(upper) == 1:
# The lower and the upper part of the SIMD register use
# the same input coefficient, we combine the SIMD register
# from two shorter SIMD types
return prim.Call(VCLLowerUpperLoad(dtype_floatingpoint()),
(self.interfaces[0].realize_direct(shape, inames),
self.interfaces[len(self.interfaces) // 2].realize_direct(shape, inames, which=1),
(self.interfaces[0].realize_direct_input(shape, inames),
self.interfaces[len(self.interfaces) // 2].realize_direct_input(shape, inames, which=1),
)
)
else:
......@@ -241,7 +293,7 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase):
return prim.Call(prim.Variable(hadd_function), (result,))
def realize(self, sf, result, insn_dep):
def setup_output(self, sf, result, insn_dep):
# TODO: vector_cost_permutation not used!
outputs = set(self.interfaces)
......@@ -258,11 +310,11 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase):
deps = frozenset()
for o in outputs:
hadd_result = self._add_hadd(o, maybe_wrap_subscript(result, tuple(prim.Variable(iname) for iname in inames + (veciname,))))
deps = deps.union(o.realize(sf, hadd_result, insn_dep, inames=inames, additional_inames=(veciname,)))
deps = deps.union(o.setup_output(sf, hadd_result, insn_dep, inames=inames, additional_inames=(veciname,)))
return deps
def realize_direct(self, result, inames, shape, **args):
def realize_direct_output(self, result, inames, shape, **args):
# TODO: vector_cost_permutation not used!
outputs = set(self.interfaces)
......@@ -279,7 +331,7 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase):
for o in outputs:
hadd_result = self._add_hadd(o, result)
which = tuple(remove_duplicates(self.interfaces)).index(o)
deps = deps.union(o.realize_direct(hadd_result, inames, shape, which=which, **args))
deps = deps.union(o.realize_direct_output(hadd_result, inames, shape, which=which, **args))
return deps
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment