From 1afb64a55ca28fcba8420c9fe11b349f4869c1e8 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ren=C3=A9=20He=C3=9F?= <rene.hess@iwr.uni-heidelberg.de>
Date: Mon, 28 Jan 2019 11:32:31 +0100
Subject: [PATCH] [skip ci] Cleanup

---
 python/dune/codegen/sumfact/accumulation.py |  4 ++--
 python/dune/codegen/sumfact/symbolic.py     | 19 ++++++++++++-------
 2 files changed, 14 insertions(+), 9 deletions(-)

diff --git a/python/dune/codegen/sumfact/accumulation.py b/python/dune/codegen/sumfact/accumulation.py
index 3264a62c..e360c847 100644
--- a/python/dune/codegen/sumfact/accumulation.py
+++ b/python/dune/codegen/sumfact/accumulation.py
@@ -258,8 +258,8 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord):
 
         return frozenset({dep})
 
-    def realize_direct_output(self, result, inames, shape, which=0, permute=True, **args):
-        if permute:
+    def realize_direct_output(self, result, inames, shape, which=0, reverse_cost_permutation=True, **args):
+        if reverse_cost_permutation:
             inames = permute_backward(inames, self.cost_permutation)
         inames = permute_backward(inames, self.quadrature_permutation)
 
diff --git a/python/dune/codegen/sumfact/symbolic.py b/python/dune/codegen/sumfact/symbolic.py
index 3f4e91f4..2b91aafa 100644
--- a/python/dune/codegen/sumfact/symbolic.py
+++ b/python/dune/codegen/sumfact/symbolic.py
@@ -101,7 +101,7 @@ class SumfactKernelInterfaceBase(object):
     def accumulate_output(self, sf, result, insn_dep, inames=None, additional_inames=()):
         """Generate accumulate instruction after a stage 3 sumfact kernel function (non fastdg)
 
-        This happens after the function call. After stage 2 the result should
+        This happens after the function call. After stage 3 the result should
         be ordered x, y, z,..., no permutations necessary.
 
         Parameters
@@ -118,13 +118,12 @@ class SumfactKernelInterfaceBase(object):
         """
         raise NotImplementedError
 
-    def realize_direct_output(self, result, iname, shape, which=0, **kwargs):
+    def realize_direct_output(self, result, iname, shape, which=0, reverse_cost_permutation=True, **kwargs):
         """Accumulate results directly in the sumfact kernel function (fastdg)
 
         This happens inside the sumfact kernel function.
 
         TODO: Add note about permutation
-        TODO: Document input arguments
 
         Parameters
         ----------
@@ -134,6 +133,8 @@ class SumfactKernelInterfaceBase(object):
         shape : tuple of ints
         which : int
             TODO Doc me!
+        reverse_cost_permutation : tuple of ints
+            TODO Doc me!
         **kwargs :
             Key word arguments passed to loopy instruction
         """
@@ -398,7 +399,12 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase):
         for o in outputs:
             hadd_result = self._add_hadd(o, result)
             which = tuple(remove_duplicates(self.interfaces)).index(o)
-            deps = deps.union(o.realize_direct_output(hadd_result, inames, shape, which=which, permute=False, **args))
+            deps = deps.union(o.realize_direct_output(hadd_result,
+                                                      inames,
+                                                      shape,
+                                                      which=which,
+                                                      reverse_cost_permutation=False,
+                                                      **args))
 
         return deps
 
@@ -603,9 +609,8 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
         # TODO: For now we do not vectorize SumfactKernels with different
         # quadrature_permutation. This should be handled like upper/lower
         # vectorization
-        return tuple(m.quadrature_size for m in self.matrix_sequence_quadrature_permuted) + tuple(m.basis_size for m in self.matrix_sequence_quadrature_permuted) + (self.stage, self.buffer, self.interface.within_inames) + (self.interface.direct_is_possible,)
-        # return tuple(m.quadrature_size for m in self.matrix_sequence_quadrature_permuted) + tuple(m.basis_size for m in self.matrix_sequence_quadrature_permuted) + (self.stage, self.buffer, self.interface.within_inames)
-        # return tuple(m.basis_size for m in self.matrix_sequence_quadrature_permuted) + (self.stage, self.buffer, self.interface.within_inames)
+        return tuple(m.quadrature_size for m in self.matrix_sequence_quadrature_permuted) + tuple(m.basis_size for m in self.matrix_sequence_quadrature_permuted) + (self.stage, self.buffer, self.interface.within_inames) + (self.interface.direct_is_possible, self.interface.quadrature_permutation)
+        # return tuple(m.quadrature_size for m in self.matrix_sequence_quadrature_permuted) + tuple(m.basis_size for m in self.matrix_sequence_quadrature_permuted) + (self.stage, self.buffer, self.interface.within_inames) + (self.interface.direct_is_possible,)
 
     @property
     def cache_key(self):
-- 
GitLab