From 25807a973a59e4fee926004f5d15c3052335fb54 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Thu, 8 Dec 2016 17:48:33 +0100
Subject: [PATCH] Shrink the quadrature loop on intersection by one iname

---
 python/dune/perftool/sumfact/basis.py      | 10 ++---
 python/dune/perftool/sumfact/quadrature.py |  8 ++--
 python/dune/perftool/sumfact/sumfact.py    | 43 +++++++++++++++-------
 3 files changed, 38 insertions(+), 23 deletions(-)

diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py
index b0f95ec1..7274e58a 100644
--- a/python/dune/perftool/sumfact/basis.py
+++ b/python/dune/perftool/sumfact/basis.py
@@ -19,7 +19,8 @@ from dune.perftool.sumfact.amatrix import (AMatrix,
                                            name_theta,
                                            quadrature_points_per_direction,
                                            )
-from dune.perftool.sumfact.sumfact import (setup_theta,
+from dune.perftool.sumfact.sumfact import (get_facedir,
+                                           setup_theta,
                                            SumfactKernel,
                                            sumfact_iname,
                                            sum_factorization_kernel,
@@ -62,7 +63,7 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor):
     insn_dep = None
     for i in range(dim):
         # Construct the matrix sequence for this sum factorization
-        a_matrices = construct_amatrix_sequence(derivative=i)
+        a_matrices = construct_amatrix_sequence(derivative=i, face=get_facedir(restriction))
 
         # Get the vectorization info. If this happens during the dry run, we get dummies
         from dune.perftool.sumfact.vectorization import get_vectorization_info
@@ -126,7 +127,7 @@ def pymbolic_trialfunction(element, restriction, component, visitor):
     dim = world_dimension()
 
     # Construct the matrix sequence for this sum factorization
-    a_matrices = construct_amatrix_sequence()
+    a_matrices = construct_amatrix_sequence(face=get_facedir(restriction))
 
     # Get the vectorization info. If this happens during the dry run, we get dummies
     from dune.perftool.sumfact.vectorization import get_vectorization_info
@@ -153,6 +154,7 @@ def pymbolic_trialfunction(element, restriction, component, visitor):
                                       1,
                                       preferred_position=None,
                                       insn_dep=frozenset({Writes(input)}),
+                                      outshape=tuple(mat.rows for mat in a_matrices if mat.rows != 1),
                                       )
 
     if index:
@@ -230,9 +232,7 @@ def evaluate_reference_gradient(element, name, restriction):
         calls[i] = prim.Subscript(prim.Variable(dtheta), (prim.Variable(quad_inames[i]), prim.Variable(inames[i])))
         calls = tuple(calls)
 
-        # assignee = prim.Subscript(prim.Variable(name), tuple(prim.Variable(0)))
         assignee = prim.Subscript(prim.Variable(name), (i,))
-        # assignee = prim.Variable(name)
         expression = prim.Product(calls)
 
         instruction(assignee=assignee,
diff --git a/python/dune/perftool/sumfact/quadrature.py b/python/dune/perftool/sumfact/quadrature.py
index dbc870e1..08a5d7e5 100644
--- a/python/dune/perftool/sumfact/quadrature.py
+++ b/python/dune/perftool/sumfact/quadrature.py
@@ -71,15 +71,15 @@ def pymbolic_base_weight():
 
 
 @iname
-def sumfact_quad_iname(d, context):
-    name = "quad_{}_{}".format(context, d)
+def sumfact_quad_iname(d, bound):
+    name = "quad_{}".format(d)
     domain(name, quadrature_points_per_direction())
     return name
 
 
 @backend(interface="quad_inames", name="sumfact")
-def quadrature_inames(context=''):
-    return tuple(sumfact_quad_iname(d, context) for d in range(local_dimension()))
+def quadrature_inames():
+    return tuple(sumfact_quad_iname(d, quadrature_points_per_direction()) for d in range(local_dimension()))
 
 
 def define_recursive_quadrature_weight(name, dir):
diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py
index 9a29deb3..a4377817 100644
--- a/python/dune/perftool/sumfact/sumfact.py
+++ b/python/dune/perftool/sumfact/sumfact.py
@@ -34,6 +34,9 @@ from dune.perftool.pdelab.restriction import restricted_name
 from dune.perftool.pdelab.spaces import (name_lfs,
                                          name_lfs_bound,
                                          )
+from dune.perftool.pdelab.geometry import (local_dimension,
+                                           world_dimension,
+                                           )
 from dune.perftool.sumfact.amatrix import (AMatrix,
                                            LargeAMatrix,
                                            quadrature_points_per_direction,
@@ -59,6 +62,17 @@ import pymbolic.primitives as prim
 from pytools import product
 
 
+def get_facedir(restriction):
+    from dune.perftool.pdelab.restriction import Restriction
+    if restriction == Restriction.NEGATIVE or get_global_context_value("integral_type") == "exterior_facet":
+        return get_global_context_value("facedir_s")
+    if restriction == Restriction.POSITIVE:
+        return get_global_context_value("facedir_n")
+    if restriction == Restriction.NONE:
+        return None
+    assert False
+
+
 @iname
 def _sumfact_iname(bound, _type, count):
     name = "sf_{}_{}".format(_type, str(count))
@@ -101,9 +115,8 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
     if pymbolic_expr == 0:
         return
 
-    # Get geometric dimension
-    formdata = get_global_context_value('formdata')
-    dim = formdata.geometric_dimension
+    dim = world_dimension()
+    facedir = get_facedir(accterm.argument.restriction)
 
     # Collect buffers we need
     buffers = []
@@ -125,6 +138,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
         # Construct the matrix sequence for this sum factorization
         a_matrices = construct_amatrix_sequence(transpose=True,
                                                 derivative=i if accterm.argument.index else None,
+                                                face=facedir,
                                                 )
 
         # Get the vectorization info. If this happens during the dry run, we get dummies
@@ -132,8 +146,8 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
         a_matrices, buffer, input, index = get_vectorization_info(a_matrices)
 
         # Initialize a base storage for this buffer and get a temporay pointing to it
-        shape = tuple(mat.cols for mat in a_matrices)
-        dim_tags = ",".join(['f'] * dim)
+        shape = tuple(mat.cols for mat in a_matrices if mat.cols != 1)
+        dim_tags = ",".join(['f'] * local_dimension())
         if index is not None:
             shape = shape + (4,)
             dim_tags = dim_tags + ",c"
@@ -226,7 +240,11 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
 
 
 @generator_factory(item_tags=("sumfactkernel",), context_tags=("kernel",), cache_key_generator=lambda a, b, s, **kw: (a, b, s))
-def sum_factorization_kernel(a_matrices, buf, stage, insn_dep=frozenset({}), additional_inames=frozenset({}), preferred_position=None):
+def sum_factorization_kernel(a_matrices, buf, stage,
+                             insn_dep=frozenset({}),
+                             additional_inames=frozenset({}),
+                             preferred_position=None,
+                             outshape=None):
     """
     Calculate a sum factorization matrix product.
 
@@ -310,19 +328,16 @@ def sum_factorization_kernel(a_matrices, buf, stage, insn_dep=frozenset({}), add
                                           )
                               })
 
-    # Get geometric dimension
-    formdata = get_global_context_value('formdata')
-    dim = formdata.geometric_dimension
-
-    out_shape = tuple(mat.rows for mat in a_matrices)
-    dim_tags = ",".join(['f'] * dim)
+    if outshape is None:
+        outshape = tuple(mat.rows for mat in a_matrices)
+    dim_tags = ",".join(['f'] * len(outshape))
 
     if next(iter(a_matrices)).vectorized:
-        out_shape = out_shape + vec_shape
+        outshape = outshape + vec_shape
         dim_tags = dim_tags + ",c"
 
     out = get_buffer_temporary(buf,
-                               shape=out_shape,
+                               shape=outshape,
                                dim_tags=dim_tags,
                                )
     silenced_warning('read_no_write({})'.format(out))
-- 
GitLab