From 5e072d6c6142f0ea7e553120f5d200f9081bd050 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Thu, 7 Apr 2016 14:56:41 +0200
Subject: [PATCH] Various foxes to the generation machinery

---
 python/dune/perftool/generation/cache.py       | 10 ++++++++--
 python/dune/perftool/generation/cpp.py         |  2 +-
 python/dune/perftool/generation/loopy.py       |  3 ++-
 python/dune/perftool/loopy/transformer.py      |  8 ++++----
 python/dune/perftool/pdelab/__init__.py        |  7 ++++++-
 python/dune/perftool/pdelab/argument.py        | 18 +++++++++---------
 python/dune/perftool/pdelab/localoperator.py   |  7 ++++---
 python/dune/perftool/pdelab/quadrature.py      |  2 +-
 .../extract_accumulation_terms.py              |  4 ++--
 9 files changed, 37 insertions(+), 24 deletions(-)

diff --git a/python/dune/perftool/generation/cache.py b/python/dune/perftool/generation/cache.py
index 0780cd14..82217280 100644
--- a/python/dune/perftool/generation/cache.py
+++ b/python/dune/perftool/generation/cache.py
@@ -65,11 +65,17 @@ class _CacheItemMeta(type):
 
             def add_count(x):
                 rettype._count = rettype._count + 1
-                return (rettype._count, original_on_store(x))
+                if isinstance(x, tuple):
+                    return (rettype._count, original_on_store(*x))
+                else:
+                    return (rettype._count, original_on_store(x))
             on_store = add_count
 
         def _init(s, x):
-            s.content = on_store(x)
+            if isinstance(x, tuple) and not counted:
+                s.content = on_store(*x)
+            else:
+                s.content = on_store(x)
             s.tags = item_tags
             s.counted = counted
 
diff --git a/python/dune/perftool/generation/cpp.py b/python/dune/perftool/generation/cpp.py
index 652f7f71..32efb653 100644
--- a/python/dune/perftool/generation/cpp.py
+++ b/python/dune/perftool/generation/cpp.py
@@ -43,7 +43,7 @@ def class_member(classtag=None, access=AccessModifier.PRIVATE):
     from cgen import Value
     from dune.perftool.cgen.clazz import ClassMember
 
-    return generator_factory(item_tags=(classtag, "member"), on_store=lambda t, n: ClassMember(Value(_type, name), access=access), counted=True, cache_key_generator=lambda t, n: n)
+    return generator_factory(item_tags=(classtag, "member"), on_store=lambda t, n: ClassMember(Value(t, n), access=access), counted=True)
 
 
 def constructor_parameter(_type, name, classtag=None, constructortag=None):
diff --git a/python/dune/perftool/generation/loopy.py b/python/dune/perftool/generation/loopy.py
index d4098dec..55268c4f 100644
--- a/python/dune/perftool/generation/loopy.py
+++ b/python/dune/perftool/generation/loopy.py
@@ -22,5 +22,6 @@ def globalarg(name, shape=loopy.auto):
 
 @generator_factory(item_tags=("loopy", "kernel", "domain"))
 def domain(iname, shape):
-    valuearg(shape)
+    if isinstance(shape, str):
+        valuearg(shape)
     return "{{ [{0}] : 0<={0}<{1} }}".format(iname, shape)
diff --git a/python/dune/perftool/loopy/transformer.py b/python/dune/perftool/loopy/transformer.py
index e5f26008..3ad304f1 100644
--- a/python/dune/perftool/loopy/transformer.py
+++ b/python/dune/perftool/loopy/transformer.py
@@ -37,9 +37,9 @@ def dimension_iname(index):
 def argument_iname(arg):
     # TODO extract the {iname}_n thing by a preamble
     from dune.perftool.ufl.modified_terminals import modified_argument_number
-    iname = "arg{}".format(chr(ord("i") + modified_argument_number()(arg)))
-    domain(iname, iname + "_n")
-    return iname
+    ainame = "arg{}".format(chr(ord("i") + arg.argexpr.number()))
+    domain(ainame, ainame + "_n")
+    return ainame
 
 
 @iname
@@ -76,7 +76,7 @@ def get_pymbolic_expr(expr):
 
     trial_ma = extract_modified_arguments(expr, trialfunction=True)
     # OLD CODE had: globalarg(name)
-    rmap = {ma: Variable(name_trialfunction(ma)) for ma in trial_ma}
+    rmap = {ma.expr: Variable(name_trialfunction(ma)) for ma in trial_ma}
     ufl2l_mf = UFL2LoopyVisitor()
     re_mf = ReplaceExpression(replacemap=rmap, otherwise=ufl2l_mf)
     ufl2l_mf.call = re_mf.__call__
diff --git a/python/dune/perftool/pdelab/__init__.py b/python/dune/perftool/pdelab/__init__.py
index 981c8b71..5743fbab 100644
--- a/python/dune/perftool/pdelab/__init__.py
+++ b/python/dune/perftool/pdelab/__init__.py
@@ -14,4 +14,9 @@ def quadrature_preamble(assignees=[]):
 # Now define some commonly used generators that do not fall into a specific category
 @symbol
 def name_index(index):
-    return str(index._indices[0])
+    from ufl.classes import MultiIndex, Index
+    if isinstance(index, Index):
+        return str(index)
+    if isinstance(index, MultiIndex):
+        assert len(index) == 1
+        return str(index._indices[0])
diff --git a/python/dune/perftool/pdelab/argument.py b/python/dune/perftool/pdelab/argument.py
index bc7e6078..318c0040 100644
--- a/python/dune/perftool/pdelab/argument.py
+++ b/python/dune/perftool/pdelab/argument.py
@@ -6,7 +6,7 @@ from dune.perftool.ufl.modified_terminals import ModifiedArgumentDescriptor
 
 @symbol
 def name_testfunction(ma):
-    if len(ma.expr.element().sub_elements()) > 0:
+    if len(ma.argexpr.element().sub_elements()) > 0:
         pass
     return "{}a{}".format("grad_" if ma.grad else "", ma.argexpr.number())
 
@@ -29,19 +29,19 @@ def name_trialfunctionspace(*a):
 
 
 def name_argumentspace(ma):
-    if ma.expr.number() == 0:
-        return name_testfunctionspace(modarg)
-    if ma.expr.number() == 1:
-        return name_trialfunctionspace(modarg)
+    if ma.argexpr.number() == 0:
+        return name_testfunctionspace(ma)
+    if ma.argexpr.number() == 1:
+        return name_trialfunctionspace(ma)
     # We should never encounter an argument other than 0 or 1
     assert False
 
 
 def name_argument(ma):
-    if ma.expr.number() == 0:
-        return name_testfunction(modarg)
-    if ma.expr.number() == 1:
-        return name_trialfunction(modarg)
+    if ma.argexpr.number() == 0:
+        return name_testfunction(ma)
+    if ma.argexpr.number() == 1:
+        return name_trialfunction(ma)
     # We should never encounter an argument other than 0 or 1
     assert False
 
diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py
index e844820e..bb2cd648 100644
--- a/python/dune/perftool/pdelab/localoperator.py
+++ b/python/dune/perftool/pdelab/localoperator.py
@@ -11,9 +11,9 @@ from pytools import memoize
 def define_initree(name):
     include_file('dune/common/parametertree.hh', filetag="operatorfile")
     constructor_parameter("const Dune::ParameterTree&", "iniParams", classtag="operator", constructortag="iniconstructor")
-    initializer_list("_iniParams", ["iniParams"])
+    initializer_list("_iniParams", ["iniParams"], classtag="operator")
 
-    return "const Dune::ParameterTree&", "_iniParams"
+    return ("const Dune::ParameterTree&", "_iniParams")
 
 
 @symbol
@@ -43,7 +43,8 @@ def measure_specific_details(measure):
             # Add the initializer list for that base class
             ini = name_initree_member()
             initializer_list("Dune::PDELab::NumericalJacobian{}<{}>".format(which, loptype),
-                             ["{}.get(\"numerical_epsilon.{}\", 1e-9)".format(ini, which.lower())])
+                             ["{}.get(\"numerical_epsilon.{}\", 1e-9)".format(ini, which.lower())],
+                             classtag="operator")
 
     if measure == "cell":
         base_class('Dune::PDELab::FullVolumePattern', classtag="operator")
diff --git a/python/dune/perftool/pdelab/quadrature.py b/python/dune/perftool/pdelab/quadrature.py
index a5a99681..962cb735 100644
--- a/python/dune/perftool/pdelab/quadrature.py
+++ b/python/dune/perftool/pdelab/quadrature.py
@@ -16,6 +16,6 @@ def define_quadrature_factor(fac):
 
 @symbol
 def name_factor():
-    loopy_temporary_variable("fac")
+    temporary_variable("fac")
     define_quadrature_factor("fac")
     return "fac"
diff --git a/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py b/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py
index 2f902aa2..d14383ab 100644
--- a/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py
+++ b/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py
@@ -35,7 +35,7 @@ def split_into_accumulation_terms(expr):
     if len(filter(lambda ma: ma.argexpr.count() == 1, mod_args)) == 0:
         for arg in mod_args:
             # Do the replacement on the expression
-            accum_expr = replace_expression(expr, replacemap=_ReplacementDict(arg))
+            accum_expr = replace_expression(expr, replacemap=_ReplacementDict(arg.expr))
 
             # Store the found accumulation expression
             accumulation_terms.append((accum_expr, (arg,)))
@@ -44,7 +44,7 @@ def split_into_accumulation_terms(expr):
         for arg1, arg2 in itertools.product(filter(lambda ma: ma.argexpr.count() == 0, mod_args),
                                             filter(lambda ma: ma.argexpr.count() == 1, mod_args)
                                             ):
-            accum_expr = replace_expression(expr, replacemap=_ReplacementDict(arg1, arg2))
+            accum_expr = replace_expression(expr, replacemap=_ReplacementDict(arg1.expr, arg2.expr))
 
             accumulation_terms.append((accum_expr, (arg1, arg2)))
 
-- 
GitLab