From 4b2db15f2cd52ad2cd3e11f48926375e31153ca5 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Fri, 31 Mar 2017 14:38:02 +0200
Subject: [PATCH] Refactor output shape

---
 python/dune/perftool/loopy/symbolic.py      | 23 +++++++++++++++++++++
 python/dune/perftool/sumfact/basis.py       |  8 ++-----
 python/dune/perftool/sumfact/realization.py | 20 +++---------------
 3 files changed, 28 insertions(+), 23 deletions(-)

diff --git a/python/dune/perftool/loopy/symbolic.py b/python/dune/perftool/loopy/symbolic.py
index 345819e5..402c8940 100644
--- a/python/dune/perftool/loopy/symbolic.py
+++ b/python/dune/perftool/loopy/symbolic.py
@@ -109,6 +109,29 @@ class SumfactKernel(ImmutableRecord, prim.Variable):
             shape = shape + (4,)
         return shape
 
+    @property
+    def output_shape(self):
+        """ The shape of the output temporary, ready to be fed into loopy """
+        # In stage 1, the output may be of reduced dimensionality
+        if self.stage == 1:
+            shape = tuple(mat.rows for mat in self.a_matrices if mat.face is None)
+        else:
+            shape = tuple(mat.rows for mat in self.a_matrices)
+        if self.vectorized:
+            shape = shape + (4,)
+        return shape
+
+    @property
+    def output_dimtags(self):
+        """ The dim_tags of the output temporary, ready to be fed into loopy """
+        tags = ["f"] *  len(self.output_shape)
+        if self.vectorized:
+            if self.stage == 1:
+                tags[-1] = 'c'
+            else:
+                tags[-1] = 'vec'
+        return ",".join(tags)
+
 
 class FusedMultiplyAdd(prim.Expression):
     """ Represents an FMA operation """
diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py
index 1f15491e..cf148a7f 100644
--- a/python/dune/perftool/sumfact/basis.py
+++ b/python/dune/perftool/sumfact/basis.py
@@ -97,9 +97,7 @@ def pymbolic_coefficient_gradient(element, restriction, component, coeff_func, v
         # evaluation of the gradients of basis functions at quadrature
         # points (stage 1)
         from dune.perftool.sumfact.realization import realize_sum_factorization_kernel
-        var, insn_dep = realize_sum_factorization_kernel(sf,
-                                                 outshape=tuple(mat.rows for mat in sf.a_matrices if mat.face is None),
-                                                 )
+        var, insn_dep = realize_sum_factorization_kernel(sf)
 
         buffers.append(var)
 
@@ -144,9 +142,7 @@ def pymbolic_coefficient(element, restriction, component, coeff_func, visitor):
     # Add a sum factorization kernel that implements the evaluation of
     # the basis functions at quadrature points (stage 1)
     from dune.perftool.sumfact.realization import realize_sum_factorization_kernel
-    var, _ = realize_sum_factorization_kernel(sf,
-                                              outshape=tuple(mat.rows for mat in sf.a_matrices if mat.face is None),
-                                              )
+    var, _ = realize_sum_factorization_kernel(sf)
 
     if sf.index:
         index = (sf.index,)
diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py
index 0ed1641e..8da4339f 100644
--- a/python/dune/perftool/sumfact/realization.py
+++ b/python/dune/perftool/sumfact/realization.py
@@ -69,7 +69,7 @@ def _realize_input(sf, insn_dep):
 @generator_factory(item_tags=("sumfactkernel",),
                    context_tags=("kernel",),
                    cache_key_generator=lambda s, **kw: s.cache_key)
-def _realize_sum_factorization_kernel(sf, insn_dep=frozenset(), outshape=None, direct_output=None):
+def _realize_sum_factorization_kernel(sf, insn_dep=frozenset(), direct_output=None):
     # Unify the insn_dep parameter to be a frozenset
     if isinstance(insn_dep, str):
         insn_dep = frozenset({insn_dep})
@@ -279,23 +279,9 @@ def _realize_sum_factorization_kernel(sf, insn_dep=frozenset(), outshape=None, d
             insn_dep = instruction(code="HP_TIMER_START({});".format(qp_timer_name),
                                    depends_on=insn_dep)
 
-    if outshape is None:
-        assert sf.stage == 3
-        outshape = tuple(mat.rows for mat in a_matrices)
-
-    dim_tags = ",".join(['f'] * len(outshape))
-
-    if sf.vectorized:
-        outshape = outshape + vec_shape
-        # This is a 'bit' hacky: In stage 3 we need to return something with vectag, in stage 1 not.
-        if sf.stage == 1:
-            dim_tags = dim_tags + ",c"
-        else:
-            dim_tags = dim_tags + ",vec"
-
     out = get_buffer_temporary(sf.buffer,
-                               shape=outshape,
-                               dim_tags=dim_tags,
+                               shape=sf.output_shape,
+                               dim_tags=sf.output_dimtags,
                                )
     silenced_warning('read_no_write({})'.format(out))
 
-- 
GitLab