Skip to content
Snippets Groups Projects
Commit 1afb64a5 authored by René Heß's avatar René Heß
Browse files

[skip ci] Cleanup

parent 4118f2f5
No related branches found
No related tags found
No related merge requests found
......@@ -258,8 +258,8 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord):
return frozenset({dep})
def realize_direct_output(self, result, inames, shape, which=0, permute=True, **args):
if permute:
def realize_direct_output(self, result, inames, shape, which=0, reverse_cost_permutation=True, **args):
if reverse_cost_permutation:
inames = permute_backward(inames, self.cost_permutation)
inames = permute_backward(inames, self.quadrature_permutation)
......
......@@ -101,7 +101,7 @@ class SumfactKernelInterfaceBase(object):
def accumulate_output(self, sf, result, insn_dep, inames=None, additional_inames=()):
"""Generate accumulate instruction after a stage 3 sumfact kernel function (non fastdg)
This happens after the function call. After stage 2 the result should
This happens after the function call. After stage 3 the result should
be ordered x, y, z,..., no permutations necessary.
Parameters
......@@ -118,13 +118,12 @@ class SumfactKernelInterfaceBase(object):
"""
raise NotImplementedError
def realize_direct_output(self, result, iname, shape, which=0, **kwargs):
def realize_direct_output(self, result, iname, shape, which=0, reverse_cost_permutation=True, **kwargs):
"""Accumulate results directly in the sumfact kernel function (fastdg)
This happens inside the sumfact kernel function.
TODO: Add note about permutation
TODO: Document input arguments
Parameters
----------
......@@ -134,6 +133,8 @@ class SumfactKernelInterfaceBase(object):
shape : tuple of ints
which : int
TODO Doc me!
reverse_cost_permutation : tuple of ints
TODO Doc me!
**kwargs :
Key word arguments passed to loopy instruction
"""
......@@ -398,7 +399,12 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase):
for o in outputs:
hadd_result = self._add_hadd(o, result)
which = tuple(remove_duplicates(self.interfaces)).index(o)
deps = deps.union(o.realize_direct_output(hadd_result, inames, shape, which=which, permute=False, **args))
deps = deps.union(o.realize_direct_output(hadd_result,
inames,
shape,
which=which,
reverse_cost_permutation=False,
**args))
return deps
......@@ -603,9 +609,8 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
# TODO: For now we do not vectorize SumfactKernels with different
# quadrature_permutation. This should be handled like upper/lower
# vectorization
return tuple(m.quadrature_size for m in self.matrix_sequence_quadrature_permuted) + tuple(m.basis_size for m in self.matrix_sequence_quadrature_permuted) + (self.stage, self.buffer, self.interface.within_inames) + (self.interface.direct_is_possible,)
# return tuple(m.quadrature_size for m in self.matrix_sequence_quadrature_permuted) + tuple(m.basis_size for m in self.matrix_sequence_quadrature_permuted) + (self.stage, self.buffer, self.interface.within_inames)
# return tuple(m.basis_size for m in self.matrix_sequence_quadrature_permuted) + (self.stage, self.buffer, self.interface.within_inames)
return tuple(m.quadrature_size for m in self.matrix_sequence_quadrature_permuted) + tuple(m.basis_size for m in self.matrix_sequence_quadrature_permuted) + (self.stage, self.buffer, self.interface.within_inames) + (self.interface.direct_is_possible, self.interface.quadrature_permutation)
# return tuple(m.quadrature_size for m in self.matrix_sequence_quadrature_permuted) + tuple(m.basis_size for m in self.matrix_sequence_quadrature_permuted) + (self.stage, self.buffer, self.interface.within_inames) + (self.interface.direct_is_possible,)
@property
def cache_key(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment