From a84d4edaa13a03a1fcfaeeb375dc18172ec1d586 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Mon, 3 Apr 2017 11:47:33 +0200
Subject: [PATCH] Move the output generation on the sumfact node

---
 python/dune/perftool/sumfact/realization.py | 2 +-
 python/dune/perftool/sumfact/symbolic.py    | 7 +++++++
 python/dune/perftool/sumfact/tabulation.py  | 6 ------
 3 files changed, 8 insertions(+), 7 deletions(-)

diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py
index c063c7e1..dcc06af6 100644
--- a/python/dune/perftool/sumfact/realization.py
+++ b/python/dune/perftool/sumfact/realization.py
@@ -272,4 +272,4 @@ def _realize_sum_factorization_kernel(sf):
                                )
     silenced_warning('read_no_write({})'.format(out))
 
-    return next(iter(matrix_sequence)).output_to_pymbolic(out), insn_dep
+    return sf.output_to_pymbolic(out), insn_dep
diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py
index 42cf3382..0017d467 100644
--- a/python/dune/perftool/sumfact/symbolic.py
+++ b/python/dune/perftool/sumfact/symbolic.py
@@ -3,6 +3,7 @@
 from pytools import ImmutableRecord
 
 import pymbolic.primitives as prim
+import loopy as lp
  
 
 class SumfactKernel(ImmutableRecord, prim.Variable):
@@ -243,3 +244,9 @@ class SumfactKernel(ImmutableRecord, prim.Variable):
             return self.quadrature_dimtags
         else:
             return self.dof_dimtags
+
+    def output_to_pymbolic(self, name):
+        if self.vectorized:
+            return lp.TaggedVariable(name, "vector")
+        else:
+            return lp.TaggedVariable(name, "sumfac")
diff --git a/python/dune/perftool/sumfact/tabulation.py b/python/dune/perftool/sumfact/tabulation.py
index 03eb0915..c6fa99e3 100644
--- a/python/dune/perftool/sumfact/tabulation.py
+++ b/python/dune/perftool/sumfact/tabulation.py
@@ -57,9 +57,6 @@ class BasisTabulationMatrix(BasisTabulationMatrixBase):
     def vectorized(self):
         return False
 
-    def output_to_pymbolic(self, name):
-        return lp.TaggedVariable(name, "sumfac")
-
 
 class BasisTabulationMatrixArray(BasisTabulationMatrixBase):
     def __init__(self, rows, cols, transpose, derivative, face):
@@ -95,9 +92,6 @@ class BasisTabulationMatrixArray(BasisTabulationMatrixBase):
     def vectorized(self):
         return True
 
-    def output_to_pymbolic(self, name):
-        return lp.TaggedVariable(name, "vector")
-
 
 def quadrature_points_per_direction():
     # Quadrature order
-- 
GitLab