From 1529cb7363dcf1d0290b33d1d48d12d6cd610a47 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Wed, 26 Apr 2017 11:18:45 +0200
Subject: [PATCH] Precompute quadrature weights on a localoperator level

---
 python/dune/perftool/options.py              |  1 +
 python/dune/perftool/pdelab/localoperator.py | 13 ------
 python/dune/perftool/sumfact/quadrature.py   | 45 ++++++++++++++++++--
 3 files changed, 43 insertions(+), 16 deletions(-)

diff --git a/python/dune/perftool/options.py b/python/dune/perftool/options.py
index 7d7712d9..8dff610b 100644
--- a/python/dune/perftool/options.py
+++ b/python/dune/perftool/options.py
@@ -65,6 +65,7 @@ class PerftoolOptionsArray(ImmutableRecord):
     # Arguments that are mainly to be set by logic depending on other options
     max_vector_width = PerftoolOption(default=256, helpstr=None)
     unroll_dimension_loops = PerftoolOption(default=False, helpstr="whether loops over the gemetric dimension should be unrolled.")
+    precompute_quadrature_info = PerftoolOption(default=True, helpstr="whether loops over the gemetric dimension should be unrolled.")
 
 
 # Until more sophisticated logic is needed, we keep the actual option data in this module
diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py
index 3bae1ea9..642ab87b 100644
--- a/python/dune/perftool/pdelab/localoperator.py
+++ b/python/dune/perftool/pdelab/localoperator.py
@@ -414,19 +414,6 @@ def visit_integrals(integrals):
                 if name.startswith("cse"):
                     set_subst_rule(name, expr)
 
-            # Ensure CSE on detjac * quadrature weight
-            domain = accterm.term.ufl_domain()
-            if measure == "cell":
-                set_subst_rule("integration_factor_cell1",
-                               uc.QuadratureWeight(domain) * uc.Abs(uc.JacobianDeterminant(domain)))
-                set_subst_rule("integration_factor_cell2",
-                               uc.Abs(uc.JacobianDeterminant(domain)) * uc.QuadratureWeight(domain))
-            else:
-                set_subst_rule("integration_factor_facet1",
-                               uc.FacetJacobianDeterminant(domain) * uc.QuadratureWeight(domain))
-                set_subst_rule("integration_factor_facet2",
-                               uc.QuadratureWeight(domain) * uc.FacetJacobianDeterminant(domain))
-
             get_backend(interface="accum_insn")(visitor, accterm, measure, subdomain_id)
 
 
diff --git a/python/dune/perftool/sumfact/quadrature.py b/python/dune/perftool/sumfact/quadrature.py
index 5e6935c3..e3d58fba 100644
--- a/python/dune/perftool/sumfact/quadrature.py
+++ b/python/dune/perftool/sumfact/quadrature.py
@@ -4,6 +4,7 @@ from dune.perftool.generation import (backend,
                                       get_global_context_value,
                                       iname,
                                       instruction,
+                                      loopy_class_member,
                                       temporary_variable,
                                       )
 
@@ -15,6 +16,7 @@ from dune.perftool.pdelab.argument import name_accumulation_variable
 from dune.perftool.pdelab.geometry import (dimension_iname,
                                            local_dimension,
                                            )
+from dune.perftool.options import get_option
 
 from loopy import CallMangleInfo
 from loopy.symbolic import FunctionIdentifier
@@ -26,7 +28,8 @@ from pymbolic.primitives import (Call,
                                  Variable,
                                  )
 
-import numpy
+import pymbolic.primitives as prim
+import numpy as np
 
 
 def nest_quadrature_loops(kernel, inames):
@@ -58,7 +61,7 @@ class BaseWeight(FunctionIdentifier):
 @function_mangler
 def base_weight_function_mangler(target, func, dtypes):
     if isinstance(func, BaseWeight):
-        return CallMangleInfo(func.name, (NumpyType(numpy.float64),), ())
+        return CallMangleInfo(func.name, (NumpyType(np.float64),), ())
 
 
 def pymbolic_base_weight():
@@ -82,6 +85,17 @@ def quadrature_inames():
     return tuple(sumfact_quad_iname(d, quadrature_points_per_direction()) for d in range(local_dimension()))
 
 
+@iname(kernel="operator")
+def constructor_quad_iname(name, d, bound):
+    name = "{}_{}".format(name, d)
+    domain(name, quadrature_points_per_direction(), kernel="operator")
+    return name
+
+
+def constructor_quadrature_inames(name):
+    return tuple(constructor_quad_iname(name, d, quadrature_points_per_direction()) for d in range(local_dimension()))
+
+
 def define_recursive_quadrature_weight(name, dir):
     iname = quadrature_inames()[dir]
     temporary_variable(name, shape=(), shape_impl=())
@@ -107,7 +121,32 @@ def recursive_quadrature_weight(dir=0):
 
 
 def quadrature_weight():
-    return recursive_quadrature_weight()
+    # Return non-precomputed version
+    if not get_option("precompute_quadrature_info"):
+        return recursive_quadrature_weight()
+
+    dim = local_dimension()
+    num1d = quadrature_points_per_direction()
+    name = "quad_weights_dim{}_num{}".format(dim, num1d)
+
+    # Add a class member
+    loopy_class_member(name,
+                       dtype=np.float64,
+                       shape=(num1d,) * dim,
+                       classtag="operator",
+                       dim_tags=",".join(["c"] * dim),
+                       managed=True,
+                       potentially_vectorized=True,
+                       )
+
+    # Precompute it in the constructor
+    instruction(assignee=prim.Subscript(prim.Variable(name), tuple(prim.Variable(i) for i in constructor_quadrature_inames(name))),
+                expression=prim.Product(tuple(Subscript(Variable(name_oned_quadrature_weights()), (prim.Variable(i),)) for i in constructor_quadrature_inames(name))),
+                within_inames=frozenset(constructor_quadrature_inames(name)),
+                kernel="operator",
+                )
+
+    return prim.Subscript(prim.Variable(name), tuple(prim.Variable(i) for i in quadrature_inames()))
 
 
 def define_quadrature_position(name):
-- 
GitLab