From 733675a93953c8fa935c45ca12dcbc8a9cdece79 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Thu, 30 Mar 2017 15:11:58 +0200
Subject: [PATCH] WIP checkpoint 1

---
 python/dune/perftool/loopy/symbolic.py        | 64 +++++++++++++++----
 python/dune/perftool/sumfact/amatrix.py       | 36 ++++++-----
 python/dune/perftool/sumfact/basis.py         | 58 ++++++++++++++---
 python/dune/perftool/sumfact/sumfact.py       | 32 ++++++++--
 python/dune/perftool/sumfact/vectorization.py | 48 +++++++-------
 5 files changed, 172 insertions(+), 66 deletions(-)

diff --git a/python/dune/perftool/loopy/symbolic.py b/python/dune/perftool/loopy/symbolic.py
index 7efae2d5..bfca31d6 100644
--- a/python/dune/perftool/loopy/symbolic.py
+++ b/python/dune/perftool/loopy/symbolic.py
@@ -7,6 +7,7 @@ from pymbolic.mapper.substitutor import make_subst_func
 
 import loopy as lp
 import pymbolic.primitives as prim
+from pytools import ImmutableRecord
 from six.moves import intern
 
 
@@ -15,35 +16,72 @@ from six.moves import intern
 #
 
 
-class SumfactKernel(prim.Variable):
+class SumfactKernel(ImmutableRecord, prim.Variable):
     def __init__(self,
-                 a_matrices,
-                 buffer,
-                 stage,
-                 preferred_position,
-                 restriction,
+                 a_matrices=None,
+                 buffer=None,
+                 stage=1,
+                 preferred_position=None,
+                 restriction=0,
+                 within_inames=frozenset(),
+                 input=None,
+                 padding=frozenset(),
+                 index=None,
                  ):
+        # Check the input and apply defaults where necessary
+        assert isinstance(a_matrices, tuple)
+        from dune.perftool.sumfact.amatrix import AMatrixBase
+        assert all(isinstance(m, AMatrixBase) for m in a_matrices)
+
+        assert stage in (1, 3)
+
+        if preferred_position is not None:
+            assert isinstance(preferred_position, int)
+
         if not isinstance(restriction, tuple):
             restriction = (restriction, 0)
 
-        self.a_matrices = a_matrices
-        self.buffer = buffer
-        self.stage = stage
-        self.preferred_position = preferred_position
-        self.restriction = restriction
+        assert isinstance(within_inames, frozenset)
+
+        ImmutableRecord.__init__(self,
+                                 a_matrices=a_matrices,
+                                 buffer=buffer,
+                                 stage=stage,
+                                 preferred_position=preferred_position,
+                                 restriction=restriction,
+                                 within_inames=within_inames,
+                                 input=input,
+                                 padding=padding,
+                                 index=index,
+                                 )
 
         prim.Variable.__init__(self, "SUMFACT")
 
+    #
+    # The methods/fields needed to get a well-formed pymbolic node
+    #
     def __getinitargs__(self):
-        return (self.a_matrices, self.buffer, self.stage, self.preferred_position, self.restriction)
+        return (self.a_matrices, self.buffer, self.stage, self.preferred_position, self.restriction, self.within_inames, self.input, self.padding, self.index)
 
     def stringifier(self):
         return lp.symbolic.StringifyMapper
 
-    init_arg_names = ("a_matrices", "buffer", "stage", "preferred_position", "restriction")
+    init_arg_names = ("a_matrices", "buffer", "stage", "preferred_position", "restriction", "within_inames", "input", "padding", "index")
 
     mapper_method = "map_sumfact_kernel"
 
+    #
+    # Some convenience methods to extract information about the sum factorization kernel
+    #
+
+    @property
+    def length(self):
+        return len(self.a_matrices)
+
+    @property
+    def vectorized(self):
+        return next(iter(a_matrices)).vectorized
+
 
 class FusedMultiplyAdd(prim.Expression):
     """ Represents an FMA operation """
diff --git a/python/dune/perftool/sumfact/amatrix.py b/python/dune/perftool/sumfact/amatrix.py
index 16f8b528..5037c426 100644
--- a/python/dune/perftool/sumfact/amatrix.py
+++ b/python/dune/perftool/sumfact/amatrix.py
@@ -35,15 +35,19 @@ import loopy as lp
 import numpy
 
 
-class AMatrix(ImmutableRecord):
+class AMatrixBase(ImmutableRecord):
+    pass
+
+
+class AMatrix(AMatrixBase):
     def __init__(self, rows, cols, transpose=False, derivative=False, face=None):
-        ImmutableRecord.__init__(self,
-                                 rows=rows,
-                                 cols=cols,
-                                 transpose=transpose,
-                                 derivative=derivative,
-                                 face=face,
-                                 )
+        AMatrixBase.__init__(self,
+                             rows=rows,
+                             cols=cols,
+                             transpose=transpose,
+                             derivative=derivative,
+                             face=face,
+                             )
 
     @property
     def name(self):
@@ -57,16 +61,16 @@ class AMatrix(ImmutableRecord):
         return lp.TaggedVariable(name, "sumfac")
 
 
-class LargeAMatrix(ImmutableRecord):
+class LargeAMatrix(AMatrixBase):
     def __init__(self, rows, cols, transpose, derivative, face):
         assert isinstance(derivative, tuple)
-        ImmutableRecord.__init__(self,
-                                 rows=rows,
-                                 cols=cols,
-                                 transpose=transpose,
-                                 derivative=derivative,
-                                 face=face,
-                                 )
+        AMatrixBase.__init__(self,
+                             rows=rows,
+                             cols=cols,
+                             transpose=transpose,
+                             derivative=derivative,
+                             face=face,
+                             )
 
     @property
     def name(self):
diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py
index e9b5a6ce..9f9e516c 100644
--- a/python/dune/perftool/sumfact/basis.py
+++ b/python/dune/perftool/sumfact/basis.py
@@ -6,6 +6,7 @@ multiplication with the test function is part of the sum factorization kernel.
 
 from dune.perftool.generation import (backend,
                                       domain,
+                                      get_counted_variable,
                                       get_counter,
                                       get_global_context_value,
                                       iname,
@@ -74,9 +75,26 @@ def pymbolic_coefficient_gradient(element, restriction, component, coeff_func, v
                                                 facemod=get_facemod(restriction),
                                                 )
 
-        # Get the vectorization info. If this happens during the dry run, we get dummies
-        from dune.perftool.sumfact.vectorization import get_vectorization_info
-        a_matrices, buf, inp, index, padding = get_vectorization_info(a_matrices, restriction)
+        sf = SumfactKernel(a_matrices=a_matrices,
+                           restriction=restriction,
+                           preferred_position=i,
+                           )
+
+        from dune.perftool.sumfact.vectorization import attach_vectorization_info
+        sf = attach_vectorization_info(sf)
+
+        # Extract again, for compatibility
+        # TODO away!
+        a_matrices = sf.a_matrices
+        buf = sf.buffer
+        inp = sf.input
+        index = sf.index
+        padding = sf.padding
+
+        if buf is None:
+            buf = get_counted_variable("buffer")
+        if inp is None:
+            inp = get_counted_variable("input")
 
         # Initialize the buffer for the sum fact kernel
         shape = (product(mat.cols for mat in a_matrices),)
@@ -101,7 +119,8 @@ def pymbolic_coefficient_gradient(element, restriction, component, coeff_func, v
         # Add a sum factorization kernel that implements the
         # evaluation of the gradients of basis functions at quadrature
         # points (stage 1)
-        var, insn_dep = sum_factorization_kernel(a_matrices,
+        if not get_global_context_value("dry_run", False):
+            var, insn_dep = sum_factorization_kernel(a_matrices,
                                                  buf,
                                                  1,
                                                  preferred_position=i,
@@ -110,6 +129,9 @@ def pymbolic_coefficient_gradient(element, restriction, component, coeff_func, v
                                                  outshape=tuple(mat.rows for mat in a_matrices if mat.face is None),
                                                  direct_input=direct_input,
                                                  )
+        else:
+            var = sf
+
         buffers.append(var)
 
     # Check whether we want to return early with something that has the indexing
@@ -144,9 +166,26 @@ def pymbolic_coefficient(element, restriction, component, coeff_func, visitor):
     a_matrices = construct_amatrix_sequence(facedir=get_facedir(restriction),
                                             facemod=get_facemod(restriction),)
 
-    # Get the vectorization info. If this happens during the dry run, we get dummies
-    from dune.perftool.sumfact.vectorization import get_vectorization_info
-    a_matrices, buf, inp, index, padding = get_vectorization_info(a_matrices, restriction)
+    sf = SumfactKernel(a_matrices=a_matrices,
+                       restriction=restriction,
+                       )
+
+    # TODO: Move this away!
+    from dune.perftool.sumfact.vectorization import attach_vectorization_info
+    sf = attach_vectorization_info(sf)
+
+    # Extract again, for compatibility
+    # TODO away!
+    a_matrices = sf.a_matrices
+    buf = sf.buffer
+    inp = sf.input
+    index = sf.index
+    padding = sf.padding
+
+    if buf is None:
+        buf = get_counted_variable("buffer")
+    if inp is None:
+        inp = get_counted_variable("input")
 
     # Flip flop buffers for sumfactorization
     shape = (product(mat.cols for mat in a_matrices),)
@@ -169,7 +208,8 @@ def pymbolic_coefficient(element, restriction, component, coeff_func, visitor):
 
     # Add a sum factorization kernel that implements the evaluation of
     # the basis functions at quadrature points (stage 1)
-    var, _ = sum_factorization_kernel(a_matrices,
+    if not get_global_context_value("dry_run", False):
+        var, _ = sum_factorization_kernel(a_matrices,
                                       buf,
                                       1,
                                       preferred_position=None,
@@ -178,6 +218,8 @@ def pymbolic_coefficient(element, restriction, component, coeff_func, visitor):
                                       restriction=restriction,
                                       direct_input=direct_input,
                                       )
+    else:
+        var = sf
 
     if index:
         index = (index,)
diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py
index 0dabe93a..95fceb08 100644
--- a/python/dune/perftool/sumfact/sumfact.py
+++ b/python/dune/perftool/sumfact/sumfact.py
@@ -14,6 +14,7 @@ from dune.perftool.generation import (backend,
                                       dump_accumulate_timer,
                                       function_mangler,
                                       generator_factory,
+                                      get_counted_variable,
                                       get_counter,
                                       get_global_context_value,
                                       globalarg,
@@ -143,9 +144,28 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
                                                 facemod=get_facemod(accterm.argument.restriction),
                                                 )
 
-        # Get the vectorization info. If this happens during the dry run, we get dummies
-        from dune.perftool.sumfact.vectorization import get_vectorization_info
-        a_matrices, buf, inp, index, padding = get_vectorization_info(a_matrices, (accterm.argument.restriction, restriction))
+        sf = SumfactKernel(a_matrices=a_matrices,
+                           restriction=(accterm.argument.restriction, restriction),
+                           stage=3,
+                           preferred_position=i if accterm.new_indices else None
+                           )
+
+        # TODO: Move this away!
+        from dune.perftool.sumfact.vectorization import attach_vectorization_info
+        sf = attach_vectorization_info(sf)
+
+        # Extract again, for compatibility
+        # TODO away!
+        a_matrices = sf.a_matrices
+        buf = sf.buffer
+        inp = sf.input
+        index = sf.index
+        padding = sf.padding
+
+        if buf is None:
+            buf = get_counted_variable("buffer")
+        if inp is None:
+            inp = get_counted_variable("input")
 
         # Initialize a base storage for this buffer and get a temporay pointing to it
         shape = tuple(mat.cols for mat in a_matrices if mat.face is None)
@@ -262,7 +282,8 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
         # Add a sum factorization kernel that implements the multiplication
         # with the test function (stage 3)
         pref_pos = i if accterm.new_indices else None
-        result, insn_dep = sum_factorization_kernel(a_matrices,
+        if not get_global_context_value("dry_run", False):
+            result, insn_dep = sum_factorization_kernel(a_matrices,
                                                     buf,
                                                     3,
                                                     insn_dep=insn_dep,
@@ -273,7 +294,8 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
                                                     direct_output=direct_output,
                                                     visitor=visitor
                                                     )
-
+        else:
+            result = sf
         # Determine the expression to accumulate with. This depends on the vectorization strategy!
         result = prim.Subscript(result, tuple(prim.Variable(i) for i in inames))
         vecinames = ()
diff --git a/python/dune/perftool/sumfact/vectorization.py b/python/dune/perftool/sumfact/vectorization.py
index 0adcc6ab..c7662135 100644
--- a/python/dune/perftool/sumfact/vectorization.py
+++ b/python/dune/perftool/sumfact/vectorization.py
@@ -1,7 +1,9 @@
 """ Sum factorization vectorization """
 
+from dune.perftool.loopy.symbolic import SumfactKernel
 from dune.perftool.generation import (generator_factory,
                                       get_counted_variable,
+                                      get_global_context_value,
                                       )
 from dune.perftool.pdelab.restriction import (Restriction,
                                               restricted_name,
@@ -12,35 +14,27 @@ from dune.perftool.options import get_option
 import loopy as lp
 
 
-@generator_factory(item_tags=("vecinfo", "dryrundata"), cache_key_generator=lambda a, r, *args: (a, r))
-def vectorization_info(a_matrices, restriction, new_a_matrices, buf, inp, index, padding):
-    return (new_a_matrices, buf, inp, index, padding)
+@generator_factory(item_tags=("vecinfo", "dryrundata"), cache_key_generator=lambda o, n: o)
+def _cache_vectorization_info(old, new):
+    if new is None:
+        raise PerftoolError("Vectorization info for sum factorization kernel was not gathered correctly!")
+    else:
+        print("Registering old: {}".format(repr(old)))
+        print("Registering new: {}".format(repr(new)))
+    return new
 
 
-def get_vectorization_info(a_matrices, restriction):
-    if not isinstance(restriction, tuple):
-        restriction = (restriction, 0)
-    from dune.perftool.generation import get_global_context_value
+def attach_vectorization_info(sf):
+    assert isinstance(sf, SumfactKernel)
     if get_global_context_value("dry_run"):
-        # Return dummy data
-        return (a_matrices, get_counted_variable("buffer"), get_counted_variable("input"), None, frozenset())
-
-    # Try getting the vectorization info collected after dry run
-    vec = vectorization_info(a_matrices, restriction, None, None, None, None, None)
-    if vec[0] is None:
-        raise PerftoolError("Sumfact Vectorization data should have been collected in dry run, but wasnt")
-    return vec
+        return sf
+    else:
+        return _cache_vectorization_info(sf, None)
 
 
 def no_vectorization(sumfacts):
-    for sumf in sumfacts:
-        vectorization_info(sumf.a_matrices,
-                           sumf.restriction,
-                           sumf.a_matrices,
-                           get_counted_variable("buffer"),
-                           get_counted_variable("input"),
-                           None,
-                           frozenset())
+    for sf in sumfacts:
+        _cache_vectorization_info(sf, sf)
 
 
 def decide_stage_vectorization_strategy(sumfacts, stage, restriction):
@@ -88,7 +82,13 @@ def decide_stage_vectorization_strategy(sumfacts, stage, restriction):
             large_a_matrices.append(large)
 
         for sumf in stage_sumfacts:
-            vectorization_info(sumf.a_matrices, sumf.restriction, tuple(large_a_matrices), buf, inp, position_mapping[sumf], frozenset(available))
+            _cache_vectorization_info(sumf,
+                                      sumf.copy(a_matrices=tuple(large_a_matrices),
+                                                buffer=buf,
+                                                input=inp,
+                                                index=position_mapping[sumf],
+                                                padding=frozenset(available))
+                                      )
     else:
         # Disable vectorization strategy
         no_vectorization(stage_sumfacts)
-- 
GitLab