diff --git a/python/dune/codegen/sumfact/accumulation.py b/python/dune/codegen/sumfact/accumulation.py index 3264a62c96844b372b7366b4260ef985f9f76904..e360c84775a05fdbd30591dc9f8d7c4b9b26d068 100644 --- a/python/dune/codegen/sumfact/accumulation.py +++ b/python/dune/codegen/sumfact/accumulation.py @@ -258,8 +258,8 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord): return frozenset({dep}) - def realize_direct_output(self, result, inames, shape, which=0, permute=True, **args): - if permute: + def realize_direct_output(self, result, inames, shape, which=0, reverse_cost_permutation=True, **args): + if reverse_cost_permutation: inames = permute_backward(inames, self.cost_permutation) inames = permute_backward(inames, self.quadrature_permutation) diff --git a/python/dune/codegen/sumfact/symbolic.py b/python/dune/codegen/sumfact/symbolic.py index 3f4e91f4ce194a1d354dc98a09e4b56afffff270..2b91aafa7ea2597e5b547d99671be3ac85949a81 100644 --- a/python/dune/codegen/sumfact/symbolic.py +++ b/python/dune/codegen/sumfact/symbolic.py @@ -101,7 +101,7 @@ class SumfactKernelInterfaceBase(object): def accumulate_output(self, sf, result, insn_dep, inames=None, additional_inames=()): """Generate accumulate instruction after a stage 3 sumfact kernel function (non fastdg) - This happens after the function call. After stage 2 the result should + This happens after the function call. After stage 3 the result should be ordered x, y, z,..., no permutations necessary. Parameters @@ -118,13 +118,12 @@ class SumfactKernelInterfaceBase(object): """ raise NotImplementedError - def realize_direct_output(self, result, iname, shape, which=0, **kwargs): + def realize_direct_output(self, result, iname, shape, which=0, reverse_cost_permutation=True, **kwargs): """Accumulate results directly in the sumfact kernel function (fastdg) This happens inside the sumfact kernel function. TODO: Add note about permutation - TODO: Document input arguments Parameters ---------- @@ -134,6 +133,8 @@ class SumfactKernelInterfaceBase(object): shape : tuple of ints which : int TODO Doc me! + reverse_cost_permutation : tuple of ints + TODO Doc me! **kwargs : Key word arguments passed to loopy instruction """ @@ -398,7 +399,12 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase): for o in outputs: hadd_result = self._add_hadd(o, result) which = tuple(remove_duplicates(self.interfaces)).index(o) - deps = deps.union(o.realize_direct_output(hadd_result, inames, shape, which=which, permute=False, **args)) + deps = deps.union(o.realize_direct_output(hadd_result, + inames, + shape, + which=which, + reverse_cost_permutation=False, + **args)) return deps @@ -603,9 +609,8 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): # TODO: For now we do not vectorize SumfactKernels with different # quadrature_permutation. This should be handled like upper/lower # vectorization - return tuple(m.quadrature_size for m in self.matrix_sequence_quadrature_permuted) + tuple(m.basis_size for m in self.matrix_sequence_quadrature_permuted) + (self.stage, self.buffer, self.interface.within_inames) + (self.interface.direct_is_possible,) - # return tuple(m.quadrature_size for m in self.matrix_sequence_quadrature_permuted) + tuple(m.basis_size for m in self.matrix_sequence_quadrature_permuted) + (self.stage, self.buffer, self.interface.within_inames) - # return tuple(m.basis_size for m in self.matrix_sequence_quadrature_permuted) + (self.stage, self.buffer, self.interface.within_inames) + return tuple(m.quadrature_size for m in self.matrix_sequence_quadrature_permuted) + tuple(m.basis_size for m in self.matrix_sequence_quadrature_permuted) + (self.stage, self.buffer, self.interface.within_inames) + (self.interface.direct_is_possible, self.interface.quadrature_permutation) + # return tuple(m.quadrature_size for m in self.matrix_sequence_quadrature_permuted) + tuple(m.basis_size for m in self.matrix_sequence_quadrature_permuted) + (self.stage, self.buffer, self.interface.within_inames) + (self.interface.direct_is_possible,) @property def cache_key(self):