From e005da8abda3f3914ed628c183b80d73823106ef Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Thu, 7 Apr 2016 17:15:11 +0200
Subject: [PATCH] A bit work on arguments

laplace is again proceeding to code generation
---
 python/dune/perftool/generation/__init__.py   |  1 +
 python/dune/perftool/generation/loopy.py      |  3 +-
 python/dune/perftool/loopy/transformer.py     | 16 ++++-----
 python/dune/perftool/pdelab/argument.py       | 35 ++++++++++++++++++-
 .../dune/perftool/ufl/modified_terminals.py   |  2 +-
 5 files changed, 45 insertions(+), 12 deletions(-)

diff --git a/python/dune/perftool/generation/__init__.py b/python/dune/perftool/generation/__init__.py
index 54a6bf47..81463557 100644
--- a/python/dune/perftool/generation/__init__.py
+++ b/python/dune/perftool/generation/__init__.py
@@ -19,6 +19,7 @@ from dune.perftool.generation.loopy import (c_instruction,
                                             expr_instruction,
                                             globalarg,
                                             iname,
+                                            pymbolic_expr,
                                             temporary_variable,
                                             valuearg,
                                             )
diff --git a/python/dune/perftool/generation/loopy.py b/python/dune/perftool/generation/loopy.py
index 55268c4f..78a82205 100644
--- a/python/dune/perftool/generation/loopy.py
+++ b/python/dune/perftool/generation/loopy.py
@@ -11,7 +11,8 @@ expr_instruction = generator_factory(item_tags=("loopy", "kernel", "instruction"
 temporary_variable = generator_factory(item_tags=("loopy", "kernel", "temporary"), on_store=lambda n: loopy.TemporaryVariable(n, dtype=numpy.float64), no_deco=True)
 c_instruction = generator_factory(item_tags=("loopy", "kernel", "instruction", "cinstruction"), no_deco=True)
 valuearg = generator_factory(item_tags=("loopy", "kernel", "argument", "valuearg"), on_store=lambda n: loopy.ValueArg(n), no_deco=True)
-
+pymbolic_expr = generator_factory(item_tags=("loopy", "kernel", "pymbolic"))
+constantarg = generator_factory(item_tags=("loopy", "kernel", "argument", "constantarg"), on_store=lambda n:loopy.ConstantArg(n))
 
 @generator_factory(item_tags=("loopy", "kernel", "argument", "globalarg"))
 def globalarg(name, shape=loopy.auto):
diff --git a/python/dune/perftool/loopy/transformer.py b/python/dune/perftool/loopy/transformer.py
index 15abdee6..68d2a703 100644
--- a/python/dune/perftool/loopy/transformer.py
+++ b/python/dune/perftool/loopy/transformer.py
@@ -104,11 +104,11 @@ def transform_accumulation_term(term):
 
     rmap = {}
     for ma in test_ma:
-        from dune.perftool.pdelab.argument import name_testfunction
-        rmap[ma.expr] = Variable(name_testfunction(ma))
+        from dune.perftool.pdelab.argument import pymbolic_testfunction
+        rmap[ma.expr] = pymbolic_testfunction(ma)
     for ma in trial_ma:
-        from dune.perftool.pdelab.argument import name_trialfunction
-        rmap[ma.expr] = Variable(name_trialfunction(ma))
+        from dune.perftool.pdelab.argument import pymbolic_trialfunction
+        rmap[ma.expr] = pymbolic_trialfunction(ma)
 
     # Get the transformer!
     ufl2l_mf = UFL2LoopyVisitor()
@@ -119,8 +119,8 @@ def transform_accumulation_term(term):
 
     # Now simplify the expression
     # TODO: Add a switch to disable/configure this.
-    from dune.perftool.pymbolic.simplify import simplify_pymbolic_expression
-    pymbolic_expr = simplify_pymbolic_expression(pymbolic_expr)
+#     from dune.perftool.pymbolic.simplify import simplify_pymbolic_expression
+#     pymbolic_expr = simplify_pymbolic_expression(pymbolic_expr)
 
     # Define a temporary variable for this expression
     expr_tv_name = "expr_" + str(get_count()).zfill(4)
@@ -133,11 +133,9 @@ def transform_accumulation_term(term):
 
     # Generate the code for the modified arguments:
     for arg in test_ma:
-        from dune.perftool.pdelab.argument import name_argumentspace, name_argument
+        from dune.perftool.pdelab.argument import name_argumentspace, pymbolic_argument
         accumargs.append(name_argumentspace(arg))
         accumargs.append(argument_iname(arg))
-        # TODO is this global
-        #globalarg(argument_iname(arg)+"_n")
 
     from dune.perftool.pdelab.argument import name_residual
     residual = name_residual()
diff --git a/python/dune/perftool/pdelab/argument.py b/python/dune/perftool/pdelab/argument.py
index 318c0040..720108b3 100644
--- a/python/dune/perftool/pdelab/argument.py
+++ b/python/dune/perftool/pdelab/argument.py
@@ -1,6 +1,6 @@
 """ Generator functions related to trial and test functions and the accumulation loop"""
 
-from dune.perftool.generation import symbol
+from dune.perftool.generation import pymbolic_expr, symbol, globalarg
 from dune.perftool.ufl.modified_terminals import ModifiedArgumentDescriptor
 
 
@@ -11,11 +11,37 @@ def name_testfunction(ma):
     return "{}a{}".format("grad_" if ma.grad else "", ma.argexpr.number())
 
 
+@pymbolic_expr
+def pymbolic_testfunction(ma):
+    assert bool(ma.index) == ma.grad
+    from pymbolic.primitives import Subscript, Variable
+    v = Variable(name_testfunction(ma))
+    globalarg(name_testfunction(ma))
+    if ma.grad:
+        from dune.perftool.pdelab import name_index
+        return Subscript(v, Variable(name_index(ma.index)))
+    else:
+        return v
+
+
 @symbol
 def name_trialfunction(ma):
     return "{}c{}".format("grad_" if ma.grad else "", ma.argexpr.count())
 
 
+@pymbolic_expr
+def pymbolic_trialfunction(ma):
+    assert bool(ma.index) == ma.grad
+    from pymbolic.primitives import Subscript, Variable
+    v = Variable(name_trialfunction(ma))
+    globalarg(name_trialfunction(ma))
+    if ma.grad:
+        from dune.perftool.pdelab import name_index
+        return Subscript(v, Variable(name_index(ma.index)))
+    else:
+        return v
+
+
 @symbol
 def name_testfunctionspace(*a):
     # TODO
@@ -46,6 +72,13 @@ def name_argument(ma):
     assert False
 
 
+def pymbolic_argument(ma):
+    if ma.argexpr.number() == 0:
+        return pymbolic_testfunction(ma)
+    if ma.argexpr.number() == 1:
+        return pymbolic_trialfunction(ma)
+    assert False
+
 @symbol
 def name_residual():
     return "r"
diff --git a/python/dune/perftool/ufl/modified_terminals.py b/python/dune/perftool/ufl/modified_terminals.py
index 8bf5173a..6beed959 100644
--- a/python/dune/perftool/ufl/modified_terminals.py
+++ b/python/dune/perftool/ufl/modified_terminals.py
@@ -77,7 +77,7 @@ class ModifiedArgumentDescriptor(MultiFunction):
         self(o.ufl_operands[0])
 
     def indexed(self, o):
-        indexed = o.ufl_operands[1]
+        self.index = o.ufl_operands[1]
         self(o.ufl_operands[0])
 
     def argument(self, o):
-- 
GitLab