From a37601db85ce1cf27cd8f7380543fa4dc79770a5 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Fri, 13 Jan 2017 11:28:12 +0100
Subject: [PATCH] Refactor trialfunction evaluation to be reusable for apply
 function

---
 python/dune/perftool/sumfact/__init__.py | 11 +++++++----
 python/dune/perftool/sumfact/basis.py    | 12 ++++++------
 python/dune/perftool/sumfact/sumfact.py  |  4 ++--
 3 files changed, 15 insertions(+), 12 deletions(-)

diff --git a/python/dune/perftool/sumfact/__init__.py b/python/dune/perftool/sumfact/__init__.py
index 1900e7b5..520fa7d1 100644
--- a/python/dune/perftool/sumfact/__init__.py
+++ b/python/dune/perftool/sumfact/__init__.py
@@ -1,3 +1,6 @@
+from dune.perftool.pdelab.argument import (name_applycontainer,
+                                           name_coefficientcontainer,
+                                           )
 from dune.perftool.sumfact.quadrature import (quadrature_inames,
                                               quadrature_weight,
                                               pymbolic_quadrature_position_in_cell,
@@ -6,8 +9,8 @@ from dune.perftool.sumfact.quadrature import (quadrature_inames,
 from dune.perftool.sumfact.basis import (lfs_inames,
                                          pymbolic_basis,
                                          pymbolic_reference_gradient,
-                                         pymbolic_trialfunction,
-                                         pymbolic_trialfunction_gradient,
+                                         pymbolic_coefficient,
+                                         pymbolic_coefficient_gradient,
                                          )
 import dune.perftool.sumfact.switch
 
@@ -25,12 +28,12 @@ class SumFactInterface(PDELabInterface):
         return pymbolic_reference_gradient(element, restriction, number)
 
     def pymbolic_trialfunction_gradient(self, element, restriction, component, visitor=None):
-        ret, indices = pymbolic_trialfunction_gradient(element, restriction, component, visitor)
+        ret, indices = pymbolic_coefficient_gradient(element, restriction, component, name_coefficientcontainer, visitor)
         visitor.indices = indices
         return ret
 
     def pymbolic_trialfunction(self, element, restriction, component, visitor=None):
-        ret, indices = pymbolic_trialfunction(element, restriction, component, visitor)
+        ret, indices = pymbolic_coefficient(element, restriction, component, name_coefficientcontainer, visitor)
         visitor.indices = indices
         return ret
 
diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py
index 3c04cd09..e9b5a6ce 100644
--- a/python/dune/perftool/sumfact/basis.py
+++ b/python/dune/perftool/sumfact/basis.py
@@ -53,7 +53,7 @@ def name_sumfact_base_buffer():
 
 
 @kernel_cached
-def pymbolic_trialfunction_gradient(element, restriction, component, visitor):
+def pymbolic_coefficient_gradient(element, restriction, component, coeff_func, visitor):
     rawname = "gradu" + "_".join(str(c) for c in component)
     name = restricted_name(rawname, restriction)
 
@@ -92,11 +92,11 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor):
 
         if get_option('fastdg'):
             # Name of direct input, shape and globalarg is set in sum_factorization_kernel
-            direct_input = name_coefficientcontainer(restriction)
+            direct_input = coeff_func(restriction)
         else:
             direct_input = None
             # Setup the input!
-            setup_theta(inp, element, restriction, component, index)
+            setup_theta(inp, element, restriction, component, index, coeff_func)
 
         # Add a sum factorization kernel that implements the
         # evaluation of the gradients of basis functions at quadrature
@@ -136,7 +136,7 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor):
 
 
 @kernel_cached
-def pymbolic_trialfunction(element, restriction, component, visitor):
+def pymbolic_coefficient(element, restriction, component, coeff_func, visitor):
     # Get geometric dimension
     dim = world_dimension()
 
@@ -161,11 +161,11 @@ def pymbolic_trialfunction(element, restriction, component, visitor):
 
     if get_option('fastdg'):
         # Name of direct input, shape and globalarg is set in sum_factorization_kernel
-        direct_input = name_coefficientcontainer(restriction)
+        direct_input = coeff_func(restriction)
     else:
         direct_input = None
         # Setup the input!
-        setup_theta(inp, element, restriction, component, index)
+        setup_theta(inp, element, restriction, component, index, coeff_func)
 
     # Add a sum factorization kernel that implements the evaluation of
     # the basis functions at quadrature points (stage 1)
diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py
index 82aa38d8..c4890794 100644
--- a/python/dune/perftool/sumfact/sumfact.py
+++ b/python/dune/perftool/sumfact/sumfact.py
@@ -95,7 +95,7 @@ def accum_iname(restriction, bound, i):
     return sumfact_iname(bound, "accum")
 
 
-def setup_theta(inp, element, restriction, component, index):
+def setup_theta(inp, element, restriction, component, index, coeff_func):
     if index is None:
         index = ()
     else:
@@ -103,7 +103,7 @@ def setup_theta(inp, element, restriction, component, index):
     # Write initial coefficients into buffer
     lfs = name_lfs(element, restriction, component)
     basisiname = sumfact_iname(name_lfs_bound(lfs), "basis")
-    container = name_coefficientcontainer(restriction)
+    container = coeff_func(restriction)
     coeff = pymbolic_coefficient(container, lfs, basisiname)
     assignee = Subscript(Variable(inp), (Variable(basisiname),) + index)
     return instruction(assignee=assignee,
-- 
GitLab