From 120ee4f26b9ccc57e8efc86f391ef63cd7c84e55 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Fri, 28 Jul 2017 14:15:16 +0200
Subject: [PATCH] Enforce giving an element to quadrature_inames

---
 python/dune/perftool/sumfact/__init__.py   |  2 +-
 python/dune/perftool/sumfact/geometry.py   |  6 +++---
 python/dune/perftool/sumfact/quadrature.py | 16 +++++++++-------
 3 files changed, 13 insertions(+), 11 deletions(-)

diff --git a/python/dune/perftool/sumfact/__init__.py b/python/dune/perftool/sumfact/__init__.py
index 05e0137c..123c4e1e 100644
--- a/python/dune/perftool/sumfact/__init__.py
+++ b/python/dune/perftool/sumfact/__init__.py
@@ -73,6 +73,6 @@ class SumFactInterface(PDELabInterface):
 
     def pymbolic_spatial_coordinate(self):
         import dune.perftool.sumfact.geometry
-        ret, indices = get_backend(interface="spatial_coordinate", selector=option_switch("diagonal_transformation_matrix"))(self.visitor.indices, self.visitor.do_predicates)
+        ret, indices = get_backend(interface="spatial_coordinate", selector=option_switch("diagonal_transformation_matrix"))(self.visitor.indices, self.visitor.do_predicates, self.visitor)
         self.visitor.indices = indices
         return ret
diff --git a/python/dune/perftool/sumfact/geometry.py b/python/dune/perftool/sumfact/geometry.py
index c991aeaa..c698a2be 100644
--- a/python/dune/perftool/sumfact/geometry.py
+++ b/python/dune/perftool/sumfact/geometry.py
@@ -68,7 +68,7 @@ class GeoCornersInput(SumfactKernelInputBase, ImmutableRecord):
 
 @kernel_cached
 @backend(interface="spatial_coordinate", name="default")
-def pymbolic_spatial_coordinate_multilinear(visitor_indices, do_predicates):
+def pymbolic_spatial_coordinate_multilinear(visitor_indices, do_predicates, visitor):
     assert len(visitor_indices) == 1
 
     # Construct the matrix sequence for the evaluation of the global coordinate.
@@ -127,7 +127,7 @@ def name_meshwidth():
 
 @kernel_cached
 @backend(interface="spatial_coordinate", name="diagonal_transformation_matrix")
-def pymbolic_spatial_coordinate_axiparallel(visitor_indices, do_predicates):
+def pymbolic_spatial_coordinate_axiparallel(visitor_indices, do_predicates, visitor):
     assert len(visitor_indices) == 1
     index, = visitor_indices
 
@@ -156,6 +156,6 @@ def pymbolic_spatial_coordinate_axiparallel(visitor_indices, do_predicates):
         if face is not None and index > face:
             iindex = iindex - 1
         from dune.perftool.sumfact.quadrature import pymbolic_quadrature_position
-        x = pymbolic_quadrature_position(iindex)
+        x = pymbolic_quadrature_position(iindex, visitor)
 
     return prim.Subscript(prim.Variable(lowcorner), (index,)) + x * prim.Subscript(prim.Variable(meshwidth), (index,)), None
diff --git a/python/dune/perftool/sumfact/quadrature.py b/python/dune/perftool/sumfact/quadrature.py
index 17d4405e..a5ede367 100644
--- a/python/dune/perftool/sumfact/quadrature.py
+++ b/python/dune/perftool/sumfact/quadrature.py
@@ -76,15 +76,12 @@ def pymbolic_base_weight():
 
 @backend(interface="quad_inames", name="sumfact")
 @kernel_cached
-def quadrature_inames(element=None):
+def quadrature_inames(element):
     if element is None:
         names = tuple("quad_{}".format(d) for d in range(local_dimension()))
     else:
         from ufl import FiniteElement
-        try:
-            assert isinstance(element, FiniteElement)
-        except:
-            from pudb import set_trace; set_trace()
+        assert isinstance(element, FiniteElement)
         from dune.perftool.pdelab.driver import FEM_name_mangling
         names = tuple("quad_{}_{}".format(FEM_name_mangling(element), d) for d in range(local_dimension()))
     domain(names, quadrature_points_per_direction())
@@ -177,7 +174,7 @@ def define_quadrature_position(name, index):
 
 
 @backend(interface="quad_pos", name="sumfact")
-def pymbolic_quadrature_position(index):
+def pymbolic_quadrature_position(index, visitor):
     # Return the non-precomputed version
     if not get_option("precompute_quadrature_info"):
         name = 'pos'
@@ -207,7 +204,12 @@ def pymbolic_quadrature_position(index):
                 kernel="operator",
                 )
 
-    return prim.Subscript(lp.symbolic.TaggedVariable(name, "operator_precomputed"), tuple(prim.Variable(i) for i in quadrature_inames()))
+    info = visitor.current_info[1]
+    if info is None:
+        element = None
+    else:
+        element = info.element
+    return prim.Subscript(lp.symbolic.TaggedVariable(name, "operator_precomputed"), tuple(prim.Variable(i) for i in quadrature_inames(element)))
 
 
 @backend(interface="qp_in_cell", name="sumfact")
-- 
GitLab