From 63886624fc03010c34223c0053382f06c59d07d7 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Thu, 20 Apr 2017 15:57:29 +0200
Subject: [PATCH] Fix FastDG

---
 python/dune/perftool/sumfact/basis.py       |  7 +++++++
 python/dune/perftool/sumfact/realization.py | 11 +++--------
 python/dune/perftool/sumfact/symbolic.py    |  7 +++++--
 3 files changed, 15 insertions(+), 10 deletions(-)

diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py
index de9c30ae..04bf64ba 100644
--- a/python/dune/perftool/sumfact/basis.py
+++ b/python/dune/perftool/sumfact/basis.py
@@ -110,6 +110,13 @@ class LFSSumfactKernelInput(SumfactKernelInputBase, ImmutableRecord):
                     tags=frozenset({"sumfact_stage{}".format(sf.stage)}),
                     )
 
+    @property
+    def direct_input(self):
+        if get_option("fastdg"):
+            return self.coeff_func(self.restriction)
+        else:
+            return None
+
 
 def _basis_functions_per_direction(element, component):
     """Number of basis functions per direction of a given component of an element"""
diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py
index f939feae..8922a0b0 100644
--- a/python/dune/perftool/sumfact/realization.py
+++ b/python/dune/perftool/sumfact/realization.py
@@ -57,10 +57,10 @@ def _realize_sum_factorization_kernel(sf):
                                                          ),
                                              }))
 
-    # Set up the input for stage 1
-    if sf.stage == 1 and not get_option("fastdg"):
-        assert sf.input
+    direct_input = sf.input.direct_input
 
+    # Set up the input for stage 1
+    if direct_input is None:
         if sf.vectorized:
             for i, inputsf in enumerate(sf.kernels):
                 inputsf.input.realize(sf, i, inputsf.insn_dep.union(insn_dep))
@@ -69,11 +69,6 @@ def _realize_sum_factorization_kernel(sf):
 
         insn_dep = insn_dep.union(frozenset({lp.match.Writes("input_{}".format(sf.buffer))}))
 
-    # Construct the direct_input for the FastDG case
-    direct_input = None
-    if get_option('fastdg') and sf.stage == 1:
-        direct_input = sf.input.coeff_func(sf.input.restriction)
-
     direct_output = None
     if get_option('fastdg') and sf.stage == 3:
         direct_output = sf.accumvar + ".data()"
diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py
index a8a42ed1..9a3011cb 100644
--- a/python/dune/perftool/sumfact/symbolic.py
+++ b/python/dune/perftool/sumfact/symbolic.py
@@ -16,8 +16,11 @@ import inspect
 
 class SumfactKernelInputBase(object):
     @property
-    def flat_shape(self):
-        return False
+    def direct_input(self):
+        return None
+
+    def realize(self, sf, i, dep):
+        pass
 
 
 class SumfactKernelBase(object):
-- 
GitLab