From c2c6100a8c3e292e5bc577bf01a4c3153acb56b3 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Fri, 8 Apr 2016 15:45:32 +0200
Subject: [PATCH] Major overhaul of loopy instruction generation mechanisms

---
 python/dune/perftool/generation/__init__.py |  5 +-
 python/dune/perftool/generation/cache.py    | 13 +++--
 python/dune/perftool/generation/loopy.py    | 54 ++++++++++++++++++++-
 python/dune/perftool/loopy/transformer.py   | 33 +++++++------
 python/dune/perftool/pdelab/__init__.py     |  8 ++-
 python/dune/perftool/pdelab/quadrature.py   |  4 +-
 6 files changed, 87 insertions(+), 30 deletions(-)

diff --git a/python/dune/perftool/generation/__init__.py b/python/dune/perftool/generation/__init__.py
index 81463557..d1e9fc44 100644
--- a/python/dune/perftool/generation/__init__.py
+++ b/python/dune/perftool/generation/__init__.py
@@ -14,11 +14,10 @@ from dune.perftool.generation.cpp import (base_class,
                                           symbol,
                                           )
 
-from dune.perftool.generation.loopy import (c_instruction,
-                                            domain,
-                                            expr_instruction,
+from dune.perftool.generation.loopy import (domain,
                                             globalarg,
                                             iname,
+                                            instruction,
                                             pymbolic_expr,
                                             temporary_variable,
                                             valuearg,
diff --git a/python/dune/perftool/generation/cache.py b/python/dune/perftool/generation/cache.py
index 82217280..5a0e026a 100644
--- a/python/dune/perftool/generation/cache.py
+++ b/python/dune/perftool/generation/cache.py
@@ -25,6 +25,10 @@ def _freeze(data):
     if isinstance(data, ufl.classes.Expr):
         return data
 
+    from pymbolic.primitives import Expression
+    if isinstance(data, Expression):
+        return data
+
     # Check if the given data is already hashable
     if isinstance(data, Hashable):
         if isinstance(data, Iterable):
@@ -103,16 +107,19 @@ class _RegisteredFunction(object):
         self.cache_key_generator = cache_key_generator
         self.itemtype = _construct_cache_item_type("CacheItemType", **kwargs)
 
-    def __call__(self, *args):
+    def __call__(self, *args, **kwargs):
         # Get the cache key from the given arguments
-        cache_key = (self, _freeze(self.cache_key_generator(*args)))
+        cache_args = self.cache_key_generator(*args, **kwargs)
+        # Make sure that all keyword arguments have vanished from the cache_args
+        assert (lambda *a, **k: len(k) == 0)(cache_args)
+        cache_key = (self, _freeze(self.cache_key_generator(*args, **kwargs)))
         # check whether we have a cache hit
         if cache_key in _cache:
             # and return the result depending on the cache item type
             return _cache[cache_key].content
         else:
             # evaluate the original function and wrap it in a cache item
-            citem = self.itemtype(self.func(*args))
+            citem = self.itemtype(self.func(*args, **kwargs))
             _cache[cache_key] = citem
             return citem.content
 
diff --git a/python/dune/perftool/generation/loopy.py b/python/dune/perftool/generation/loopy.py
index 78a82205..38746ff9 100644
--- a/python/dune/perftool/generation/loopy.py
+++ b/python/dune/perftool/generation/loopy.py
@@ -7,9 +7,7 @@ import loopy
 import numpy
 
 iname = generator_factory(item_tags=("loopy", "kernel", "iname"))
-expr_instruction = generator_factory(item_tags=("loopy", "kernel", "instruction", "exprinstruction"), no_deco=True)
 temporary_variable = generator_factory(item_tags=("loopy", "kernel", "temporary"), on_store=lambda n: loopy.TemporaryVariable(n, dtype=numpy.float64), no_deco=True)
-c_instruction = generator_factory(item_tags=("loopy", "kernel", "instruction", "cinstruction"), no_deco=True)
 valuearg = generator_factory(item_tags=("loopy", "kernel", "argument", "valuearg"), on_store=lambda n: loopy.ValueArg(n), no_deco=True)
 pymbolic_expr = generator_factory(item_tags=("loopy", "kernel", "pymbolic"))
 constantarg = generator_factory(item_tags=("loopy", "kernel", "argument", "constantarg"), on_store=lambda n:loopy.ConstantArg(n))
@@ -26,3 +24,55 @@ def domain(iname, shape):
     if isinstance(shape, str):
         valuearg(shape)
     return "{{ [{0}] : 0<={0}<{1} }}".format(iname, shape)
+
+
+# Now define generators for instructions. To ease dependency handling of instructions
+# these generators are a bit more involved... We apply the following procedure:
+# There is one generator that returns the unique id and forwards to a generator that
+# actually adds the instruction. Hashing is done based on the code snippet.
+
+@generator_factory(item_tags=("loopy", "kernel", "instruction", "cinstruction"),
+                   cache_key_generator=lambda *a, **kw: kw['code'],
+                   )
+def c_instruction_impl(**kw):
+    kw['insn_deps'] = kw.pop('deps', None)
+    kw.setdefault('assignees', [])
+    inames = kw.pop('inames')
+    return loopy.CInstruction(inames, **kw)
+
+
+@generator_factory(item_tags=("loopy", "kernel", "instruction", "exprinstruction"),
+                   cache_key_generator=lambda *a, **kw: kw['expression'],
+                   )
+def expr_instruction_impl(**kw):
+    return loopy.ExpressionInstruction(id=kw['id'], assignee=kw['assignee'], expression=kw['expression'])
+
+
+class _IDCounter:
+    count = 0
+
+
+def _insn_cache_key(inames, code=None, expr=None, deps=[]):
+    if code:
+        return code
+    if expr:
+        return expr
+
+
+@generator_factory(item_tags=("insn_id"), no_deco=True, cache_key_generator=_insn_cache_key)
+def instruction(code=None, expression=None, **kwargs):
+    assert code or expression
+    assert not (code and expression)
+
+    # Get an ID for this instruction
+    id = 'insn' + str(_IDCounter.count).zfill(4)
+    _IDCounter.count = _IDCounter.count + 1
+
+    # Now create the actual instruction
+    if code:
+        c_instruction_impl(id=id, code=code, **kwargs)
+    if expression:
+        expr_instruction_impl(id=id, expression=expression, **kwargs)
+
+    # return the ID, as it is the only useful information to the user
+    return id
diff --git a/python/dune/perftool/loopy/transformer.py b/python/dune/perftool/loopy/transformer.py
index 68d2a703..7d2ae27c 100644
--- a/python/dune/perftool/loopy/transformer.py
+++ b/python/dune/perftool/loopy/transformer.py
@@ -9,11 +9,10 @@ from dune.perftool import Restriction
 from dune.perftool.ufl.modified_terminals import ModifiedTerminalTracker
 from dune.perftool.pymbolic.uflmapper import UFL2PymbolicMapper
 
-from dune.perftool.generation import (c_instruction,
-                                      domain,
-                                      expr_instruction,
+from dune.perftool.generation import (domain,
                                       globalarg,
                                       iname,
+                                      instruction,
                                       temporary_variable,
                                       valuearg,
                                       )
@@ -48,6 +47,12 @@ def quadrature_iname():
     return "q"
 
 
+@iname
+def index_sum_iname(i):
+    from dune.perftool.pdelab import name_index
+    return name_index(i)
+
+
 class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper):
     def __init__(self):
         super(UFL2LoopyVisitor, self).__init__()
@@ -64,10 +69,9 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper):
         # Define an iname for each of the indices in the multiindex
         for i in o.ufl_operands[1].indices():
             shape = determine_shape(o.ufl_operands[0], i)
+            index_sum_iname(i)
             from dune.perftool.pdelab import name_index
-            name = name_index(i)
-            iname(name)
-            domain(name, shape)
+            domain(name_index(i), shape)
 
         # Now continue processing the expression
         return self.call(o.ufl_operands[0])
@@ -126,7 +130,7 @@ def transform_accumulation_term(term):
     expr_tv_name = "expr_" + str(get_count()).zfill(4)
     expr_tv = temporary_variable(expr_tv_name)
     from pymbolic.primitives import Variable
-    expr_instruction(loopy.ExpressionInstruction(assignee=Variable(expr_tv_name), expression=pymbolic_expr))
+    instruction(assignee=Variable(expr_tv_name), expression=pymbolic_expr)
 
     # The data that is used to collect the arguments for the accumulate function
     accumargs = []
@@ -144,11 +148,10 @@ def transform_accumulation_term(term):
     inames = retrieve_cache_items("iname")
 
     from dune.perftool.pdelab.quadrature import name_factor
-    c_instruction(loopy.CInstruction(inames,
-                                     "{}.accumulate({}, {}*{})".format(residual,
-                                                                          ", ".join(accumargs),
-                                                                          expr_tv_name,
-                                                                          name_factor()
-                                                                          )
-                                     )
-                  )
+    instruction(inames=inames,
+                code="{}.accumulate({}, {}*{})".format(residual,
+                                                       ", ".join(accumargs),
+                                                       expr_tv_name,
+                                                       name_factor()
+                                                       )
+                )
\ No newline at end of file
diff --git a/python/dune/perftool/pdelab/__init__.py b/python/dune/perftool/pdelab/__init__.py
index 5743fbab..c1f69266 100644
--- a/python/dune/perftool/pdelab/__init__.py
+++ b/python/dune/perftool/pdelab/__init__.py
@@ -1,14 +1,12 @@
 """ The pdelab specific parts of the code generation process """
 
 # Define the generators that are used throughout all pdelab specific code generations.
-from dune.perftool.generation import symbol, generator_factory
+from dune.perftool.generation import symbol, instruction
 from dune.perftool.loopy.transformer import quadrature_iname
-from loopy import CInstruction
 
 
-def quadrature_preamble(assignees=[]):
-    # TODO: How to enforce the order of quadrature preambles? Counted?
-    return generator_factory(item_tags=("instruction", "cinstruction"), on_store=lambda code: CInstruction(quadrature_iname(), code, assignees=assignees))
+def quadrature_preamble(code, **kw):
+    return instruction(inames=quadrature_iname(), code=code, **kw)
 
 
 # Now define some commonly used generators that do not fall into a specific category
diff --git a/python/dune/perftool/pdelab/quadrature.py b/python/dune/perftool/pdelab/quadrature.py
index c8346972..2719ed1e 100644
--- a/python/dune/perftool/pdelab/quadrature.py
+++ b/python/dune/perftool/pdelab/quadrature.py
@@ -8,10 +8,10 @@ def quadrature_rule():
     return "rule"
 
 
-@quadrature_preamble()
 def define_quadrature_factor(fac):
     rule = quadrature_rule()
-    return "auto {} = {}->weight();".format(fac, rule)
+    code = "auto {} = {}->weight();".format(fac, rule)
+    return quadrature_preamble(code, assignees=fac)
 
 
 @symbol
-- 
GitLab