From 1afb64a55ca28fcba8420c9fe11b349f4869c1e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20He=C3=9F?= <rene.hess@iwr.uni-heidelberg.de> Date: Mon, 28 Jan 2019 11:32:31 +0100 Subject: [PATCH] [skip ci] Cleanup --- python/dune/codegen/sumfact/accumulation.py | 4 ++-- python/dune/codegen/sumfact/symbolic.py | 19 ++++++++++++------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/python/dune/codegen/sumfact/accumulation.py b/python/dune/codegen/sumfact/accumulation.py index 3264a62c..e360c847 100644 --- a/python/dune/codegen/sumfact/accumulation.py +++ b/python/dune/codegen/sumfact/accumulation.py @@ -258,8 +258,8 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord): return frozenset({dep}) - def realize_direct_output(self, result, inames, shape, which=0, permute=True, **args): - if permute: + def realize_direct_output(self, result, inames, shape, which=0, reverse_cost_permutation=True, **args): + if reverse_cost_permutation: inames = permute_backward(inames, self.cost_permutation) inames = permute_backward(inames, self.quadrature_permutation) diff --git a/python/dune/codegen/sumfact/symbolic.py b/python/dune/codegen/sumfact/symbolic.py index 3f4e91f4..2b91aafa 100644 --- a/python/dune/codegen/sumfact/symbolic.py +++ b/python/dune/codegen/sumfact/symbolic.py @@ -101,7 +101,7 @@ class SumfactKernelInterfaceBase(object): def accumulate_output(self, sf, result, insn_dep, inames=None, additional_inames=()): """Generate accumulate instruction after a stage 3 sumfact kernel function (non fastdg) - This happens after the function call. After stage 2 the result should + This happens after the function call. After stage 3 the result should be ordered x, y, z,..., no permutations necessary. Parameters @@ -118,13 +118,12 @@ class SumfactKernelInterfaceBase(object): """ raise NotImplementedError - def realize_direct_output(self, result, iname, shape, which=0, **kwargs): + def realize_direct_output(self, result, iname, shape, which=0, reverse_cost_permutation=True, **kwargs): """Accumulate results directly in the sumfact kernel function (fastdg) This happens inside the sumfact kernel function. TODO: Add note about permutation - TODO: Document input arguments Parameters ---------- @@ -134,6 +133,8 @@ class SumfactKernelInterfaceBase(object): shape : tuple of ints which : int TODO Doc me! + reverse_cost_permutation : tuple of ints + TODO Doc me! **kwargs : Key word arguments passed to loopy instruction """ @@ -398,7 +399,12 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase): for o in outputs: hadd_result = self._add_hadd(o, result) which = tuple(remove_duplicates(self.interfaces)).index(o) - deps = deps.union(o.realize_direct_output(hadd_result, inames, shape, which=which, permute=False, **args)) + deps = deps.union(o.realize_direct_output(hadd_result, + inames, + shape, + which=which, + reverse_cost_permutation=False, + **args)) return deps @@ -603,9 +609,8 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): # TODO: For now we do not vectorize SumfactKernels with different # quadrature_permutation. This should be handled like upper/lower # vectorization - return tuple(m.quadrature_size for m in self.matrix_sequence_quadrature_permuted) + tuple(m.basis_size for m in self.matrix_sequence_quadrature_permuted) + (self.stage, self.buffer, self.interface.within_inames) + (self.interface.direct_is_possible,) - # return tuple(m.quadrature_size for m in self.matrix_sequence_quadrature_permuted) + tuple(m.basis_size for m in self.matrix_sequence_quadrature_permuted) + (self.stage, self.buffer, self.interface.within_inames) - # return tuple(m.basis_size for m in self.matrix_sequence_quadrature_permuted) + (self.stage, self.buffer, self.interface.within_inames) + return tuple(m.quadrature_size for m in self.matrix_sequence_quadrature_permuted) + tuple(m.basis_size for m in self.matrix_sequence_quadrature_permuted) + (self.stage, self.buffer, self.interface.within_inames) + (self.interface.direct_is_possible, self.interface.quadrature_permutation) + # return tuple(m.quadrature_size for m in self.matrix_sequence_quadrature_permuted) + tuple(m.basis_size for m in self.matrix_sequence_quadrature_permuted) + (self.stage, self.buffer, self.interface.within_inames) + (self.interface.direct_is_possible,) @property def cache_key(self): -- GitLab