From 09cc5d4c7dd64545c9e4d7f1308f3100c3e0ec24 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Thu, 21 Jul 2016 17:43:01 +0200
Subject: [PATCH] Use a proper function for lfs children extraction

---
 python/dune/perftool/loopy/functions.py      | 16 ++++++++++++++++
 python/dune/perftool/loopy/transformer.py    |  7 ++++++-
 python/dune/perftool/pdelab/argument.py      | 11 +++++++++--
 python/dune/perftool/pdelab/basis.py         |  5 ++++-
 python/dune/perftool/pdelab/localoperator.py |  4 ++--
 5 files changed, 37 insertions(+), 6 deletions(-)

diff --git a/python/dune/perftool/loopy/functions.py b/python/dune/perftool/loopy/functions.py
index ec06459e..7f6c519e 100644
--- a/python/dune/perftool/loopy/functions.py
+++ b/python/dune/perftool/loopy/functions.py
@@ -4,6 +4,22 @@ from loopy.types import NumpyType
 
 import numpy
 
+class LFSChild(FunctionIdentifier):
+    def __init__(self, lfs):
+        self.lfs = lfs
+
+    def __getinitargs__(self):
+        return (self.lfs,)
+
+    @property
+    def name(self):
+        return '{}.child'.format(self.lfs)
+
+
+def lfs_child_mangler(target, func, dtypes):
+    if isinstance(func, LFSChild):
+        return CallMangleInfo(func.name, (NumpyType(str),), (NumpyType(numpy.int32),))
+
 
 class CoefficientAccess(FunctionIdentifier):
     def __init__(self, restriction):
diff --git a/python/dune/perftool/loopy/transformer.py b/python/dune/perftool/loopy/transformer.py
index 843e73e4..90049a40 100644
--- a/python/dune/perftool/loopy/transformer.py
+++ b/python/dune/perftool/loopy/transformer.py
@@ -115,7 +115,12 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp
 
             lfsi = lfs_iname(subel, ma.restriction, count=count)
 
-            accumargs[2 * icount] = Variable(lfs)
+            # If the LFS is not yet a pymbolic expression, make it one
+            from pymbolic.primitives import Expression
+            if not isinstance(lfs, Expression):
+                lfs = Variable(lfs)
+
+            accumargs[2 * icount] = lfs
             accumargs[2 * icount + 1] = Variable(lfsi)
 
             arg_restr[icount] = ma.restriction
diff --git a/python/dune/perftool/pdelab/argument.py b/python/dune/perftool/pdelab/argument.py
index c96199fb..6ff8444d 100644
--- a/python/dune/perftool/pdelab/argument.py
+++ b/python/dune/perftool/pdelab/argument.py
@@ -76,9 +76,16 @@ def name_coefficientcontainer(restriction):
 @pymbolic_expr
 def pymbolic_coefficient(lfs, index, restriction):
     # TODO introduce a proper type for local function spaces!
-    valuearg(lfs, dtype=loopy.types.NumpyType("str"))
+    if isinstance(lfs, str):
+        valuearg(lfs, dtype=loopy.types.NumpyType("str"))
+
+    # If the LFS is not yet a pymbolic expression, make it one
+    from pymbolic.primitives import Expression
+    if not isinstance(lfs, Expression):
+        lfs = Variable(lfs)
+
     from dune.perftool.loopy.functions import CoefficientAccess
-    return Call(CoefficientAccess(restriction), (Variable(lfs), Variable(index),))
+    return Call(CoefficientAccess(restriction), (lfs, Variable(index),))
 
 
 @symbol
diff --git a/python/dune/perftool/pdelab/basis.py b/python/dune/perftool/pdelab/basis.py
index 0031f468..9a45e0d2 100644
--- a/python/dune/perftool/pdelab/basis.py
+++ b/python/dune/perftool/pdelab/basis.py
@@ -76,7 +76,10 @@ def define_lfs(name, father, child):
 
 
 def lfs_child(lfs, child):
-    return "{}.child({})".format(lfs, child)
+    from pymbolic.primitives import Call
+    from dune.perftool.loopy.functions import LFSChild
+    return Call(LFSChild(lfs), (Variable(child),))
+#     return "{}.child({})".format(lfs, child)
 
 
 @generator_factory(cache_key_generator=lambda e, r, **kw: (e, r))
diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py
index ef2165e7..310dc616 100644
--- a/python/dune/perftool/pdelab/localoperator.py
+++ b/python/dune/perftool/pdelab/localoperator.py
@@ -184,7 +184,7 @@ def generate_kernel(integrals):
     arguments = [i for i in retrieve_cache_items("argument")]
 
     # Get the function manglers
-    from dune.perftool.loopy.functions import accumulation_mangler, coefficient_mangler
+    from dune.perftool.loopy.functions import accumulation_mangler, coefficient_mangler, lfs_child_mangler
 
     # Create the kernel
     from loopy import make_kernel, preprocess_kernel
@@ -192,7 +192,7 @@ def generate_kernel(integrals):
                          instructions + subst_rules,
                          arguments,
                          temporary_variables=temporaries,
-                         function_manglers=[accumulation_mangler, coefficient_mangler],
+                         function_manglers=[accumulation_mangler, coefficient_mangler, lfs_child_mangler],
                          target=DuneTarget()
                          )
 
-- 
GitLab