From 0a3f6ad3ce1f05b708883e565d375bd6d0f8ab35 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Mon, 24 Apr 2017 16:30:06 +0200
Subject: [PATCH] Handle sumfact geometry evaluation the new way

---
 python/dune/perftool/pdelab/__init__.py      |  2 +-
 python/dune/perftool/sumfact/__init__.py     |  6 ++-
 python/dune/perftool/sumfact/accumulation.py |  2 +-
 python/dune/perftool/sumfact/geometry.py     | 49 +++++++-------------
 python/dune/perftool/ufl/visitor.py          |  2 +-
 5 files changed, 23 insertions(+), 38 deletions(-)

diff --git a/python/dune/perftool/pdelab/__init__.py b/python/dune/perftool/pdelab/__init__.py
index 4afe9a97..ac9ad0e4 100644
--- a/python/dune/perftool/pdelab/__init__.py
+++ b/python/dune/perftool/pdelab/__init__.py
@@ -99,7 +99,7 @@ class PDELabInterface(object):
     # Geometry related generator functions
     #
 
-    def pymbolic_spatial_coordinate(self):
+    def pymbolic_spatial_coordinate(self, visitor=None):
         return to_global(pymbolic_quadrature_position())
 
     def name_facet_jacobian_determinant(self):
diff --git a/python/dune/perftool/sumfact/__init__.py b/python/dune/perftool/sumfact/__init__.py
index e0ec47de..6a07678a 100644
--- a/python/dune/perftool/sumfact/__init__.py
+++ b/python/dune/perftool/sumfact/__init__.py
@@ -59,6 +59,8 @@ class SumFactInterface(PDELabInterface):
 #        from dune.perftool.pdelab.geometry import to_global
 #        return to_global(pymbolic_quadrature_position())
 
-    def pymbolic_spatial_coordinate(self):
+    def pymbolic_spatial_coordinate(self, visitor=None):
         from dune.perftool.sumfact.geometry import pymbolic_spatial_coordinate
-        return pymbolic_spatial_coordinate()
+        ret, indices = pymbolic_spatial_coordinate(visitor.indices)
+        visitor.indices = indices
+        return ret
diff --git a/python/dune/perftool/sumfact/accumulation.py b/python/dune/perftool/sumfact/accumulation.py
index ca1b52f5..b508071d 100644
--- a/python/dune/perftool/sumfact/accumulation.py
+++ b/python/dune/perftool/sumfact/accumulation.py
@@ -205,7 +205,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
                                   forced_iname_deps=frozenset(quadrature_inames() + jacobian_inames),
                                   forced_iname_deps_is_final=True,
                                   tags=frozenset({"quadvec"}).union(vectag),
-                                  depends_on=frozenset({deps}).union(timer_dep)
+                                  depends_on=frozenset({deps}).union(timer_dep).union(frozenset({lp.match.Tagged("sumfact_stage1")}))
                                   )
 
         if insn_dep is None:
diff --git a/python/dune/perftool/sumfact/geometry.py b/python/dune/perftool/sumfact/geometry.py
index afbf0edb..02e7a75b 100644
--- a/python/dune/perftool/sumfact/geometry.py
+++ b/python/dune/perftool/sumfact/geometry.py
@@ -63,7 +63,9 @@ class GeoCornersInput(SumfactKernelInputBase, ImmutableRecord):
 
 
 @kernel_cached
-def pymbolic_spatial_coordinate():
+def pymbolic_spatial_coordinate(visitor_indices):
+    assert len(visitor_indices) == 1
+
     # Construct the matrix sequence for the evaluation of the global coordinate.
     # We need to manually construct this one, because on facets, we want to use the
     # geometry embedding of the facet into the global space directly without going
@@ -72,37 +74,18 @@ def pymbolic_spatial_coordinate():
     from dune.perftool.sumfact.tabulation import quadrature_points_per_direction, BasisTabulationMatrix
     quadrature_size = quadrature_points_per_direction()
     matrix_sequence = (BasisTabulationMatrix(quadrature_size=quadrature_size, basis_size=2),) * local_dimension()
+    inp = GeoCornersInput(visitor_indices[0])
 
-    expressions = []
-    insn_dep = frozenset()
-    for i in range(world_dimension()):
-        inp = GeoCornersInput(i)
-
-        from dune.perftool.sumfact.symbolic import SumfactKernel
-        sf = SumfactKernel(matrix_sequence=matrix_sequence,
-                           input=inp,
-                           )
-
-        vsf = attach_vectorization_info(sf)
-
-        # Add a sum factorization kernel that implements the evaluation of
-        # the basis functions at quadrature points (stage 1)
-        from dune.perftool.sumfact.realization import realize_sum_factorization_kernel
-        var, insn_dep = realize_sum_factorization_kernel(vsf.copy(insn_dep=vsf.insn_dep.union(insn_dep)))
-
-        expressions.append(prim.Subscript(var, vsf.quadrature_index(sf)))
-
-    # Return an indexable temporary with the results!
-    name = "pos_global"
-    temporary_variable(name, shape=(world_dimension(),))
-    for i, expr in enumerate(expressions):
-        assignee = prim.Subscript(prim.Variable(name), (i,))
-        instruction(assignee=assignee,
-                    expression=expr,
-                    within_inames=frozenset(get_backend("quad_inames")()),
-                    within_inames_is_final=True,
-                    depends_on=insn_dep,
-                    tags=frozenset({"quad"}),
-                    )
+    from dune.perftool.sumfact.symbolic import SumfactKernel
+    sf = SumfactKernel(matrix_sequence=matrix_sequence,
+                       input=inp,
+                       )
+
+    vsf = attach_vectorization_info(sf)
+
+    # Add a sum factorization kernel that implements the evaluation of
+    # the basis functions at quadrature points (stage 1)
+    from dune.perftool.sumfact.realization import realize_sum_factorization_kernel
+    var, _ = realize_sum_factorization_kernel(vsf)
 
-    return prim.Variable(name)
+    return prim.Subscript(var, vsf.quadrature_index(sf)), None
diff --git a/python/dune/perftool/ufl/visitor.py b/python/dune/perftool/ufl/visitor.py
index 280c1d88..6c7ddb1a 100644
--- a/python/dune/perftool/ufl/visitor.py
+++ b/python/dune/perftool/ufl/visitor.py
@@ -355,7 +355,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
         if get_global_context_value("driver", False):
             return prim.Variable("x")
         else:
-            return self.interface.pymbolic_spatial_coordinate()
+            return self.interface.pymbolic_spatial_coordinate(self)
 
     def facet_normal(self, o):
         # The normal must be restricted to be well-defined
-- 
GitLab