From 913afbcd4c266330cc775bef3ce69dd0f6aa893b Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Fri, 25 Aug 2017 16:24:58 +0200
Subject: [PATCH] More stuff that got lost in transition

---
 .../dune/perftool/pdelab/driver/__init__.py   | 42 +++++++++++++------
 .../perftool/pdelab/driver/constraints.py     |  5 ++-
 .../pdelab/driver/gridfunctionspace.py        |  5 ++-
 .../perftool/pdelab/driver/interpolate.py     | 11 ++---
 4 files changed, 41 insertions(+), 22 deletions(-)

diff --git a/python/dune/perftool/pdelab/driver/__init__.py b/python/dune/perftool/pdelab/driver/__init__.py
index 8021fa41..a7860faf 100644
--- a/python/dune/perftool/pdelab/driver/__init__.py
+++ b/python/dune/perftool/pdelab/driver/__init__.py
@@ -87,9 +87,10 @@ def mass_form_index(formdatas, data):
             continue
 
 
-def is_linear():
+def is_linear(form=None):
     '''Test if form is linear in trial function'''
-    form = get_formdata().original_form
+    if form is None:
+        form = get_formdata().original_form
     from ufl import derivative
     from ufl.algorithms import expand_derivatives
     jacform = expand_derivatives(derivative(form, form.coefficients()[0]))
@@ -163,22 +164,37 @@ def _flatten_list(l):
         yield l
 
 
+def _unroll_list_tensors(expr):
+    from ufl.classes import ListTensor
+    if isinstance(expr, ListTensor):
+        for op in expr.ufl_operands:
+            yield op
+    else:
+        yield expr
+
+
+def unroll_list_tensors(data):
+    for expr in data:
+        for e in _unroll_list_tensors(expr):
+            yield e
+
+
 def preprocess_leaf_data(element, data):
     data = get_object(data)
     from ufl import MixedElement
     if isinstance(element, MixedElement):
-        # Dirichlet is None -> no dirichlet boundaries
+        # data is None -> use 0 default
         if data is None:
-            return (0,) * element.value_size()
-        # Dirichlet for MixedElement is not iterable -> Same
-        # constraint on all the leafs.
-        elif not isinstance(data, (tuple, list)):
-            return (data,) * element.value_size()
-        # List sizes do not match -> flatten list
-        elif len(data) != element.value_size():
-            flattened = [i for i in _flatten_list(data)]
-            assert len(flattened) == element.value_size()
-            return flattened
+            data = (0,) * element.value_size()
+
+        # Flatten nested lists
+        data = tuple(i for i in _flatten_list(data))
+
+        # Expand any list tensors
+        data = tuple(i for i in unroll_list_tensors(data))
+
+        assert len(data) == element.value_size()
+        return data
     else:
         # Do not return lists for non-MixedElement
         if not isinstance(data, (tuple, list)):
diff --git a/python/dune/perftool/pdelab/driver/constraints.py b/python/dune/perftool/pdelab/driver/constraints.py
index d472882c..bc862b70 100644
--- a/python/dune/perftool/pdelab/driver/constraints.py
+++ b/python/dune/perftool/pdelab/driver/constraints.py
@@ -5,7 +5,8 @@ from dune.perftool.pdelab.driver import (FEM_name_mangling,
                                          get_formdata,
                                          get_trial_element,
                                          )
-from dune.perftool.pdelab.driver.gridfunctionspace import (name_leafview,
+from dune.perftool.pdelab.driver.gridfunctionspace import (name_gfs,
+                                                           name_leafview,
                                                            name_trial_gfs,
                                                            type_range,
                                                            type_trial_gfs,
@@ -48,7 +49,7 @@ def name_bctype_function(element, is_dirichlet):
             subgfs.append(name_gfs(subel, is_dirichlet[k:k + subel.value_size()]))
             k = k + subel.value_size()
         name = "_".join(subgfs)
-        define_composite_bctype_function(element, is_dirichlet, name, subgfs)
+        define_composite_bctype_function(element, is_dirichlet, name, tuple(subgfs))
         return name
     else:
         assert isinstance(element, FiniteElement)
diff --git a/python/dune/perftool/pdelab/driver/gridfunctionspace.py b/python/dune/perftool/pdelab/driver/gridfunctionspace.py
index f2031274..67b7e64e 100644
--- a/python/dune/perftool/pdelab/driver/gridfunctionspace.py
+++ b/python/dune/perftool/pdelab/driver/gridfunctionspace.py
@@ -249,7 +249,7 @@ def define_gfs(element, is_dirichlet, name):
     gv = name_leafview()
     fem = name_fem(element)
     return ["{} {}({}, {});".format(gfstype, name, gv, fem),
-            "{}.name(\"{}\");".format(name, name)]
+            "{}.name(\"{}\");".format(name, name),]
 
 
 @preamble
@@ -264,7 +264,8 @@ def define_power_gfs(element, is_dirichlet, name, subgfs):
 @preamble
 def define_composite_gfs(element, is_dirichlet, name, subgfs):
     gfstype = type_gfs(element, is_dirichlet)
-    return "{} {}({});".format(gfstype, name, ", ".join(subgfs))
+    return ["{} {}({});".format(gfstype, name, ", ".join(subgfs)),
+            "{}.update();".format(name)]
 
 
 @preamble
diff --git a/python/dune/perftool/pdelab/driver/interpolate.py b/python/dune/perftool/pdelab/driver/interpolate.py
index 3e92f4b9..cc046ae2 100644
--- a/python/dune/perftool/pdelab/driver/interpolate.py
+++ b/python/dune/perftool/pdelab/driver/interpolate.py
@@ -43,20 +43,21 @@ def interpolate_vector(func, gfs, name):
                                                            )
 
 
-def name_boundary_function(element, dirichlet):
+@cached
+def name_boundary_function(element, func):
     if isinstance(element, MixedElement):
         k = 0
         childs = []
         for subel in element.sub_elements():
-            childs.append(name_boundary_function(subel, dirichlet[k:k + subel.value_size()]))
+            childs.append(name_boundary_function(subel, func[k:k + subel.value_size()]))
             k = k + subel.value_size()
         name = "_".join(childs)
-        define_composite_boundary_function(name, childs)
+        define_composite_boundary_function(name, tuple(childs))
         return name
     else:
         assert isinstance(element, FiniteElement)
-        name = "{}_boundary".format(FEM_name_mangling(element).lower())
-        define_boundary_function(name, dirichlet[0])
+        name = get_counted_variable("func")
+        define_boundary_function(name, func[0])
         return name
 
 
-- 
GitLab