From cc8841be75082aa69e8d6906535e6f46504795da Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Wed, 6 Sep 2017 09:24:43 +0200
Subject: [PATCH] Avoid testing with num_sub_elements()

We recently learned that usage of sub_elements in UFL is ambiguous...
---
 python/dune/perftool/blockstructured/basis.py |  6 ++++--
 python/dune/perftool/pdelab/basis.py          | 12 ++++++++----
 python/dune/perftool/sumfact/basis.py         |  6 +++---
 python/dune/perftool/sumfact/realization.py   |  4 +++-
 python/dune/perftool/sumfact/symbolic.py      | 10 ++++++----
 5 files changed, 24 insertions(+), 14 deletions(-)

diff --git a/python/dune/perftool/blockstructured/basis.py b/python/dune/perftool/blockstructured/basis.py
index 59d55174..77b266d7 100644
--- a/python/dune/perftool/blockstructured/basis.py
+++ b/python/dune/perftool/blockstructured/basis.py
@@ -21,6 +21,8 @@ from dune.perftool.pdelab.spaces import type_leaf_gfs
 from dune.perftool.pdelab.restriction import restricted_name
 from dune.perftool.blockstructured.spaces import lfs_inames
 
+from ufl import MixedElement
+
 import pymbolic.primitives as prim
 
 
@@ -79,7 +81,7 @@ def evaluate_basis(leaf_element, name, restriction):
 
 
 def pymbolic_basis(leaf_element, restriction, number, context=''):
-    assert leaf_element.num_sub_elements() == 0
+    assert not isinstance(leaf_element, MixedElement)
     name = "phi_{}".format(FEM_name_mangling(leaf_element))
     name = restricted_name(name, restriction)
     evaluate_basis(leaf_element, name, restriction)
@@ -102,7 +104,7 @@ def evaluate_reference_gradient(leaf_element, name, restriction):
 
 
 def pymbolic_reference_gradient(leaf_element, restriction, number, context=''):
-    assert leaf_element.num_sub_elements() == 0
+    assert not isinstance(leaf_element, MixedElement)
     name = "js_{}".format(FEM_name_mangling(leaf_element))
     name = restricted_name(name, restriction)
     evaluate_reference_gradient(leaf_element, name, restriction)
diff --git a/python/dune/perftool/pdelab/basis.py b/python/dune/perftool/pdelab/basis.py
index 2bda1822..0c7ab18a 100644
--- a/python/dune/perftool/pdelab/basis.py
+++ b/python/dune/perftool/pdelab/basis.py
@@ -35,7 +35,11 @@ from dune.perftool.pdelab.restriction import restricted_name
 from dune.perftool.pdelab.driver import (isPk,
                                          isQk,
                                          isDG)
+
 from pymbolic.primitives import Product, Subscript, Variable
+
+from ufl import MixedElement
+
 from loopy import Reduction
 
 
@@ -107,7 +111,7 @@ def evaluate_basis(leaf_element, name, restriction):
 
 
 def pymbolic_basis(leaf_element, restriction, number, context=''):
-    assert leaf_element.num_sub_elements() == 0
+    assert not isinstance(leaf_element, MixedElement)
     name = "phi_{}".format(FEM_name_mangling(leaf_element))
     name = restricted_name(name, restriction)
     evaluate_basis(leaf_element, name, restriction)
@@ -135,7 +139,7 @@ def evaluate_reference_gradient(leaf_element, name, restriction):
 
 
 def pymbolic_reference_gradient(leaf_element, restriction, number, context=''):
-    assert leaf_element.num_sub_elements() == 0
+    assert not isinstance(leaf_element, MixedElement)
     name = "js_{}".format(FEM_name_mangling(leaf_element))
     name = restricted_name(name, restriction)
     evaluate_reference_gradient(leaf_element, name, restriction)
@@ -156,7 +160,7 @@ def shape_as_pymbolic(shape):
 @kernel_cached
 def evaluate_coefficient(visitor, element, name, container, restriction, index):
     sub_element = element
-    if element.num_sub_elements() > 0:
+    if isinstance(element, MixedElement):
         sub_element = element.extract_component(index)[1]
 
     from ufl import FiniteElement
@@ -189,7 +193,7 @@ def evaluate_coefficient(visitor, element, name, container, restriction, index):
 @kernel_cached
 def evaluate_coefficient_gradient(visitor, element, name, container, restriction, index):
     sub_element = element
-    if element.num_sub_elements() > 0:
+    if isinstance(element, MixedElement):
         sub_element = element.extract_component(index)[1]
     from ufl import FiniteElement
     assert isinstance(sub_element, FiniteElement)
diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py
index 8e745c7d..fc2b5a57 100644
--- a/python/dune/perftool/sumfact/basis.py
+++ b/python/dune/perftool/sumfact/basis.py
@@ -155,10 +155,10 @@ def pymbolic_coefficient_gradient(element, restriction, index, coeff_func, visit
 @kernel_cached
 def pymbolic_coefficient(element, restriction, index, coeff_func, visitor_indices):
     sub_element = element
-    if element.num_sub_elements() > 0:
+    if isinstance(element, MixedElement):
         sub_element = element.extract_component(index)[1]
     from ufl import FiniteElement
-    assert isinstance(sub_element, FiniteElement)
+    assert isinstance(sub_element, (FiniteElement, TensorProductElement))
 
     # Basis functions per direction
     basis_size = _basis_functions_per_direction(sub_element)
@@ -252,7 +252,7 @@ def pymbolic_basis(element, restriction, number):
     if number == 0:
         return 1
 
-    assert element.num_sub_elements() == 0
+    assert not isinstance(element, MixedElement)
 
     name = "phi_{}".format(FEM_name_mangling(element))
     name = restricted_name(name, restriction)
diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py
index 98276bbc..91200069 100644
--- a/python/dune/perftool/sumfact/realization.py
+++ b/python/dune/perftool/sumfact/realization.py
@@ -30,6 +30,8 @@ from dune.perftool.sumfact.vectorization import attach_vectorization_info
 from dune.perftool.sumfact.accumulation import sumfact_iname
 from dune.perftool.loopy.vcl import ExplicitVCLCast
 
+from ufl import MixedElement
+
 import loopy as lp
 import numpy as np
 import pymbolic.primitives as prim
@@ -229,7 +231,7 @@ def _realize_sum_factorization_kernel(sf):
                 direct_output = "{}x{}".format(direct_output, sf.trial_element_index)
                 rowsize = sum(tuple(s for s in _local_sizes(sf.trial_element)))
                 element = sf.trial_element
-                if element.num_sub_elements() > 0:
+                if isinstance(element, MixedElement):
                     element = element.extract_component(sf.trial_element_index)[1]
                 other_shape = tuple(element.degree() + 1 for e in range(sf.length))
                 from pytools import product
diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py
index a2fe394e..7aa9c327 100644
--- a/python/dune/perftool/sumfact/symbolic.py
+++ b/python/dune/perftool/sumfact/symbolic.py
@@ -8,6 +8,8 @@ from dune.perftool.sumfact.tabulation import BasisTabulationMatrixBase, BasisTab
 
 from pytools import ImmutableRecord, product
 
+from ufl import MixedElement
+
 import pymbolic.primitives as prim
 import loopy as lp
 import frozendict
@@ -192,7 +194,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
         else:
             from dune.perftool.sumfact.basis import lfs_inames
             element = self.trial_element
-            if element.num_sub_elements() > 0:
+            if isinstance(element, MixedElement):
                 element = element.extract_component(self.trial_element_index)[1]
             return lfs_inames(element, self.restriction)
 
@@ -212,7 +214,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
 
     def quadrature_index(self, _):
         element = self.trial_element
-        if element is not None and element.num_sub_elements() > 0:
+        if element is not None and isinstance(element, MixedElement):
             element = element.extract_component(self.trial_element_index)[1]
         quad_inames = quadrature_inames(element)
         if len(self.matrix_sequence) == local_dimension():
@@ -466,7 +468,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
 
     def _quadrature_index(self, sf):
         element = self.trial_element
-        if element is not None and element.num_sub_elements() > 0:
+        if element is not None and isinstance(element, MixedElement):
             element = element.extract_component(self.trial_element_index)[1]
         quad_inames = quadrature_inames(element)
         index = []
@@ -498,7 +500,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
 
     def vec_index(self, sf):
         element = self.trial_element
-        if element is not None and element.num_sub_elements() > 0:
+        if element is not None and isinstance(element, MixedElement):
             element = element.extract_component(self.trial_element_index)[1]
         quad_inames = quadrature_inames(element)
         sliced = 0
-- 
GitLab