From cb740aff339dd3a07f19ba94086aceb450575e75 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Thu, 27 Oct 2016 14:55:49 +0200
Subject: [PATCH] Implement quadrature point evaulation in sumfact case

---
 python/dune/perftool/pdelab/basis.py       | 13 ++++-----
 python/dune/perftool/pdelab/geometry.py    |  5 ++--
 python/dune/perftool/pdelab/quadrature.py  |  9 ++++--
 python/dune/perftool/sumfact/__init__.py   |  2 +-
 python/dune/perftool/sumfact/amatrix.py    |  1 +
 python/dune/perftool/sumfact/quadrature.py | 34 ++++++++++++++++++++--
 6 files changed, 50 insertions(+), 14 deletions(-)

diff --git a/python/dune/perftool/pdelab/basis.py b/python/dune/perftool/pdelab/basis.py
index 09ecddb0..aacbcac0 100644
--- a/python/dune/perftool/pdelab/basis.py
+++ b/python/dune/perftool/pdelab/basis.py
@@ -3,6 +3,7 @@
 from dune.perftool.generation import (cached,
                                       class_member,
                                       generator_factory,
+                                      get_backend,
                                       include_file,
                                       instruction,
                                       preamble,
@@ -65,9 +66,8 @@ def evaluate_basis(leaf_element, name, restriction):
     temporary_variable(name, shape=(name_lfs_bound(lfs),), decl_method=None)
     declare_cache_temporary(leaf_element, restriction, name, which='Function')
     cache = name_localbasis_cache(leaf_element)
-    qp = name_quadrature_position_in_cell(restriction)
-    instruction(inames=(quadrature_iname(),
-                        ),
+    qp = get_backend("qp_in_cell")(restriction)
+    instruction(inames=get_backend("quad_inames")(),
                 code='{} = {}.evaluateFunction({}, {}.finiteElement().localBasis());'.format(name,
                                                                                              cache,
                                                                                              qp,
@@ -94,8 +94,7 @@ def evaluate_reference_gradient(leaf_element, name, restriction):
     declare_cache_temporary(leaf_element, restriction, name, which='Jacobian')
     cache = name_localbasis_cache(leaf_element)
     qp = name_quadrature_position_in_cell(restriction)
-    instruction(inames=(quadrature_iname(),
-                        ),
+    instruction(inames=get_backend("quad_inames")(),
                 code='{} = {}.evaluateJacobian({}, {}.finiteElement().localBasis());'.format(name,
                                                                                              cache,
                                                                                              qp,
@@ -154,7 +153,7 @@ def evaluate_coefficient(element, name, container, restriction, component):
     reduction_expr = Product((coeff, Subscript(Variable(basis), Variable(index))))
     instruction(expression=Reduction("sum", index, reduction_expr, allow_simultaneous=True),
                 assignee=assignee,
-                forced_iname_deps=frozenset({quadrature_iname()}).union(frozenset(idims)),
+                forced_iname_deps=frozenset(get_backend("quad_inames")()).union(frozenset(idims)),
                 forced_iname_deps_is_final=True,
                 )
 
@@ -193,6 +192,6 @@ def evaluate_coefficient_gradient(element, name, container, restriction, compone
 
     instruction(expression=Reduction("sum", index, reduction_expr, allow_simultaneous=True),
                 assignee=assignee,
-                forced_iname_deps=frozenset({quadrature_iname()}).union(frozenset(idims)),
+                forced_iname_deps=frozenset(get_backend("quad_inames")()).union(frozenset(idims)),
                 forced_iname_deps_is_final=True,
                 )
diff --git a/python/dune/perftool/pdelab/geometry.py b/python/dune/perftool/pdelab/geometry.py
index 8dd41c12..1210a6a0 100644
--- a/python/dune/perftool/pdelab/geometry.py
+++ b/python/dune/perftool/pdelab/geometry.py
@@ -2,6 +2,7 @@ from dune.perftool.ufl.modified_terminals import Restriction
 from dune.perftool.pdelab.restriction import restricted_name
 from dune.perftool.generation import (cached,
                                       domain,
+                                      get_backend,
                                       get_global_context_value,
                                       iname,
                                       preamble,
@@ -254,7 +255,7 @@ def define_jacobian_inverse_transposed(name, restriction):
     dim = name_dimension()
     temporary_variable(name, decl_method=define_jacobian_inverse_transposed_temporary(restriction), shape=(dim, dim))
     geo = name_cell_geometry(restriction)
-    pos = name_quadrature_position_in_cell(restriction)
+    pos = get_backend("qp_in_cell")(restriction)
     return quadrature_preamble("{} = {}.jacobianInverseTransposed({});".format(name,
                                                                                geo,
                                                                                pos,
@@ -273,7 +274,7 @@ def name_jacobian_inverse_transposed(restriction):
 def define_jacobian_determinant(name):
     temporary_variable(name, shape=())
     geo = name_geometry()
-    pos = name_quadrature_position()
+    pos = get_backend("quad_pos")()
     code = "{} = {}.integrationElement({});".format(name,
                                                     geo,
                                                     pos,
diff --git a/python/dune/perftool/pdelab/quadrature.py b/python/dune/perftool/pdelab/quadrature.py
index ea55c546..030c467c 100644
--- a/python/dune/perftool/pdelab/quadrature.py
+++ b/python/dune/perftool/pdelab/quadrature.py
@@ -1,5 +1,7 @@
-from dune.perftool.generation import (cached,
+from dune.perftool.generation import (backend,
+                                      cached,
                                       domain,
+                                      get_backend,
                                       get_global_context_value,
                                       iname,
                                       include_file,
@@ -18,12 +20,13 @@ def quadrature_iname():
     return "q"
 
 
+@backend(interface="quad_inames")
 def quadrature_inames():
     return (quadrature_iname(),)
 
 
 def quadrature_preamble(code, **kw):
-    return instruction(inames=quadrature_iname(), code=code, **kw)
+    return instruction(inames=get_backend(interface="quad_inames")(), code=code, **kw)
 
 
 def name_quadrature_point():
@@ -37,6 +40,7 @@ def define_quadrature_position(name):
                                )
 
 
+@backend(interface="quad_pos")
 def name_quadrature_position():
     name = "pos"
     # To determine the shape, I do query global information here for lack of good alternatives
@@ -52,6 +56,7 @@ def name_quadrature_position():
     return name
 
 
+@backend(interface="qp_in_cell")
 def name_quadrature_position_in_cell(restriction):
     if restriction == Restriction.NONE:
         return name_quadrature_position()
diff --git a/python/dune/perftool/sumfact/__init__.py b/python/dune/perftool/sumfact/__init__.py
index a4abda6f..e263c742 100644
--- a/python/dune/perftool/sumfact/__init__.py
+++ b/python/dune/perftool/sumfact/__init__.py
@@ -15,4 +15,4 @@ class SumFactInterface(PDELabInterface):
         return quadrature_inames()
 
     def pymbolic_quadrature_weight(self):
-        return quadrature_weight()
\ No newline at end of file
+        return quadrature_weight()
diff --git a/python/dune/perftool/sumfact/amatrix.py b/python/dune/perftool/sumfact/amatrix.py
index 9ff78641..1d878add 100644
--- a/python/dune/perftool/sumfact/amatrix.py
+++ b/python/dune/perftool/sumfact/amatrix.py
@@ -139,6 +139,7 @@ def define_oned_quadrature_points(name):
 
 def name_oned_quadrature_points():
     name = "qp"
+    globalarg(name, shape=(quadrature_points_per_direction(),), dtype=NumpyType(numpy.float64))
     define_oned_quadrature_points(name)
     return name
 
diff --git a/python/dune/perftool/sumfact/quadrature.py b/python/dune/perftool/sumfact/quadrature.py
index 786de465..ea7104a4 100644
--- a/python/dune/perftool/sumfact/quadrature.py
+++ b/python/dune/perftool/sumfact/quadrature.py
@@ -1,4 +1,5 @@
-from dune.perftool.generation import (domain,
+from dune.perftool.generation import (backend,
+                                      domain,
                                       function_mangler,
                                       get_global_context_value,
                                       iname,
@@ -7,9 +8,11 @@ from dune.perftool.generation import (domain,
                                       )
 
 from dune.perftool.sumfact.amatrix import (name_number_of_basis_functions_per_direction,
+                                           name_oned_quadrature_points,
                                            name_oned_quadrature_weights,
                                            )
 from dune.perftool.pdelab.argument import name_accumulation_variable
+from dune.perftool.pdelab.geometry import dimension_iname
 
 from loopy import CallMangleInfo
 from loopy.symbolic import FunctionIdentifier
@@ -49,6 +52,7 @@ def sumfact_quad_iname(d, context):
     return name
 
 
+@backend(interface="quad_inames", name="sumfact")
 def quadrature_inames(context=''):
     formdata = get_global_context_value('formdata')
     dim = formdata.geometric_dimension
@@ -81,4 +85,30 @@ def recursive_quadrature_weight(dir=0):
 
 
 def quadrature_weight():
-    return recursive_quadrature_weight()
\ No newline at end of file
+    return recursive_quadrature_weight()
+
+
+def define_quadrature_position(name):
+    formdata = get_global_context_value('formdata')
+    dim = formdata.geometric_dimension
+    for i in range(dim):
+        instruction(expression=Subscript(Variable(name_oned_quadrature_points()), (Variable(quadrature_inames()[i]),)),
+                    assignee=Subscript(Variable(name), (i,)),
+                    forced_iname_deps=frozenset(quadrature_inames()),
+                    forced_iname_deps_is_final=True,
+                    )
+
+
+@backend(interface="quad_pos", name="sumfact")
+def name_quadrature_position():
+    formdata = get_global_context_value('formdata')
+    dim = formdata.geometric_dimension
+    name = 'pos'
+    temporary_variable(name, shape=(dim,), shape_impl=("fv",))
+    define_quadrature_position(name)
+    return name
+
+
+@backend(interface="qp_in_cell", name="sumfact")
+def name_quadrature_position_in_cell(restriction):
+    return name_quadrature_position()
-- 
GitLab