From a6c4562090b314b194f0a67e0bc8a490c61e7f25 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Wed, 15 Jun 2016 17:21:25 +0200
Subject: [PATCH] Implement vector valued trial function evaluation

---
 python/dune/perftool/pdelab/basis.py | 26 +++++++++++++++++++-------
 1 file changed, 19 insertions(+), 7 deletions(-)

diff --git a/python/dune/perftool/pdelab/basis.py b/python/dune/perftool/pdelab/basis.py
index 3ea24b34..37ec4628 100644
--- a/python/dune/perftool/pdelab/basis.py
+++ b/python/dune/perftool/pdelab/basis.py
@@ -308,20 +308,32 @@ def evaluate_trialfunction(element, name, restriction, component):
     from ufl.functionview import select_subelement
     sub_element = select_subelement(element, component)
 
-    # Right now, no chance of getting velocity field...
-    assert len(sub_element.value_shape()) == 0
-#     element = sub_element
+    # Determine the rank of the trialfunction tensor
+    rank = len(sub_element.value_shape())
+    assert rank in (0, 1)
 
-    temporary_variable(name, shape=())
+    shape = (name_dimension(),) * rank
+    shape_impl = ('fv', ) * rank
+    idims = tuple(dimension_iname(count=i) for i in range(rank))
+    leaf_element = sub_element
+    from ufl import VectorElement
+    if isinstance(sub_element, VectorElement):
+        leaf_element = sub_element.sub_elements()[0]
+
+    temporary_variable(name, shape=shape, shape_impl=shape_impl)
     lfs = name_lfs(element, restriction, component)
-    index = lfs_iname(sub_element, restriction, context='trial')
-    basis = name_basis(sub_element, restriction)
+    index = lfs_iname(leaf_element, restriction, context='trial')
+    basis = name_basis(leaf_element, restriction)
+
+    if isinstance(sub_element, VectorElement):
+        lfs = lfs_child(lfs, idims[0])
+
     from dune.perftool.pdelab.argument import pymbolic_coefficient
     coeff = pymbolic_coefficient(lfs, index, restriction)
     reduction_expr = Product((coeff, Subscript(Variable(basis), Variable(index))))
     instruction(expression=Reduction("sum", index, reduction_expr, allow_simultaneous=True),
                 assignee=Variable(name),
-                forced_iname_deps=frozenset({quadrature_iname()}),
+                forced_iname_deps=frozenset({quadrature_iname()}).union(frozenset(idims)),
                 forced_iname_deps_is_final=True,
                 )
 
-- 
GitLab