From 2b043378f7c675a66d5f6666608b7daf61bb811a Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Mon, 26 Mar 2018 14:26:12 +0200
Subject: [PATCH] Move function_name generation onto symbolic representation

---
 python/dune/perftool/sumfact/realization.py | 49 +++++----------------
 python/dune/perftool/sumfact/symbolic.py    | 31 ++++++++-----
 2 files changed, 31 insertions(+), 49 deletions(-)

diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py
index 4ea2e8a5..703f9a6b 100644
--- a/python/dune/perftool/sumfact/realization.py
+++ b/python/dune/perftool/sumfact/realization.py
@@ -29,6 +29,7 @@ from dune.perftool.sumfact.permutation import (sumfact_permutation_strategy,
                                                permute_backward,
                                                permute_forward,
                                                )
+from dune.perftool.sumfact.quadrature import quadrature_points_per_direction
 from dune.perftool.sumfact.symbolic import (get_input_output_tuple,
                                             SumfactKernel,
                                             VectorizedSumfactKernel,
@@ -47,39 +48,9 @@ import numpy as np
 import pymbolic.primitives as prim
 
 
-necessary_kernel_implementations = generator_factory(item_tags=("kernelimpl",), no_deco=True)
-
-
-@generator_factory(cache_key_generator=lambda s, qp: (s.function_key, qp))
-def _name_kernel_implementation_function(sf, qp):
-    name = "sfimpl_{}".format("_".join(str(m) for m in sf.matrix_sequence))
-    if get_form_option("fastdg"):
-        if sf.stage == 1:
-            if isinstance(sf, SumfactKernel):
-                fastdg = "{}comp{}".format(FEM_name_mangling(sf.input.element), sf.input.element_index)
-            if isinstance(sf, VectorizedSumfactKernel):
-                fastdg = "_".join("{}comp{}".format(FEM_name_mangling(i.element), i.element_index) for i in remove_duplicates(sf.input.inputs))
-        if sf.stage == 3:
-            if isinstance(sf, SumfactKernel):
-                fastdg = "{}comp{}".format(FEM_name_mangling(sf.output.test_element), sf.output.test_element_index)
-                if sf.within_inames:
-                    fastdg = "{}x{}comp{}".format(fastdg, FEM_name_mangling(sf.output.trial_element), sf.output.trial_element_index)
-            if isinstance(sf, VectorizedSumfactKernel):
-                fastdg = "_".join("{}comp{}".format(FEM_name_mangling(i.test_element), i.test_element_index) for i in remove_duplicates(sf.output.outputs))
-                if sf.within_inames:
-                    fastdg = "{}x{}".format(fastdg,
-                                            "_".join("{}comp{}".format(FEM_name_mangling(i.trial_element), i.trial_element_index) for i in remove_duplicates(sf.output.outputs))
-                                            )
-
-        name = "{}_fastdg{}_{}".format(name, sf.stage, fastdg)
-    necessary_kernel_implementations((sf, qp))
-    return name
-
-
-def name_kernel_implementation_function(sf):
-    from dune.perftool.sumfact.quadrature import quadrature_points_per_direction
-    qp = quadrature_points_per_direction()
-    return _name_kernel_implementation_function(sf, qp)
+# Have a generator function store the necessary sum factorization kernel implementations
+# This way then can easily be extracted at the end of the form visiting process
+necessary_kernel_implementations = generator_factory(item_tags=("kernelimpl",), cache_key_generator=lambda a: a[0].function_name, no_deco=True)
 
 
 def realize_sum_factorization_kernel(sf, **kwargs):
@@ -125,7 +96,6 @@ def _realize_sum_factorization_kernel(sf):
         insn_dep = insn_dep.union(timer_dep)
 
     # Get all the necessary pieces for a function call
-    funcname = name_kernel_implementation_function(sf)
     buffers = tuple(name_buffer_storage(sf.buffer, i) for i in range(2))
 
     # Make sure that the storage is allocated and has a certain minimum size
@@ -153,8 +123,12 @@ def _realize_sum_factorization_kernel(sf):
     if sf.stage == 3:
         fastdg_args = sf.output.fastdg_args
 
+    # Trigger generation of the sum factorization kernel function
+    qp = quadrature_points_per_direction()
+    necessary_kernel_implementations((sf, qp))
+
     # Call the function
-    code = "{}({});".format(funcname, ", ".join(buffers + fastdg_args))
+    code = "{}({});".format(sf.function_name, ", ".join(buffers + fastdg_args))
     tag = "sumfact_stage{}".format(sf.stage)
     insn_dep = frozenset({instruction(code=code,
                                       depends_on=insn_dep,
@@ -334,7 +308,6 @@ def realize_sumfact_kernel_function(sf):
                                   })
 
     # Construct a loopy kernel object
-    name = name_kernel_implementation_function(sf)
     from dune.perftool.pdelab.localoperator import extract_kernel_from_cache
     args = ["const char* buffer0", "const char* buffer1"]
     if get_form_option('fastdg'):
@@ -344,7 +317,7 @@ def realize_sumfact_kernel_function(sf):
             if sf.within_inames:
                 args.append("unsigned int jacobian_offset{}".format(i))
 
-    signature = "void {}({}) const".format(name, ", ".join(args))
-    kernel = extract_kernel_from_cache("kernel_default", name, [signature], add_timings=False)
+    signature = "void {}({}) const".format(sf.function_name, ", ".join(args))
+    kernel = extract_kernel_from_cache("kernel_default", sf.function_name, [signature], add_timings=False)
     delete_cache_items("kernel_default")
     return kernel
diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py
index c3be20dd..5ef99849 100644
--- a/python/dune/perftool/sumfact/symbolic.py
+++ b/python/dune/perftool/sumfact/symbolic.py
@@ -5,6 +5,7 @@ from dune.perftool.generation import (get_counted_variable,
                                       subst_rule,
                                       transform,
                                       )
+from dune.perftool.pdelab.driver import FEM_name_mangling
 from dune.perftool.pdelab.geometry import local_dimension, world_dimension
 from dune.perftool.sumfact.quadrature import quadrature_inames
 from dune.perftool.sumfact.tabulation import BasisTabulationMatrixBase, BasisTabulationMatrixArray
@@ -288,15 +289,18 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
     # Watch out for the documentation to see which key is used unter what circumstances
     #
     @property
-    def function_key(self):
-        """ Kernels sharing this key may use the same kernel implementation function """
-        fastdg = ()
+    def function_name(self):
+        """ The name of the function that implements this kernel """
+        name = "sfimpl_{}".format("_".join(str(m) for m in self.matrix_sequence))
         if get_form_option("fastdg"):
             if self.stage == 1:
-                fastdg = (self.input.element, self.input.element_index)
+                fastdg = "{}comp{}".format(FEM_name_mangling(self.input.element), self.input.element_index)
             if self.stage == 3:
-                fastdg = (self.output.test_element, self.output.test_element_index, self.output.trial_element, self.output.trial_element_index)
-        return tuple(str(m) for m in self.matrix_sequence) + fastdg
+                fastdg = "{}comp{}".format(FEM_name_mangling(self.output.test_element), self.output.test_element_index)
+                if self.within_inames:
+                    fastdg = "{}x{}comp{}".format(fastdg, FEM_name_mangling(self.output.trial_element), self.output.trial_element_index)
+            name = "{}_fastdg{}_{}".format(name, self.stage, fastdg)
+        return name
 
     @property
     def parallel_key(self):
@@ -556,14 +560,19 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
     # Watch out for the documentation to see which key is used unter what circumstances
     #
     @property
-    def function_key(self):
-        fastdg = ()
+    def function_name(self):
+        name = "sfimpl_{}".format("_".join(str(m) for m in self.matrix_sequence))
         if get_form_option("fastdg"):
             if self.stage == 1:
-                fastdg = sum(((i.element, i.element_index) for i in remove_duplicates(self.input.inputs)), ())
+                fastdg = "_".join("{}comp{}".format(FEM_name_mangling(i.element), i.element_index) for i in remove_duplicates(self.input.inputs))
             if self.stage == 3:
-                fastdg = sum(((o.test_element, o.test_element_index, o.trial_element, o.trial_element_index) for o in remove_duplicates(self.output.outputs)), ())
-        return tuple(str(m) for m in self.matrix_sequence) + fastdg
+                fastdg = "_".join("{}comp{}".format(FEM_name_mangling(i.test_element), i.test_element_index) for i in remove_duplicates(self.output.outputs))
+                if self.within_inames:
+                    fastdg = "{}x{}".format(fastdg,
+                                            "_".join("{}comp{}".format(FEM_name_mangling(i.trial_element), i.trial_element_index) for i in remove_duplicates(self.output.outputs))
+                                            )
+            name = "{}_fastdg{}_{}".format(name, self.stage, fastdg)
+        return name
 
     @property
     def cache_key(self):
-- 
GitLab