From 7a9fbcda08f62c5e5b51d4e7e11fa74117da8c9b Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Wed, 13 Apr 2016 14:33:39 +0200
Subject: [PATCH] Steps towards correct local function space extraction

---
 python/dune/perftool/loopy/transformer.py | 23 ++++++-
 python/dune/perftool/pdelab/argument.py   |  2 +-
 python/dune/perftool/pdelab/basis.py      | 84 ++++++++++++++++++++---
 python/dune/perftool/pdelab/geometry.py   |  4 +-
 python/dune/perftool/pdelab/quadrature.py |  3 +-
 5 files changed, 101 insertions(+), 15 deletions(-)

diff --git a/python/dune/perftool/loopy/transformer.py b/python/dune/perftool/loopy/transformer.py
index d225ea16..db0a0fbc 100644
--- a/python/dune/perftool/loopy/transformer.py
+++ b/python/dune/perftool/loopy/transformer.py
@@ -18,7 +18,11 @@ from dune.perftool.generation import (domain,
                                       valuearg,
                                       )
 
-from dune.perftool.pdelab.basis import lfs_iname, lfs_iname_bound
+from dune.perftool.pdelab.basis import (lfs_iname,
+                                        name_lfs,
+                                        name_lfs_bound,
+                                        traverse_lfs_tree,
+                                        )
 
 
 @iname
@@ -83,9 +87,17 @@ def transform_accumulation_term(term):
 
     rmap = {}
     for ma in test_ma:
+        # Set up the local function space structure
+        traverse_lfs_tree(ma)
+
+        # Get the expression for the modified argument representing the test function
         from dune.perftool.pdelab.argument import pymbolic_testfunction
         rmap[ma.expr] = pymbolic_testfunction(ma)
     for ma in trial_ma:
+        # Set up the local function space structure
+        traverse_lfs_tree(ma)
+
+        # Get the expression for the modified argument representing the trial function
         from dune.perftool.pdelab.argument import pymbolic_trialfunction
         rmap[ma.expr] = pymbolic_trialfunction(ma)
 
@@ -109,6 +121,7 @@ def transform_accumulation_term(term):
 
     # The data that is used to collect the arguments for the accumulate function
     accumargs = []
+    residual_shape = {}
 
     # Generate the code for the modified arguments:
     for arg in test_ma:
@@ -116,12 +129,16 @@ def transform_accumulation_term(term):
         accumargs.append(name_argumentspace(arg))
         accumargs.append(lfs_iname(arg.argexpr.element(), argcount=arg.argexpr.count()))
 
+        # Determine the shape
+        residual_shape[arg.argexpr.number()] = name_lfs_bound(name_lfs(arg.argexpr.element()))
+
     from dune.perftool.pdelab.argument import name_residual
     residual = name_residual()
 
     # The residual/the jacobian should be represented through a loopy global argument
-    from dune.perftool.ufl.rank import ufl_rank
-    globalarg(residual, shape=tuple(lfs_iname_bound(i) for i in range(ufl_rank(term))))
+    # TODO this seems still a bit hacky, esp. w.r.t. systems
+    shape = tuple(v for k, v in sorted(residual_shape.items(), key=lambda (k, v): k))
+    globalarg(residual, shape=shape)
 
     from dune.perftool.generation import retrieve_cache_items
     inames = retrieve_cache_items("iname")
diff --git a/python/dune/perftool/pdelab/argument.py b/python/dune/perftool/pdelab/argument.py
index a62b0d49..bb06887b 100644
--- a/python/dune/perftool/pdelab/argument.py
+++ b/python/dune/perftool/pdelab/argument.py
@@ -6,9 +6,9 @@ from dune.perftool.pdelab import name_index
 from dune.perftool.pdelab.basis import (evaluate_trialfunction,
                                         evaluate_trialfunction_gradient,
                                         lfs_iname,
-                                        lfs_iname_bound,
                                         name_basis,
                                         name_basis_gradient,
+                                        name_lfs_bound,
                                         )
 
 from pymbolic.primitives import Subscript, Variable
diff --git a/python/dune/perftool/pdelab/basis.py b/python/dune/perftool/pdelab/basis.py
index 79891864..ac1195a5 100644
--- a/python/dune/perftool/pdelab/basis.py
+++ b/python/dune/perftool/pdelab/basis.py
@@ -2,8 +2,10 @@
 
 from dune.perftool.generation import (cached,
                                       domain,
+                                      generator_factory,
                                       iname,
                                       instruction,
+                                      preamble,
                                       symbol,
                                       temporary_variable,
                                       )
@@ -14,19 +16,85 @@ from dune.perftool.pdelab.geometry import (name_jacobian_inverse_transposed,
                                            )
 
 
-# TODO having lfs_iname_bound get the number as parameter is completely
-# broken. It should accept an element instead, but I need to find out how
-# it should be done in the jacobian case first.
+@preamble('blubb')
+def define_lfs_bound(lfs, bound):
+    return 'auto {} = {}.size();'.format(bound, lfs)
+
+
 @symbol
-def lfs_iname_bound(number):
-    return "arg{}_n".format(number)
+def name_lfs_bound(lfs):
+    bound = '{}_size'.format(lfs)
+    define_lfs_bound(lfs, bound)
+
+    return bound
+
+
+@generator_factory(cache_key_generator=lambda e, **kw: e, item_tags=('blubb',))
+def name_lfs(element, prefix=None):
+    # Omitting the prefix is only valid upon a second call, which will
+    # result in a cache hit.
+    assert prefix
+
+    # Additionally, element is expected to be a ufl finite element
+    from ufl import FiniteElementBase
+    assert isinstance(element, FiniteElementBase)
+
+    # Recurse into the given element to define all other local function spaces!
+    from ufl import MixedElement
+    if isinstance(element, MixedElement):
+        for i, subelem in enumerate(element.sub_elements()):
+            # TODO in this case, we need to trigger the extraction mechanism
+            # as preambles in our code.
+            name_lfs(subelem, prefix + "_" + str(i))
+
+    # Now trigger the creation of all those other symbols/preables necessary for this lfs
+
+    # Now return the prefix!
+    return prefix
+
+
+def traverse_lfs_tree(arg):
+    from dune.perftool.ufl.modified_terminals import ModifiedArgumentDescriptor
+    assert isinstance(arg, ModifiedArgumentDescriptor)
+
+    # First we need to determine the basename as given in the signature of
+    # this kernel method!
+    basename = None
+    from ufl.classes import Argument, Coefficient
+    if isinstance(arg.argexpr, Argument):
+        if arg.argexpr.count() == 0:
+            basename = 'lfsv'
+        if arg.argexpr.count() == 1:
+            basename = 'lfsu'
+        # TODO add restrictions here.
+
+    if isinstance(arg.argexpr, Coefficient):
+        # We should only ever call this for a trialfunction, which in our case
+        # is the coefficient of reserved index 0.
+        assert arg.argexpr.count() == 0
+
+        basename = 'lfsu'
+
+    assert basename
+
+    # Now start recursively extracting local function spaces and fill the cache with
+    # all those values. That way we can later get a correct local function space with
+    # just the ufl finite element.
+    name_lfs(arg.argexpr.element(), prefix=basename)
 
 
 @iname
 def _lfs_iname(element, argcount):
-    ainame = "arg{}".format(chr(ord("i") + argcount))
-    domain(ainame, lfs_iname_bound(argcount))
-    return ainame
+    name = name_lfs(element)
+    bound = name_lfs_bound(name)
+
+    if argcount != 0:
+        name = 'lfsu'
+
+    name = name + '_index'
+    domain(name, bound)
+
+    return name
 
 
 def lfs_iname(element, argcount=0):
diff --git a/python/dune/perftool/pdelab/geometry.py b/python/dune/perftool/pdelab/geometry.py
index 9dfdeee7..15eeaa04 100644
--- a/python/dune/perftool/pdelab/geometry.py
+++ b/python/dune/perftool/pdelab/geometry.py
@@ -1,7 +1,7 @@
 from dune.perftool.generation import preamble, symbol
 
 
-@preamble
+@preamble('blubb')
 def define_geometry(name):
     return "auto {} = eg.geometry();".format(name)
 
@@ -18,7 +18,7 @@ def name_dimension():
     return "dim"
 
 
-@preamble
+@preamble('blubb')
 def define_jacobian_inverse_transposed(name):
     geo = name_geometry()
     return "auto {} = {}.jacobianInverseTransposed();".format(name,
diff --git a/python/dune/perftool/pdelab/quadrature.py b/python/dune/perftool/pdelab/quadrature.py
index a6febb7b..d43c73c7 100644
--- a/python/dune/perftool/pdelab/quadrature.py
+++ b/python/dune/perftool/pdelab/quadrature.py
@@ -1,4 +1,5 @@
-from dune.perftool.generation import (domain,
+from dune.perftool.generation import (cached,
+                                      domain,
                                       iname,
                                       instruction,
                                       symbol,
-- 
GitLab