From cb3390bd52c5884149bee1c8e7c6fa66a0d24b3f Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Thu, 7 Apr 2016 15:58:12 +0200
Subject: [PATCH] Do not remove mod args from the expression anymore

UFL did not like  my way of replacing some args with 1 as it
destroyed the shape information. It now remains in the expression,
but triggers code generation separately.
---
 python/dune/perftool/loopy/transformer.py     | 53 +++++++++----------
 python/dune/perftool/pdelab/quadrature.py     |  2 +-
 .../dune/perftool/ufl/modified_terminals.py   |  5 +-
 .../perftool/ufl/transformations/__init__.py  |  6 +--
 .../extract_accumulation_terms.py             | 19 +++----
 5 files changed, 36 insertions(+), 49 deletions(-)

diff --git a/python/dune/perftool/loopy/transformer.py b/python/dune/perftool/loopy/transformer.py
index 71681e95..15abdee6 100644
--- a/python/dune/perftool/loopy/transformer.py
+++ b/python/dune/perftool/loopy/transformer.py
@@ -79,25 +79,6 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper):
         return Variable(name_facetarea())
 
 
-def get_pymbolic_expr(expr):
-    """ Transform the given UFL expression into a pymbolic expression
-    and have all sorts of side effects on the generation cache. """
-    # We do need to manually handle modified terminals related to trial functions
-    from dune.perftool.ufl.modified_terminals import extract_modified_arguments
-    from dune.perftool.ufl.transformations.replace import ReplaceExpression
-    from dune.perftool.pdelab.argument import name_trialfunction
-    from pymbolic.primitives import Variable
-
-    trial_ma = extract_modified_arguments(expr, trialfunction=True)
-    # OLD CODE had: globalarg(name)
-    rmap = {ma.expr: Variable(name_trialfunction(ma)) for ma in trial_ma}
-    ufl2l_mf = UFL2LoopyVisitor()
-    re_mf = ReplaceExpression(replacemap=rmap, otherwise=ufl2l_mf)
-    ufl2l_mf.call = re_mf.__call__
-
-    return re_mf(expr)
-
-
 class _Counter:
     counter = 0
 
@@ -109,15 +90,32 @@ def get_count():
 
 
 def transform_accumulation_term(term):
-    # Get the accumulation expression and the modified arguments
-    expr, args = term
+    from dune.perftool.ufl.transformations.replace import ReplaceExpression
+    from pymbolic.primitives import Variable
 
     # We always have a quadrature loop
     quadrature_iname()
 
     # Get the pymbolic expression needed for this accumulation term.
     # This includes filling the cache with all sorts of necessary preambles!
-    pymbolic_expr = get_pymbolic_expr(expr)
+    from dune.perftool.ufl.modified_terminals import extract_modified_arguments
+    test_ma = extract_modified_arguments(term, trialfunction=False, testfunction=True)
+    trial_ma = extract_modified_arguments(term, trialfunction=True, testfunction=False)
+
+    rmap = {}
+    for ma in test_ma:
+        from dune.perftool.pdelab.argument import name_testfunction
+        rmap[ma.expr] = Variable(name_testfunction(ma))
+    for ma in trial_ma:
+        from dune.perftool.pdelab.argument import name_trialfunction
+        rmap[ma.expr] = Variable(name_trialfunction(ma))
+
+    # Get the transformer!
+    ufl2l_mf = UFL2LoopyVisitor()
+    re_mf = ReplaceExpression(replacemap=rmap, otherwise=ufl2l_mf)
+    ufl2l_mf.call = re_mf.__call__
+
+    pymbolic_expr = re_mf(term)
 
     # Now simplify the expression
     # TODO: Add a switch to disable/configure this.
@@ -132,16 +130,14 @@ def transform_accumulation_term(term):
 
     # The data that is used to collect the arguments for the accumulate function
     accumargs = []
-    argument_code = []
 
     # Generate the code for the modified arguments:
-    for arg in args:
+    for arg in test_ma:
         from dune.perftool.pdelab.argument import name_argumentspace, name_argument
         accumargs.append(name_argumentspace(arg))
         accumargs.append(argument_iname(arg))
-        name = name_argument(arg)
-        argument_code.append(name)
-        globalarg(name)
+        # TODO is this global
+        #globalarg(argument_iname(arg)+"_n")
 
     from dune.perftool.pdelab.argument import name_residual
     residual = name_residual()
@@ -151,10 +147,9 @@ def transform_accumulation_term(term):
 
     from dune.perftool.pdelab.quadrature import name_factor
     c_instruction(loopy.CInstruction(inames,
-                                     "{}.accumulate({}, {}*{}*{})".format(residual,
+                                     "{}.accumulate({}, {}*{})".format(residual,
                                                                           ", ".join(accumargs),
                                                                           expr_tv_name,
-                                                                          "*".join(argument_code),
                                                                           name_factor()
                                                                           )
                                      )
diff --git a/python/dune/perftool/pdelab/quadrature.py b/python/dune/perftool/pdelab/quadrature.py
index 962cb735..c8346972 100644
--- a/python/dune/perftool/pdelab/quadrature.py
+++ b/python/dune/perftool/pdelab/quadrature.py
@@ -8,7 +8,7 @@ def quadrature_rule():
     return "rule"
 
 
-@quadrature_preamble(assignees="fac")
+@quadrature_preamble()
 def define_quadrature_factor(fac):
     rule = quadrature_rule()
     return "auto {} = {}->weight();".format(fac, rule)
diff --git a/python/dune/perftool/ufl/modified_terminals.py b/python/dune/perftool/ufl/modified_terminals.py
index 9bc23612..8bf5173a 100644
--- a/python/dune/perftool/ufl/modified_terminals.py
+++ b/python/dune/perftool/ufl/modified_terminals.py
@@ -90,9 +90,10 @@ class ModifiedArgumentDescriptor(MultiFunction):
 class _ModifiedArgumentExtractor(MultiFunction):
     """ A multifunction that extracts and returns the set of modified arguments """
 
-    def __call__(self, o, argnumber=None, trialfunction=False):
+    def __call__(self, o, argnumber=None, testfunction=True, trialfunction=False):
         self.argnumber = argnumber
         self.trialfunction = trialfunction
+        self.testfunction = testfunction
         self.modified_arguments = set()
         ret = self.call(o)
         if ret:
@@ -127,7 +128,7 @@ class _ModifiedArgumentExtractor(MultiFunction):
             return o
 
     def argument(self, o):
-        if not self.trialfunction:
+        if self.testfunction:
             if self.argnumber is None or o.number() == self.argnumber:
                 return o
 
diff --git a/python/dune/perftool/ufl/transformations/__init__.py b/python/dune/perftool/ufl/transformations/__init__.py
index 81b3ff6e..19577bb3 100644
--- a/python/dune/perftool/ufl/transformations/__init__.py
+++ b/python/dune/perftool/ufl/transformations/__init__.py
@@ -48,11 +48,7 @@ class UFLTransformationWrapper(object):
         # We do also assume that the transformation returns an ufl expression or a list there of
         ret_for_print = self.extractExpressionListFromResult(ret)
 
-        try:
-            assert isinstance(ret_for_print, list) and all(isinstance(e, Expr) for e in ret_for_print)
-        except AssertionError:
-            from IPython import embed
-            embed()
+        assert isinstance(ret_for_print, list) and all(isinstance(e, Expr) for e in ret_for_print)
 
         # Maybe output the returned expression
         self.write_trafo(ret_for_print, False)
diff --git a/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py b/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py
index d14383ab..3dcb8b42 100644
--- a/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py
+++ b/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py
@@ -10,22 +10,22 @@ from dune.perftool.ufl.transformations import ufl_transformation
 from dune.perftool.ufl.transformations.replace import replace_expression
 
 from ufl.algorithms import MultiFunction
-from ufl.classes import Zero, IntValue
+from ufl.classes import Zero
 
 import itertools
 
 
 class _ReplacementDict(dict):
-    def __init__(self, *mod_args):
+    def __init__(self, *args):
         dict.__init__(self)
-        for ma in mod_args:
-            self[ma] = IntValue(1)
+        for a in args:
+            self[a] = a
 
     def __getitem__(self, key):
         return self.get(key, Zero())
 
 
-@ufl_transformation(name="accterms2", extraction_lambda=lambda l: [i[0] for i in l])
+@ufl_transformation(name="accterms2", extraction_lambda=lambda l: l)
 def split_into_accumulation_terms(expr):
     mod_args = extract_modified_arguments(expr)
 
@@ -35,18 +35,13 @@ def split_into_accumulation_terms(expr):
     if len(filter(lambda ma: ma.argexpr.count() == 1, mod_args)) == 0:
         for arg in mod_args:
             # Do the replacement on the expression
-            accum_expr = replace_expression(expr, replacemap=_ReplacementDict(arg.expr))
-
-            # Store the found accumulation expression
-            accumulation_terms.append((accum_expr, (arg,)))
+            accumulation_terms.append(replace_expression(expr, replacemap=_ReplacementDict(arg.expr)))
     # and now the case of a rank 2 form:
     else:
         for arg1, arg2 in itertools.product(filter(lambda ma: ma.argexpr.count() == 0, mod_args),
                                             filter(lambda ma: ma.argexpr.count() == 1, mod_args)
                                             ):
-            accum_expr = replace_expression(expr, replacemap=_ReplacementDict(arg1.expr, arg2.expr))
-
-            accumulation_terms.append((accum_expr, (arg1, arg2)))
+            accumulation_terms.append(replace_expression(expr, replacemap=_ReplacementDict(arg1.expr, arg2.expr)))
 
     # and return the result
     return accumulation_terms
-- 
GitLab