From 9076c157bfaad823ed76366ca30c4ef486791eea Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Fri, 23 Mar 2018 10:22:20 +0100
Subject: [PATCH] First steps towards fastdg jacobians

---
 python/dune/perftool/sumfact/realization.py | 20 +++++++++++++++++++-
 python/dune/perftool/sumfact/symbolic.py    | 12 +++++++++---
 2 files changed, 28 insertions(+), 4 deletions(-)

diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py
index de515848..cd387b1e 100644
--- a/python/dune/perftool/sumfact/realization.py
+++ b/python/dune/perftool/sumfact/realization.py
@@ -18,6 +18,7 @@ from dune.perftool.generation import (barrier,
                                       )
 from dune.perftool.pdelab.argument import pymbolic_coefficient
 from dune.perftool.pdelab.basis import shape_as_pymbolic
+from dune.perftool.pdelab.driver import FEM_name_mangling
 from dune.perftool.pdelab.geometry import world_dimension
 from dune.perftool.options import (get_form_option,
                                    get_option,
@@ -27,7 +28,10 @@ from dune.perftool.sumfact.permutation import (sumfact_permutation_strategy,
                                                permute_backward,
                                                permute_forward,
                                                )
-from dune.perftool.sumfact.symbolic import get_input_output_tuple
+from dune.perftool.sumfact.symbolic import (get_input_output_tuple,
+                                            SumfactKernel,
+                                            VectorizedSumfactKernel,
+                                            )
 from dune.perftool.sumfact.vectorization import attach_vectorization_info
 from dune.perftool.sumfact.accumulation import sumfact_iname
 from dune.perftool.loopy.target import dtype_floatingpoint
@@ -47,6 +51,20 @@ necessary_kernel_implementations = generator_factory(item_tags=("kernelimpl",),
 @generator_factory(cache_key_generator=lambda s, qp: (s.function_key, qp))
 def _name_kernel_implementation_function(sf, qp):
     name = "sfimpl_{}".format("_".join(str(m) for m in sf.matrix_sequence))
+    if get_form_option("fastdg"):
+        if sf.stage == 1:
+            if isinstance(sf, SumfactKernel):
+                fastdg = "{}comp{}".format(FEM_name_mangling(sf.input.element), sf.input.element_index)
+            if isinstance(sf, VectorizedSumfactKernel):
+                1/0
+        if sf.stage == 3:
+            if isinstance(sf, SumfactKernel):
+                fastdg = "{}comp{}".format(FEM_name_mangling(sf.output.test_element), sf.output.test_element_index)
+                if sf.output.trial_element:
+                    fastdg = "{}x{}comp{}".format(fastdg, FEM_name_mangling(sf.output.trial_element), sf.output.trial_element_index)
+            if isinstance(sf, VectorizedSumfactKernel):
+                1/0
+        name = "{}_fastdg{}_{}".format(name, sf.stage, fastdg)
     necessary_kernel_implementations((sf, qp))
     return name
 
diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py
index 5dc395f8..be90a1dc 100644
--- a/python/dune/perftool/sumfact/symbolic.py
+++ b/python/dune/perftool/sumfact/symbolic.py
@@ -1,6 +1,6 @@
 """ A pymbolic node representing a sum factorization kernel """
 
-from dune.perftool.options import get_option
+from dune.perftool.options import get_form_option, get_option
 from dune.perftool.generation import (get_counted_variable,
                                       subst_rule,
                                       transform,
@@ -281,7 +281,12 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
     @property
     def function_key(self):
         """ Kernels sharing this key may use the same kernel implementation function """
-        return tuple(str(m) for m in self.matrix_sequence)
+        fastdg = ()
+        if self.stage == 1:
+            fastdg = (self.input.element, self.input.element_index)
+        if self.stage == 3:
+            fastdg = (self.output.test_element, self.output.test_element_index, self.output.trial_element, self.output.trial_element_index)
+        return tuple(str(m) for m in self.matrix_sequence) + fastdg
 
     @property
     def parallel_key(self):
@@ -542,7 +547,8 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
     #
     @property
     def function_key(self):
-        return tuple(str(m) for m in self.matrix_sequence)
+        fastdg = self.inout_key if get_form_option("fastdg") else ()
+        return tuple(str(m) for m in self.matrix_sequence) + fastdg
 
     @property
     def cache_key(self):
-- 
GitLab