From 5504e7d71e751ffcfe7beb09f547155e6b18c123 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Fri, 2 Dec 2016 16:53:12 +0100
Subject: [PATCH] Adjust infrastructure for preprocessing vectorization info
 collection

---
 python/dune/perftool/generation/__init__.py   |  4 +-
 python/dune/perftool/generation/counter.py    |  5 +
 python/dune/perftool/loopy/buffer.py          |  9 +-
 python/dune/perftool/loopy/symbolic.py        | 21 +---
 python/dune/perftool/pdelab/localoperator.py  | 20 +++-
 python/dune/perftool/sumfact/basis.py         | 61 +++++++++---
 python/dune/perftool/sumfact/sumfact.py       | 68 ++++++-------
 python/dune/perftool/sumfact/vectorization.py | 98 +++++++++++++++++++
 8 files changed, 209 insertions(+), 77 deletions(-)
 create mode 100644 python/dune/perftool/sumfact/vectorization.py

diff --git a/python/dune/perftool/generation/__init__.py b/python/dune/perftool/generation/__init__.py
index 93d3981f..3e043645 100644
--- a/python/dune/perftool/generation/__init__.py
+++ b/python/dune/perftool/generation/__init__.py
@@ -4,7 +4,9 @@ from dune.perftool.generation.backend import (backend,
                                               get_backend,
                                               )
 
-from dune.perftool.generation.counter import get_counter
+from dune.perftool.generation.counter import (get_counter,
+                                              get_counted_variable,
+                                              )
 
 from dune.perftool.generation.cache import (cached,
                                             generator_factory,
diff --git a/python/dune/perftool/generation/counter.py b/python/dune/perftool/generation/counter.py
index 607d7eaa..77bc25b1 100644
--- a/python/dune/perftool/generation/counter.py
+++ b/python/dune/perftool/generation/counter.py
@@ -7,3 +7,8 @@ def get_counter(identifier):
     count = _counts.setdefault(identifier, 0)
     _counts[identifier] = _counts[identifier] + 1
     return count
+
+
+def get_counted_variable(identifier):
+    ret = "{}_{}".format(identifier, str(get_counter(identifier)).zfill(4))
+    return ret
diff --git a/python/dune/perftool/loopy/buffer.py b/python/dune/perftool/loopy/buffer.py
index af0fc697..84076012 100644
--- a/python/dune/perftool/loopy/buffer.py
+++ b/python/dune/perftool/loopy/buffer.py
@@ -1,5 +1,6 @@
 from dune.perftool.error import PerftoolLoopyError
 from dune.perftool.generation import (generator_factory,
+                                      get_counted_variable,
                                       get_global_context_value,
                                       temporary_variable,
                                       )
@@ -14,9 +15,6 @@ class FlipFlopBuffer(object):
         # Initialize the counter that switches between the base storages!
         self._current = 0
 
-        # Initialize a total counter for the issued temporaries
-        self._counter = 0
-
         # Generate the base storage names
         self.base_storage = tuple("{}_base_{}".format(self.identifier, i) for i in range(self.num))
 
@@ -31,8 +29,9 @@ class FlipFlopBuffer(object):
         base = self.base_storage[self._current]
 
         # Construct a temporary name
-        name = "{}_{}".format(self.identifier, self._counter)
-        self._counter = self._counter + 1
+        name = kwargs.pop("name", get_counted_variable(self.identifier))
+        if name is None:
+            from pudb import set_trace; set_trace()
 
         # Get geometric dimension
         formdata = get_global_context_value('formdata')
diff --git a/python/dune/perftool/loopy/symbolic.py b/python/dune/perftool/loopy/symbolic.py
index 23415aac..1ba315a3 100644
--- a/python/dune/perftool/loopy/symbolic.py
+++ b/python/dune/perftool/loopy/symbolic.py
@@ -18,34 +18,23 @@ class SumfactKernel(prim.Variable):
     def __init__(self,
                  a_matrices,
                  buffer,
-                 stage=1,
-                 insn_dep=frozenset({}),
-                 additional_inames=frozenset({}),
-                 preferred_interleaving_position=0,
-                 setup_method=None,
-                 input_temporary=None,
+                 stage,
+                 preferred_position,
                  ):
         self.a_matrices = a_matrices
         self.buffer = buffer
         self.stage = stage
-        self.insn_dep = insn_dep
-        self.additional_inames = additional_inames
-        self.preferred_interleaving_position = preferred_interleaving_position
-        self.setup_method = setup_method
-        self.input_temporary = input_temporary
-
-        if setup_method is not None:
-            assert isinstance(setup_method, tuple) and len(setup_method) == 2
+        self.preferred_position = preferred_position
 
         prim.Variable.__init__(self, "SUMFACT")
 
     def __getinitargs__(self):
-        return (self.a_matrices, self.buffer, self.stage, self.insn_dep, self.additional_inames, self.preferred_interleaving_position, self.setup_method, self.input_temporary)
+        return (self.a_matrices, self.buffer, self.stage, self.preferred_position)
 
     def stringifier(self):
         return lp.symbolic.StringifyMapper
 
-    init_arg_names = ("a_matrices", "buffer", "stage", "insn_dep", "additional_inames", "preferred_interleaving_position", "setup_method", "input_temporary")
+    init_arg_names = ("a_matrices", "buffer", "stage", "preferred_position")
 
     mapper_method = "map_sumfact_kernel"
 
diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py
index 8610952b..467a8115 100644
--- a/python/dune/perftool/pdelab/localoperator.py
+++ b/python/dune/perftool/pdelab/localoperator.py
@@ -423,7 +423,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
                 )
 
 
-def generate_kernel(integrals):
+def visit_integrals(integrals):
     for integral in integrals:
         integrand = integral.integrand()
         measure = integral.integral_type()
@@ -477,11 +477,25 @@ def generate_kernel(integrals):
             visitor = UFL2LoopyVisitor(interface, measure, indexmap)
             get_backend(interface="accum_insn")(visitor, term, measure, subdomain_id)
 
-    knl = extract_kernel_from_cache("kernel_default")
 
-    # All items with the kernel tags can be destroyed once a kernel has been generated
+def generate_kernel(integrals):
+    # Visit all integrals once to collect information (dry-run)!
+    with global_context(dry_run=True):
+        visit_integrals(integrals)
+
+    # Now perform some checks on what should be done
+    from dune.perftool.sumfact.vectorization import decide_vectorization_strategy
+    decide_vectorization_strategy()
+
+    # Delete the cache contents and do the real thing!
     from dune.perftool.generation import delete_cache_items
     delete_cache_items("kernel_default")
+    visit_integrals(integrals)
+    knl = extract_kernel_from_cache("kernel_default")
+    delete_cache_items("kernel_default")
+
+    # Clean the cache from any data collected after the dry run
+    delete_cache_items("dryrundata")
 
     return knl
 
diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py
index 390acb59..fb3c3ff3 100644
--- a/python/dune/perftool/sumfact/basis.py
+++ b/python/dune/perftool/sumfact/basis.py
@@ -21,6 +21,7 @@ from dune.perftool.sumfact.amatrix import (AMatrix,
 from dune.perftool.sumfact.sumfact import (setup_theta,
                                            SumfactKernel,
                                            sumfact_iname,
+                                           sum_factorization_kernel,
                                            )
 from dune.perftool.sumfact.quadrature import quadrature_inames
 from dune.perftool.loopy.buffer import initialize_buffer
@@ -29,6 +30,8 @@ from dune.perftool.pdelab.restriction import restricted_name
 
 from pytools import product
 
+from loopy.match import Writes
+
 import pymbolic.primitives as prim
 
 
@@ -66,23 +69,37 @@ def sumfact_evaluate_coefficient_gradient(element, name, restriction, component)
         a_matrices[i] = dtheta_matrix
         a_matrices = tuple(a_matrices)
 
-        buffer_name = name_sumfact_base_buffer()
-        initialize_buffer(buffer_name,
+        # Get the vectorization info. If this happens during the dry run, we get dummies
+        from dune.perftool.sumfact.vectorization import get_vectorization_info
+        a_matrices, buffer, input, index = get_vectorization_info(a_matrices)
+
+        # Initialize the buffer for the sum fact kernel
+        shape = (product(mat.cols for mat in a_matrices),)
+        if index:
+            shape = shape + (4,)
+        initialize_buffer(buffer,
                           base_storage_size=product(max(mat.rows, mat.cols) for mat in a_matrices),
                           num=2
-                          )
+                          ).get_temporary(shape=shape,
+                                          name=input,
+                                          )
+
+        # Setup the input!
+        setup_theta(input, element, restriction, component, index)
 
         # Add a sum factorization kernel that implements the
         # evaluation of the gradients of basis functions at quadrature
         # points (stage 1)
-        var = SumfactKernel(a_matrices,
-                            buffer_name,
-                            preferred_interleaving_position=i,
-                            setup_method=(setup_theta, (element, restriction, component))
-                            )
+        var, _ = sum_factorization_kernel(a_matrices,
+                                          buffer,
+                                          1,
+                                          preferred_position=i,
+                                          insn_dep=frozenset({Writes(input)}),
+                                          )
 
         buffers.append(var)
 
+    # TODO this should be quite conditional!!!
     for i, buf in enumerate(buffers):
         # Write solution from sumfactorization to gradient variable
         from pymbolic.primitives import Subscript, Variable
@@ -116,20 +133,32 @@ def pymbolic_trialfunction(element, restriction, component):
     a_matrix = AMatrix(rows, cols)
     a_matrices = (a_matrix,) * dim
 
+    # Get the vectorization info. If this happens during the dry run, we get dummies
+    from dune.perftool.sumfact.vectorization import get_vectorization_info
+    a_matrices, buffer, input, index = get_vectorization_info(a_matrices)
+
     # Flip flop buffers for sumfactorization
-    buffer_name = name_sumfact_base_buffer()
-    initialize_buffer(buffer_name,
+    shape = (product(mat.cols for mat in a_matrices),)
+    if vec:
+        shape = shape + (4,)
+    initialize_buffer(buffer,
                       base_storage_size=product(max(mat.rows, mat.cols) for mat in a_matrices),
                       num=2
-                      )
+                      ).get_temporary(shape=shape,
+                                      name=input,
+                                      )
+
+    # Setup the input!
+    setup_theta(input, element, restriction, component, index)
 
     # Add a sum factorization kernel that implements the evaluation of
     # the basis functions at quadrature points (stage 1)
-    var = SumfactKernel(a_matrices,
-                        buffer_name,
-                        preferred_interleaving_position=dim,
-                        setup_method=(setup_theta, (element, restriction, component))
-                        )
+    var, _ = sum_factorization_kernel(a_matrices,
+                                      buffer_name,
+                                      1,
+                                      preferred_position=None,
+                                      insn_dep=frozenset({Writes(input)}),
+                                      )
 
     return prim.Subscript(var,
                           tuple(prim.Variable(i) for i in quadrature_inames())
diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py
index c6bc6254..6cf4bd7b 100644
--- a/python/dune/perftool/sumfact/sumfact.py
+++ b/python/dune/perftool/sumfact/sumfact.py
@@ -49,6 +49,7 @@ from pymbolic.primitives import (Call,
                                  Variable,
                                  )
 from dune.perftool.sumfact.quadrature import quadrature_inames
+from dune.perftool.sumfact.vectorization import find_sumfact
 from loopy import Reduction, GlobalArg
 from loopy.symbolic import FunctionIdentifier, IdentityMapper
 
@@ -58,27 +59,6 @@ import pymbolic.primitives as prim
 from pytools import product
 
 
-class HasSumfactMapper(lp.symbolic.CombineMapper):
-    def combine(self, *args):
-        return frozenset().union(*tuple(*args))
-
-    def map_constant(self, expr):
-        return frozenset()
-
-    def map_algebraic_leaf(self, expr):
-        return frozenset()
-
-    def map_loopy_function_identifier(self, expr):
-        return frozenset()
-
-    def map_sumfact_kernel(self, expr):
-        return frozenset({expr})
-
-
-def find_sumfact(expr):
-    return HasSumfactMapper()(expr)
-
-
 class IndexFiddleMapper(IdentityMapper):
     def __init__(self, var, index, pos):
         assert isinstance(var, str)
@@ -264,13 +244,17 @@ def sumfact_iname(bound, _type):
     return name
 
 
-def setup_theta(inp, element, restriction, component, additional_indices=()):
+def setup_theta(inp, element, restriction, component, index):
+    if index is None:
+        index = ()
+    else:
+        index = (index,)
     # Write initial coefficients into buffer
     lfs = name_lfs(element, restriction, component)
     basisiname = sumfact_iname(name_lfs_bound(lfs), "basis")
     container = name_coefficientcontainer(restriction)
     coeff = pymbolic_coefficient(container, lfs, basisiname)
-    assignee = Subscript(Variable(inp), (Variable(basisiname),) + additional_indices)
+    assignee = Subscript(Variable(inp), (Variable(basisiname),) + index)
     return instruction(assignee=assignee,
                        expression=coeff,
                        )
@@ -330,14 +314,23 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
             pref_pos = i
         else:
             a_matrices = (theta_matrix,) * dim
-            pref_pos = dim
+            pref_pos = None
+
+        # Get the vectorization info. If this happens during the dry run, we get dummies
+        from dune.perftool.sumfact.vectorization import get_vectorization_info
+        a_matrices, buffer, input, index = get_vectorization_info(a_matrices)
 
         # Initialize a base storage for this buffer and get a temporay pointing to it
-        temp = initialize_buffer(buf,
+#        shape = product(mat.cols for mat in a_matrices)
+#        if vec:
+#            shape = (shape, 4)
+        temp = initialize_buffer(buffer,
                                  base_storage_size=product(max(mat.rows, mat.cols) for mat in a_matrices),
                                  num=2
                                  ).get_temporary(shape=(quadrature_points_per_direction(),) * dim,
-                                                 dim_tags=",".join(['f'] * dim))
+                                                 dim_tags=",".join(['f'] * dim),
+                                                 name=input,
+                                                 )
 
         # Replace gradient iname with correct index for assignement
         replace_dict = {}
@@ -365,14 +358,13 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
 
         # Add a sum factorization kernel that implements the multiplication
         # with the test function (stage 3)
-        result = SumfactKernel(a_matrices,
-                               buf,
-                               insn_dep=frozenset({contrib_dep}),
-                               additional_inames=frozenset(visitor.inames),
-                               stage=3,
-                               preferred_interleaving_position=pref_pos,
-                               input_temporary=temp,
-                               )
+        result, insn_dep = sum_factorization_kernel(a_matrices,
+                                                    buffer,
+                                                    3,
+                                                    insn_dep=frozenset({contrib_dep}),
+                                                    additional_inames=frozenset(visitor.inames),
+                                                    preferred_position=pref_pos,
+                                                    )
 
         inames = tuple(sumfact_iname(mat.rows, 'accum') for mat in a_matrices)
 
@@ -404,13 +396,14 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
                     expression=expr,
                     forced_iname_deps=frozenset(inames + visitor.inames),
                     forced_iname_deps_is_final=True,
+                    depends_on=insn_dep,
                     )
 
         # Mark the transformation that moves the quadrature loop inside the trialfunction loops for application
         transform(nest_quadrature_loops, visitor.inames)
 
 
-def sum_factorization_kernel(a_matrices, buf, insn_dep=frozenset({}), additional_inames=frozenset({}), add_vec_tag=False):
+def sum_factorization_kernel(a_matrices, buf, stage, insn_dep=frozenset({}), additional_inames=frozenset({}), add_vec_tag=False, preferred_position=None):
     """
     Calculate a sum factorization matrix product.
 
@@ -430,6 +423,9 @@ def sum_factorization_kernel(a_matrices, buf, insn_dep=frozenset({}), additional
         should depend upon. All following ones will depend on each
         other.
     """
+    if get_global_context_value("dry_run", False):
+        return SumfactKernel(a_matrices, buf, stage, preferred_position), frozenset()
+
     ftags = "f,f" + (",vec" if add_vec_tag else "")
     ctags = "c,c" + (",vec" if add_vec_tag else "")
     vec_shape = (4,) if add_vec_tag else ()
@@ -508,4 +504,4 @@ def sum_factorization_kernel(a_matrices, buf, insn_dep=frozenset({}), additional
                                dim_tags=dim_tags)
     silenced_warning('read_no_write({})'.format(out))
 
-    return out, insn_dep
+    return prim.Variable(out), insn_dep
diff --git a/python/dune/perftool/sumfact/vectorization.py b/python/dune/perftool/sumfact/vectorization.py
new file mode 100644
index 00000000..9e4f2eb8
--- /dev/null
+++ b/python/dune/perftool/sumfact/vectorization.py
@@ -0,0 +1,98 @@
+""" Sum factorization vectorization """
+
+from dune.perftool.generation import (generator_factory,
+                                      get_counted_variable,
+                                      )
+from dune.perftool.error import PerftoolError
+from dune.perftool.options import get_option
+
+import loopy as lp
+
+
+@generator_factory(item_tags=("vecinfo", "dryrundata"), cache_key_generator=lambda a, *args: a)
+def vectorization_info(a_matrices, buffer, input, index):
+    return (a_matrices, buffer, input, index)
+
+
+def get_vectorization_info(a_matrices):
+    from dune.perftool.generation import get_global_context_value
+    if get_global_context_value("dry_run"):
+        # Return dummy data
+        return (a_matrices, get_counted_variable("buffer"), get_counted_variable("input"), None)
+    try:
+        return vectorization_info(a_matrices, None, None, None)
+    except TypeError:
+        raise PerftoolError("Sumfact Vectorization data should have been collected in dry run, but wasnt")
+
+
+def no_vectorization(sumfacts):
+    for sumf in sumfacts:
+        vectorization_info(sumf.a_matrices, get_counted_variable("buffer"), get_counted_variable("input"), None)
+
+
+def decide_stage_vectorization_strategy(sumfacts, stage):
+    stage_sumfacts = frozenset([sf for sf in sumfacts if sf.stage == stage])
+    if len(stage_sumfacts) in (3, 4):
+        # Map the sum factorization to their position in the joint kernel
+        available = set(range(4))
+        for sf in stage_sumfacts:
+            if sf.preferred_position is not None:
+                # This asserts that no two kernels want to take the same position
+                # Later on, more complicated stuff might be necessary here.
+                assert sf.preferred_position in available
+                available.discard(sf.preferred_position)
+
+        # Enable vectorization strategy:
+        input = get_counted_variable("joined_input")
+        buffer = get_counted_variable("joined_buffer")
+
+        for sumf in stage_sumfacts:
+            pref_pos = sumf.preferred_position
+            if pref_pos is None:
+                pref_pos = available.pop()
+            vectorization_info(sumf.a_matrices, buffer, input, pref_pos)
+    else:
+        # Disable vectorization strategy
+        no_vectorization(stage_sumfacts)
+
+
+def decide_vectorization_strategy():
+    """ Decide how to vectorize!
+    Note that the vectorization of the quadrature loop is independent of this,
+    as it is implemented through a post-processing (== loopy transformation) step.
+    """
+    from dune.perftool.generation import retrieve_cache_items
+    insns = [i for i in retrieve_cache_items("kernel_default and instruction")]
+
+    # Find all sum factorization kernels
+    sumfacts = frozenset()
+    for insn in insns:
+        if isinstance(insn, (lp.Assignment, lp.CallInstruction)):
+            sumfacts = sumfacts.union(find_sumfact(insn.expression))
+
+    if not get_option("vectorize_grads"):
+        no_vectorization(sumfacts)
+    else:
+        decide_stage_vectorization_strategy(sumfacts, 1)
+        decide_stage_vectorization_strategy(sumfacts, 3)
+
+
+class HasSumfactMapper(lp.symbolic.CombineMapper):
+    def combine(self, *args):
+        return frozenset().union(*tuple(*args))
+
+    def map_constant(self, expr):
+        return frozenset()
+
+    def map_algebraic_leaf(self, expr):
+        return frozenset()
+
+    def map_loopy_function_identifier(self, expr):
+        return frozenset()
+
+    def map_sumfact_kernel(self, expr):
+        return frozenset({expr})
+
+
+def find_sumfact(expr):
+    return HasSumfactMapper()(expr)
-- 
GitLab