From dc3ddb7f834f8ad0d129a2ff6f4582b7f293f0d1 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Thu, 3 Aug 2017 17:44:55 +0200
Subject: [PATCH] Fix TensorElement tests

---
 python/dune/perftool/pdelab/driver/error.py       | 11 +++++++----
 python/dune/perftool/pdelab/driver/interpolate.py |  7 +++++--
 test/stokes/stokes_stress.ufl                     |  2 +-
 3 files changed, 13 insertions(+), 7 deletions(-)

diff --git a/python/dune/perftool/pdelab/driver/error.py b/python/dune/perftool/pdelab/driver/error.py
index c598f2df..2a8f35b4 100644
--- a/python/dune/perftool/pdelab/driver/error.py
+++ b/python/dune/perftool/pdelab/driver/error.py
@@ -21,7 +21,7 @@ from dune.perftool.pdelab.driver.solve import (define_vector,
                                                dune_solve,
                                                name_vector,
                                                )
-from ufl import MixedElement
+from ufl import MixedElement, TensorElement, VectorElement
 
 
 @preamble
@@ -118,9 +118,12 @@ def _accumulate_L2_squared(treepath):
 
 
 def get_treepath(element, index):
+    if isinstance(element, (VectorElement, TensorElement)):
+        return (index,)
     if isinstance(element, MixedElement):
-        i, rest = element.extract_subelement_component(index)
-        return (i,) + rest
+        pos, rest = element.extract_subelement_component(index)
+        offset = sum(element.sub_elements()[i].value_size() for i in range(pos))
+        return (pos,) + get_treepath(element.sub_elements()[pos], index - offset)
     else:
         return ()
 
@@ -130,7 +133,7 @@ def treepath_to_index(element, treepath, offset=0):
         return offset
     index = treepath[0]
     offset = offset + sum(element.sub_elements()[i].value_size() for i in range(index))
-    subel = element.sub_elements()[treepath[0]]
+    subel = element.sub_elements()[index]
     return treepath_to_index(subel, treepath[1:], offset)
 
 
diff --git a/python/dune/perftool/pdelab/driver/interpolate.py b/python/dune/perftool/pdelab/driver/interpolate.py
index 4278294a..763283c1 100644
--- a/python/dune/perftool/pdelab/driver/interpolate.py
+++ b/python/dune/perftool/pdelab/driver/interpolate.py
@@ -97,8 +97,11 @@ def name_boundary_lambda(boundary):
 def define_boundary_lambda(name, boundary):
     from ufl.classes import Expr
     if boundary is None:
-        return "auto {} = [&](const auto& x){{ return 0.0; }};".format(name)
-    elif isinstance(boundary, Expr):
+        boundary = 0.0
+    if isinstance(boundary, (int, float)):
+        return "auto {} = [&](const auto& x){{ return {}; }};".format(name, boundary)
+    else:
+        assert isinstance(boundary, Expr)
         # Set up a visitor
         with global_context(integral_type="exterior_facet", formdata=get_formdata(), driver=True):
             from dune.perftool.ufl.visitor import UFL2LoopyVisitor
diff --git a/test/stokes/stokes_stress.ufl b/test/stokes/stokes_stress.ufl
index 186d614f..a25a73ad 100644
--- a/test/stokes/stokes_stress.ufl
+++ b/test/stokes/stokes_stress.ufl
@@ -3,7 +3,7 @@ cell = triangle
 x = SpatialCoordinate(cell)
 v_bctype = conditional(x[0] < 1. - 1e-8, 1, 0)
 
-P2 = VectorElement("Lagrange", cell)
+P2 = VectorElement("Lagrange", cell, 2, 2)
 P1 = FiniteElement("Lagrange", cell, 1)
 P2_stress = TensorElement("DG", cell, 1)
 
-- 
GitLab