From 1b3f71af7ae77a80d5fdd8460686f271339377da Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Fri, 4 Nov 2016 11:04:10 +0100
Subject: [PATCH] A 'transform' generator for delayed kernel transformation
 application

---
 python/dune/perftool/generation/__init__.py  |  1 +
 python/dune/perftool/generation/loopy.py     |  7 +++++++
 python/dune/perftool/pdelab/localoperator.py |  5 +++++
 python/dune/perftool/sumfact/quadrature.py   | 14 ++++++++++++++
 python/dune/perftool/sumfact/sumfact.py      |  6 ++++++
 5 files changed, 33 insertions(+)

diff --git a/python/dune/perftool/generation/__init__.py b/python/dune/perftool/generation/__init__.py
index ea94e825..05bba36a 100644
--- a/python/dune/perftool/generation/__init__.py
+++ b/python/dune/perftool/generation/__init__.py
@@ -39,6 +39,7 @@ from dune.perftool.generation.loopy import (constantarg,
                                             noop_instruction,
                                             silenced_warning,
                                             temporary_variable,
+                                            transform,
                                             valuearg,
                                             )
 
diff --git a/python/dune/perftool/generation/loopy.py b/python/dune/perftool/generation/loopy.py
index b3fda210..c5932bfa 100644
--- a/python/dune/perftool/generation/loopy.py
+++ b/python/dune/perftool/generation/loopy.py
@@ -124,3 +124,10 @@ def instruction(code=None, expression=None, **kwargs):
 @generator_factory(item_tags=("instruction",), cache_key_generator=lambda **kw: kw['id'])
 def noop_instruction(**kwargs):
     return loopy.NoOpInstruction(**kwargs)
+
+
+@generator_factory(item_tags=("transformation",),
+                   cache_key_generator=no_caching,
+                   )
+def transform(transform, *args):
+    return (transform, args)
diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py
index fca6efac..dcf26e15 100644
--- a/python/dune/perftool/pdelab/localoperator.py
+++ b/python/dune/perftool/pdelab/localoperator.py
@@ -482,6 +482,7 @@ def generate_kernel(integrals):
     arguments = [i for i in retrieve_cache_items("argument")]
     manglers = retrieve_cache_functions("mangler")
     silenced = [l for l in retrieve_cache_items("silenced_warning")]
+    transformations = [t for t in retrieve_cache_items("transformation")]
 
     # Construct an options object
     from loopy import Options
@@ -502,6 +503,10 @@ def generate_kernel(integrals):
     from loopy import make_reduction_inames_unique
     kernel = make_reduction_inames_unique(kernel)
 
+    # Apply the transformations that were gathered during tree traversals
+    for trafo in transformations:
+        kernel = trafo[0](kernel, *trafo[1])
+
     kernel = preprocess_kernel(kernel)
 
     from dune.perftool.loopy.duplicate import heuristic_duplication
diff --git a/python/dune/perftool/sumfact/quadrature.py b/python/dune/perftool/sumfact/quadrature.py
index 408c9c0b..4949fac4 100644
--- a/python/dune/perftool/sumfact/quadrature.py
+++ b/python/dune/perftool/sumfact/quadrature.py
@@ -27,6 +27,18 @@ from pymbolic.primitives import (Call,
 import numpy
 
 
+def nest_quadrature_loops(kernel, inames):
+    from loopy import find_instructions
+    insns = []
+    for insn in find_instructions(kernel, "tag:quad"):
+        insns.append(insn.copy(within_inames=insn.within_inames.union(frozenset(inames)),
+                               tags=insn.tags - frozenset({"quad"})
+                               )
+                     )
+
+    return kernel.copy(instructions=insns)
+
+
 class BaseWeight(FunctionIdentifier):
     def __init__(self, accumobj):
         self.accumobj = accumobj
@@ -70,6 +82,7 @@ def define_recursive_quadrature_weight(name, dir):
                 assignee=Variable(name),
                 forced_iname_deps=frozenset(quadrature_inames()[dir:]),
                 forced_iname_deps_is_final=True,
+                tags=frozenset({"quad"}),
                 )
 
 
@@ -96,6 +109,7 @@ def define_quadrature_position(name):
                     assignee=Subscript(Variable(name), (i,)),
                     forced_iname_deps=frozenset(quadrature_inames()),
                     forced_iname_deps_is_final=True,
+                    tags=frozenset({"quad"}),
                     )
 
 
diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py
index 3d52a879..62326d19 100644
--- a/python/dune/perftool/sumfact/sumfact.py
+++ b/python/dune/perftool/sumfact/sumfact.py
@@ -6,16 +6,19 @@ from dune.perftool.generation import (backend,
                                       domain,
                                       function_mangler,
                                       get_counter,
+                                      get_global_context_value,
                                       globalarg,
                                       iname,
                                       instruction,
                                       silenced_warning,
                                       temporary_variable,
+                                      transform,
                                       )
 from dune.perftool.loopy.buffer import (get_buffer_temporary,
                                         initialize_buffer,
                                         switch_base_storage,
                                         )
+from dune.perftool.sumfact.quadrature import nest_quadrature_loops
 from dune.perftool.pdelab.spaces import name_lfs
 from dune.perftool.sumfact.amatrix import (AMatrix,
                                            quadrature_points_per_direction,
@@ -127,6 +130,9 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
                 depends_on=frozenset({stage_insn(3)}),
                 )
 
+    # Mark the transformation that moves the quadrature loop inside the trialfunction loops for application
+    transform(nest_quadrature_loops, visitor.inames)
+
 
 def sum_factorization_kernel(a_matrices, buffer, stage, insn_dep=frozenset({}), additional_inames=frozenset({})):
     """
-- 
GitLab