From 8881feb64a94669512c0e641ca2745202913f0f9 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Fri, 31 Mar 2017 14:02:24 +0200
Subject: [PATCH] Cleanup stage 1 calling scope

---
 python/dune/perftool/sumfact/basis.py       | 99 ++++-----------------
 python/dune/perftool/sumfact/realization.py |  7 +-
 python/dune/perftool/sumfact/sumfact.py     |  4 +-
 3 files changed, 25 insertions(+), 85 deletions(-)

diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py
index 5ed69496..892cfd12 100644
--- a/python/dune/perftool/sumfact/basis.py
+++ b/python/dune/perftool/sumfact/basis.py
@@ -6,6 +6,7 @@ multiplication with the test function is part of the sum factorization kernel.
 
 from dune.perftool.generation import (backend,
                                       domain,
+                                      get_backend,
                                       get_counted_variable,
                                       get_counter,
                                       get_global_context_value,
@@ -65,15 +66,19 @@ def pymbolic_coefficient_gradient(element, restriction, component, coeff_func, v
     shape_impl = ('arr',) * rank
     temporary_variable(name, shape=shape, shape_impl=shape_impl)
 
-    dim = world_dimension()
+    # Whether direct indexing into the output is possible. This happens
+    # if the positioning within a SIMD vectors coincides with the index!
+    direct_indexing_is_possible = True
+
     buffers = []
-    for i in range(dim):
+    for i in range(world_dimension()):
         # Construct the matrix sequence for this sum factorization
         a_matrices = construct_amatrix_sequence(derivative=i,
                                                 facedir=get_facedir(restriction),
                                                 facemod=get_facemod(restriction),
                                                 )
 
+        # The sum factorization kernel object gathering all relevant information
         sf = SumfactKernel(a_matrices=a_matrices,
                            restriction=restriction,
                            preferred_position=i,
@@ -85,43 +90,12 @@ def pymbolic_coefficient_gradient(element, restriction, component, coeff_func, v
         from dune.perftool.sumfact.vectorization import attach_vectorization_info
         sf = attach_vectorization_info(sf)
 
-        # Extract again, for compatibility
-        # TODO away!
-        a_matrices = sf.a_matrices
-        buf = sf.buffer
-        inp = sf.input
-        index = sf.index
-        padding = sf.padding
-
-#        if buf is None:
-#            buf = get_counted_variable("buffer")
-#        if inp is None:
-#            inp = get_counted_variable("input")
-#
-#        # Initialize the buffer for the sum fact kernel
-#        shape = (product(mat.cols for mat in a_matrices),)
-#        if index is not None:
-#            shape = shape + (4,)
-#        inp = initialize_buffer(buf).get_temporary(shape=shape,
-#                                                   name=inp,
-#                                                   )
-#        insn_dep = frozenset({Writes(inp)})
-
-        if get_option('fastdg'):
-            # Name of direct input, shape and globalarg is set in sum_factorization_kernel
-            direct_input = coeff_func(restriction)
-        else:
-            direct_input = None
-            # Setup the input!
-            #setup_theta(inp, element, restriction, component, index, coeff_func)
-
         # Add a sum factorization kernel that implements the
         # evaluation of the gradients of basis functions at quadrature
         # points (stage 1)
         from dune.perftool.sumfact.realization import realize_sum_factorization_kernel
         var, insn_dep = realize_sum_factorization_kernel(sf,
-                                                 outshape=tuple(mat.rows for mat in a_matrices if mat.face is None),
-                                                 direct_input=direct_input,
+                                                 outshape=tuple(mat.rows for mat in sf.a_matrices if mat.face is None),
                                                  )
 
         buffers.append(var)
@@ -129,17 +103,15 @@ def pymbolic_coefficient_gradient(element, restriction, component, coeff_func, v
     # Check whether we want to return early with something that has the indexing
     # already handled! This happens with vectorization when the index coincides
     # with the position in the vector register.
-    if index:
+    if direct_indexing_is_possible:
         assert len(visitor.indices) == 1
         return maybe_wrap_subscript(var, tuple(prim.Variable(i) for i in quadrature_inames()) + visitor.indices), None
 
     # TODO this should be quite conditional!!!
     for i, buf in enumerate(buffers):
         # Write solution from sumfactorization to gradient variable
-        from pymbolic.primitives import Subscript, Variable
-        from dune.perftool.generation import get_backend
-        assignee = Subscript(Variable(name), i)
-        expression = Subscript(buf, tuple(Variable(i) for i in quadrature_inames()))
+        assignee = prim.Subscript(prim.Variable(name), i)
+        expression = prim.Subscript(buf, tuple(prim.Variable(i) for i in quadrature_inames()))
         instruction(assignee=assignee,
                     expression=expression,
                     forced_iname_deps=frozenset(get_backend("quad_inames")()),
@@ -151,9 +123,6 @@ def pymbolic_coefficient_gradient(element, restriction, component, coeff_func, v
 
 @kernel_cached
 def pymbolic_coefficient(element, restriction, component, coeff_func, visitor):
-    # Get geometric dimension
-    dim = world_dimension()
-
     # Construct the matrix sequence for this sum factorization
     a_matrices = construct_amatrix_sequence(facedir=get_facedir(restriction),
                                             facemod=get_facemod(restriction),)
@@ -169,49 +138,15 @@ def pymbolic_coefficient(element, restriction, component, coeff_func, visitor):
     from dune.perftool.sumfact.vectorization import attach_vectorization_info
     sf = attach_vectorization_info(sf)
 
-    # Extract again, for compatibility
-    # TODO away!
-    a_matrices = sf.a_matrices
-    buf = sf.buffer
-    inp = sf.input
-    index = sf.index
-    padding = sf.padding
-#
-#    if buf is None:
-#        buf = get_counted_variable("buffer")
-#    if inp is None:
-#        inp = get_counted_variable("input")
-#
-#    # Flip flop buffers for sumfactorization
-#    shape = (product(mat.cols for mat in a_matrices),)
-#    if index is not None:
-#        shape = shape + (4,)
-#    initialize_buffer(buf).get_temporary(shape=shape,
-#                                      name=inp,
-#                                      )
-
-    if get_option('fastdg'):
-        # Name of direct input, shape and globalarg is set in sum_factorization_kernel
-        direct_input = coeff_func(restriction)
-    else:
-        direct_input = None
-        # Setup the input!
-        # setup_theta(inp, element, restriction, component, index, coeff_func)
-
     # Add a sum factorization kernel that implements the evaluation of
     # the basis functions at quadrature points (stage 1)
-    if not get_global_context_value("dry_run", False):
-        from dune.perftool.sumfact.realization import realize_sum_factorization_kernel
-        var, _ = realize_sum_factorization_kernel(sf,
-#                                                  insn_dep=frozenset({Writes(inp)}),
-                                                  outshape=tuple(mat.rows for mat in a_matrices if mat.face is None),
-                                                  direct_input=direct_input,
-                                                  )
-    else:
-        var = sf
+    from dune.perftool.sumfact.realization import realize_sum_factorization_kernel
+    var, _ = realize_sum_factorization_kernel(sf,
+                                              outshape=tuple(mat.rows for mat in sf.a_matrices if mat.face is None),
+                                              )
 
-    if index:
-        index = (index,)
+    if sf.index:
+        index = (sf.index,)
     else:
         index = ()
 
diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py
index 2d789802..0ed1641e 100644
--- a/python/dune/perftool/sumfact/realization.py
+++ b/python/dune/perftool/sumfact/realization.py
@@ -69,7 +69,7 @@ def _realize_input(sf, insn_dep):
 @generator_factory(item_tags=("sumfactkernel",),
                    context_tags=("kernel",),
                    cache_key_generator=lambda s, **kw: s.cache_key)
-def _realize_sum_factorization_kernel(sf, insn_dep=frozenset(), outshape=None, direct_input=None, direct_output=None):
+def _realize_sum_factorization_kernel(sf, insn_dep=frozenset(), outshape=None, direct_output=None):
     # Unify the insn_dep parameter to be a frozenset
     if isinstance(insn_dep, str):
         insn_dep = frozenset({insn_dep})
@@ -91,6 +91,11 @@ def _realize_sum_factorization_kernel(sf, insn_dep=frozenset(), outshape=None, d
     if sf.input:
         insn_dep = insn_dep.union(frozenset({lp.match.Writes(sf.input)}))
 
+    # Construct the direct_input for the FastDG case
+    direct_input = None
+    if get_option('fastdg') and sf.stage == 1:
+        direct_input = sf.coeff_func(sf.restriction)
+
     # Prepare some dim_tags/shapes for later use
     ftags = ",".join(["f"] * sf.length)
     novec_ftags = ftags
diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py
index 50ac3c02..6106246c 100644
--- a/python/dune/perftool/sumfact/sumfact.py
+++ b/python/dune/perftool/sumfact/sumfact.py
@@ -172,7 +172,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
             inp = get_counted_variable("input")
 
         # Initialize a base storage for this buffer and get a temporay pointing to it
-        shape = tuple(mat.cols for mat in a_matrices if mat.face is None)
+        shape = tuple(mat.cols for mat in sf.a_matrices if mat.face is None)
         dim_tags = ",".join(['f'] * local_dimension())
         if index is not None:
             shape = shape + (4,)
@@ -242,7 +242,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
                                               within_inames=frozenset(visitor.inames))})
 
         inames = tuple(accum_iname((accterm.argument.restriction, restriction), mat.rows, i)
-                       for i, mat in enumerate(a_matrices))
+                       for i, mat in enumerate(sf.a_matrices))
 
         # Collect the lfs and lfs indices for the accumulate call
         test_lfs = determine_accumulation_space(accterm.argument.expr, 0, measure)
-- 
GitLab