Skip to content
Snippets Groups Projects
Commit 1afb64a5 authored by René Heß's avatar René Heß
Browse files

[skip ci] Cleanup

parent 4118f2f5
No related branches found
No related tags found
No related merge requests found
...@@ -258,8 +258,8 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord): ...@@ -258,8 +258,8 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord):
return frozenset({dep}) return frozenset({dep})
def realize_direct_output(self, result, inames, shape, which=0, permute=True, **args): def realize_direct_output(self, result, inames, shape, which=0, reverse_cost_permutation=True, **args):
if permute: if reverse_cost_permutation:
inames = permute_backward(inames, self.cost_permutation) inames = permute_backward(inames, self.cost_permutation)
inames = permute_backward(inames, self.quadrature_permutation) inames = permute_backward(inames, self.quadrature_permutation)
......
...@@ -101,7 +101,7 @@ class SumfactKernelInterfaceBase(object): ...@@ -101,7 +101,7 @@ class SumfactKernelInterfaceBase(object):
def accumulate_output(self, sf, result, insn_dep, inames=None, additional_inames=()): def accumulate_output(self, sf, result, insn_dep, inames=None, additional_inames=()):
"""Generate accumulate instruction after a stage 3 sumfact kernel function (non fastdg) """Generate accumulate instruction after a stage 3 sumfact kernel function (non fastdg)
This happens after the function call. After stage 2 the result should This happens after the function call. After stage 3 the result should
be ordered x, y, z,..., no permutations necessary. be ordered x, y, z,..., no permutations necessary.
Parameters Parameters
...@@ -118,13 +118,12 @@ class SumfactKernelInterfaceBase(object): ...@@ -118,13 +118,12 @@ class SumfactKernelInterfaceBase(object):
""" """
raise NotImplementedError raise NotImplementedError
def realize_direct_output(self, result, iname, shape, which=0, **kwargs): def realize_direct_output(self, result, iname, shape, which=0, reverse_cost_permutation=True, **kwargs):
"""Accumulate results directly in the sumfact kernel function (fastdg) """Accumulate results directly in the sumfact kernel function (fastdg)
This happens inside the sumfact kernel function. This happens inside the sumfact kernel function.
TODO: Add note about permutation TODO: Add note about permutation
TODO: Document input arguments
Parameters Parameters
---------- ----------
...@@ -134,6 +133,8 @@ class SumfactKernelInterfaceBase(object): ...@@ -134,6 +133,8 @@ class SumfactKernelInterfaceBase(object):
shape : tuple of ints shape : tuple of ints
which : int which : int
TODO Doc me! TODO Doc me!
reverse_cost_permutation : tuple of ints
TODO Doc me!
**kwargs : **kwargs :
Key word arguments passed to loopy instruction Key word arguments passed to loopy instruction
""" """
...@@ -398,7 +399,12 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase): ...@@ -398,7 +399,12 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase):
for o in outputs: for o in outputs:
hadd_result = self._add_hadd(o, result) hadd_result = self._add_hadd(o, result)
which = tuple(remove_duplicates(self.interfaces)).index(o) which = tuple(remove_duplicates(self.interfaces)).index(o)
deps = deps.union(o.realize_direct_output(hadd_result, inames, shape, which=which, permute=False, **args)) deps = deps.union(o.realize_direct_output(hadd_result,
inames,
shape,
which=which,
reverse_cost_permutation=False,
**args))
return deps return deps
...@@ -603,9 +609,8 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): ...@@ -603,9 +609,8 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
# TODO: For now we do not vectorize SumfactKernels with different # TODO: For now we do not vectorize SumfactKernels with different
# quadrature_permutation. This should be handled like upper/lower # quadrature_permutation. This should be handled like upper/lower
# vectorization # vectorization
return tuple(m.quadrature_size for m in self.matrix_sequence_quadrature_permuted) + tuple(m.basis_size for m in self.matrix_sequence_quadrature_permuted) + (self.stage, self.buffer, self.interface.within_inames) + (self.interface.direct_is_possible,) return tuple(m.quadrature_size for m in self.matrix_sequence_quadrature_permuted) + tuple(m.basis_size for m in self.matrix_sequence_quadrature_permuted) + (self.stage, self.buffer, self.interface.within_inames) + (self.interface.direct_is_possible, self.interface.quadrature_permutation)
# return tuple(m.quadrature_size for m in self.matrix_sequence_quadrature_permuted) + tuple(m.basis_size for m in self.matrix_sequence_quadrature_permuted) + (self.stage, self.buffer, self.interface.within_inames) # return tuple(m.quadrature_size for m in self.matrix_sequence_quadrature_permuted) + tuple(m.basis_size for m in self.matrix_sequence_quadrature_permuted) + (self.stage, self.buffer, self.interface.within_inames) + (self.interface.direct_is_possible,)
# return tuple(m.basis_size for m in self.matrix_sequence_quadrature_permuted) + (self.stage, self.buffer, self.interface.within_inames)
@property @property
def cache_key(self): def cache_key(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment