From 9ebda32583301d25e27579fe99ffaed77c76b80d Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Wed, 26 Apr 2017 13:50:30 +0200
Subject: [PATCH] Also precompute quadrature positions

---
 .../loopy/transformations/collect_rotate.py   |  2 +-
 python/dune/perftool/sumfact/geometry.py      |  2 +-
 python/dune/perftool/sumfact/quadrature.py    | 63 +++++++++++++------
 3 files changed, 47 insertions(+), 20 deletions(-)

diff --git a/python/dune/perftool/loopy/transformations/collect_rotate.py b/python/dune/perftool/loopy/transformations/collect_rotate.py
index 20bc4130..167de058 100644
--- a/python/dune/perftool/loopy/transformations/collect_rotate.py
+++ b/python/dune/perftool/loopy/transformations/collect_rotate.py
@@ -244,7 +244,7 @@ def collect_vector_data_rotate(knl):
                 shape=(ceildiv(product(s for s in arg.shape), vec_size), vec_size)
                 name = loopy_class_member(quantity,
                                           shape=shape,
-                                          dim_tags="c,vec",
+                                          dim_tags="f,vec",
                                           potentially_vectorized=True,
                                           classtag="operator",
                                           dtype=np.float64,
diff --git a/python/dune/perftool/sumfact/geometry.py b/python/dune/perftool/sumfact/geometry.py
index 9a8d4f1f..fce0ec9b 100644
--- a/python/dune/perftool/sumfact/geometry.py
+++ b/python/dune/perftool/sumfact/geometry.py
@@ -149,6 +149,6 @@ def pymbolic_spatial_coordinate_axiparallel(visitor_indices):
         if face is not None and index > face:
             iindex = iindex - 1
         from dune.perftool.sumfact.quadrature import pymbolic_quadrature_position
-        x = prim.Subscript(pymbolic_quadrature_position(), (iindex,))
+        x = pymbolic_quadrature_position(iindex)
 
     return prim.Subscript(prim.Variable(lowcorner), (index,)) + x * prim.Subscript(prim.Variable(meshwidth), (index,)), None
diff --git a/python/dune/perftool/sumfact/quadrature.py b/python/dune/perftool/sumfact/quadrature.py
index 0cff5516..d12b7aeb 100644
--- a/python/dune/perftool/sumfact/quadrature.py
+++ b/python/dune/perftool/sumfact/quadrature.py
@@ -93,7 +93,8 @@ def constructor_quad_iname(name, d, bound):
     return name
 
 
-def constructor_quadrature_inames(name):
+def constructor_quadrature_inames(dim, num1d):
+    name = "quadiname_dim{}_num{}".format(dim, num1d)
     return tuple(constructor_quad_iname(name, d, quadrature_points_per_direction()) for d in range(local_dimension()))
 
 
@@ -135,37 +136,63 @@ def quadrature_weight():
                        dtype=np.float64,
                        shape=(num1d,) * dim,
                        classtag="operator",
-                       dim_tags=",".join(["c"] * dim),
+                       dim_tags=",".join(["f"] * dim),
                        managed=True,
                        potentially_vectorized=True,
                        )
 
     # Precompute it in the constructor
-    instruction(assignee=prim.Subscript(prim.Variable(name), tuple(prim.Variable(i) for i in constructor_quadrature_inames(name))),
-                expression=prim.Product(tuple(Subscript(Variable(name_oned_quadrature_weights()), (prim.Variable(i),)) for i in constructor_quadrature_inames(name))),
-                within_inames=frozenset(constructor_quadrature_inames(name)),
+    inames = constructor_quadrature_inames(dim, num1d)
+    instruction(assignee=prim.Subscript(prim.Variable(name), tuple(prim.Variable(i) for i in inames)),
+                expression=prim.Product(tuple(Subscript(Variable(name_oned_quadrature_weights()), (prim.Variable(i),)) for i in inames)),
+                within_inames=frozenset(inames),
                 kernel="operator",
                 )
 
     return prim.Subscript(lp.symbolic.TaggedVariable(name, "operator_precomputed"), tuple(prim.Variable(i) for i in quadrature_inames()))
 
 
-def define_quadrature_position(name):
-    for i in range(local_dimension()):
-        instruction(expression=Subscript(Variable(name_oned_quadrature_points()), (Variable(quadrature_inames()[i]),)),
-                    assignee=Subscript(Variable(name), (i,)),
-                    forced_iname_deps=frozenset(quadrature_inames()),
-                    forced_iname_deps_is_final=True,
-                    tags=frozenset({"quad"}),
-                    )
+def define_quadrature_position(name, index):
+    instruction(expression=Subscript(Variable(name_oned_quadrature_points()), (Variable(quadrature_inames()[index]),)),
+                assignee=Subscript(Variable(name), (index,)),
+                forced_iname_deps=frozenset(quadrature_inames()),
+                forced_iname_deps_is_final=True,
+                tags=frozenset({"quad"}),
+                )
 
 
 @backend(interface="quad_pos", name="sumfact")
-def pymbolic_quadrature_position():
-    name = 'pos'
-    temporary_variable(name, shape=(local_dimension(),), shape_impl=("fv",))
-    define_quadrature_position(name)
-    return Variable(name)
+def pymbolic_quadrature_position(index):
+    # Return the non-precomputed version
+    if not get_option("precompute_quadrature_info"):
+        name = 'pos'
+        temporary_variable(name, shape=(local_dimension(),), shape_impl=("fv",))
+        define_quadrature_position(name, index)
+        return prim.Subscript(prim.Variable(name), (index,))
+
+    assert isinstance(index, int)
+    dim = local_dimension()
+    num1d = quadrature_points_per_direction()
+    name = "quad_points_dim{}_num{}_dir{}".format(dim, num1d, index)
+
+    loopy_class_member(name,
+                       dtype=np.float64,
+                       shape=(num1d,) * dim,
+                       classtag="operator",
+                       dim_tags=",".join(["f"] * dim),
+                       managed=True,
+                       potentially_vectorized=True,
+                       )
+
+    # Precompute it in the constructor
+    inames = constructor_quadrature_inames(dim, num1d)
+    instruction(assignee=prim.Subscript(prim.Variable(name), tuple(prim.Variable(i) for i in inames)),
+                expression=Subscript(Variable(name_oned_quadrature_points()), (prim.Variable(inames[index]))),
+                within_inames=frozenset(inames),
+                kernel="operator",
+                )
+
+    return prim.Subscript(lp.symbolic.TaggedVariable(name, "operator_precomputed"), tuple(prim.Variable(i) for i in quadrature_inames()))
 
 
 @backend(interface="qp_in_cell", name="sumfact")
-- 
GitLab