Skip to content
Snippets Groups Projects
Commit cc8841be authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Avoid testing with num_sub_elements()

We recently learned that usage of sub_elements in UFL is ambiguous...
parent 6dc08920
No related branches found
No related tags found
No related merge requests found
......@@ -21,6 +21,8 @@ from dune.perftool.pdelab.spaces import type_leaf_gfs
from dune.perftool.pdelab.restriction import restricted_name
from dune.perftool.blockstructured.spaces import lfs_inames
from ufl import MixedElement
import pymbolic.primitives as prim
......@@ -79,7 +81,7 @@ def evaluate_basis(leaf_element, name, restriction):
def pymbolic_basis(leaf_element, restriction, number, context=''):
assert leaf_element.num_sub_elements() == 0
assert not isinstance(leaf_element, MixedElement)
name = "phi_{}".format(FEM_name_mangling(leaf_element))
name = restricted_name(name, restriction)
evaluate_basis(leaf_element, name, restriction)
......@@ -102,7 +104,7 @@ def evaluate_reference_gradient(leaf_element, name, restriction):
def pymbolic_reference_gradient(leaf_element, restriction, number, context=''):
assert leaf_element.num_sub_elements() == 0
assert not isinstance(leaf_element, MixedElement)
name = "js_{}".format(FEM_name_mangling(leaf_element))
name = restricted_name(name, restriction)
evaluate_reference_gradient(leaf_element, name, restriction)
......
......@@ -35,7 +35,11 @@ from dune.perftool.pdelab.restriction import restricted_name
from dune.perftool.pdelab.driver import (isPk,
isQk,
isDG)
from pymbolic.primitives import Product, Subscript, Variable
from ufl import MixedElement
from loopy import Reduction
......@@ -107,7 +111,7 @@ def evaluate_basis(leaf_element, name, restriction):
def pymbolic_basis(leaf_element, restriction, number, context=''):
assert leaf_element.num_sub_elements() == 0
assert not isinstance(leaf_element, MixedElement)
name = "phi_{}".format(FEM_name_mangling(leaf_element))
name = restricted_name(name, restriction)
evaluate_basis(leaf_element, name, restriction)
......@@ -135,7 +139,7 @@ def evaluate_reference_gradient(leaf_element, name, restriction):
def pymbolic_reference_gradient(leaf_element, restriction, number, context=''):
assert leaf_element.num_sub_elements() == 0
assert not isinstance(leaf_element, MixedElement)
name = "js_{}".format(FEM_name_mangling(leaf_element))
name = restricted_name(name, restriction)
evaluate_reference_gradient(leaf_element, name, restriction)
......@@ -156,7 +160,7 @@ def shape_as_pymbolic(shape):
@kernel_cached
def evaluate_coefficient(visitor, element, name, container, restriction, index):
sub_element = element
if element.num_sub_elements() > 0:
if isinstance(element, MixedElement):
sub_element = element.extract_component(index)[1]
from ufl import FiniteElement
......@@ -189,7 +193,7 @@ def evaluate_coefficient(visitor, element, name, container, restriction, index):
@kernel_cached
def evaluate_coefficient_gradient(visitor, element, name, container, restriction, index):
sub_element = element
if element.num_sub_elements() > 0:
if isinstance(element, MixedElement):
sub_element = element.extract_component(index)[1]
from ufl import FiniteElement
assert isinstance(sub_element, FiniteElement)
......
......@@ -155,10 +155,10 @@ def pymbolic_coefficient_gradient(element, restriction, index, coeff_func, visit
@kernel_cached
def pymbolic_coefficient(element, restriction, index, coeff_func, visitor_indices):
sub_element = element
if element.num_sub_elements() > 0:
if isinstance(element, MixedElement):
sub_element = element.extract_component(index)[1]
from ufl import FiniteElement
assert isinstance(sub_element, FiniteElement)
assert isinstance(sub_element, (FiniteElement, TensorProductElement))
# Basis functions per direction
basis_size = _basis_functions_per_direction(sub_element)
......@@ -252,7 +252,7 @@ def pymbolic_basis(element, restriction, number):
if number == 0:
return 1
assert element.num_sub_elements() == 0
assert not isinstance(element, MixedElement)
name = "phi_{}".format(FEM_name_mangling(element))
name = restricted_name(name, restriction)
......
......@@ -30,6 +30,8 @@ from dune.perftool.sumfact.vectorization import attach_vectorization_info
from dune.perftool.sumfact.accumulation import sumfact_iname
from dune.perftool.loopy.vcl import ExplicitVCLCast
from ufl import MixedElement
import loopy as lp
import numpy as np
import pymbolic.primitives as prim
......@@ -229,7 +231,7 @@ def _realize_sum_factorization_kernel(sf):
direct_output = "{}x{}".format(direct_output, sf.trial_element_index)
rowsize = sum(tuple(s for s in _local_sizes(sf.trial_element)))
element = sf.trial_element
if element.num_sub_elements() > 0:
if isinstance(element, MixedElement):
element = element.extract_component(sf.trial_element_index)[1]
other_shape = tuple(element.degree() + 1 for e in range(sf.length))
from pytools import product
......
......@@ -8,6 +8,8 @@ from dune.perftool.sumfact.tabulation import BasisTabulationMatrixBase, BasisTab
from pytools import ImmutableRecord, product
from ufl import MixedElement
import pymbolic.primitives as prim
import loopy as lp
import frozendict
......@@ -192,7 +194,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
else:
from dune.perftool.sumfact.basis import lfs_inames
element = self.trial_element
if element.num_sub_elements() > 0:
if isinstance(element, MixedElement):
element = element.extract_component(self.trial_element_index)[1]
return lfs_inames(element, self.restriction)
......@@ -212,7 +214,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
def quadrature_index(self, _):
element = self.trial_element
if element is not None and element.num_sub_elements() > 0:
if element is not None and isinstance(element, MixedElement):
element = element.extract_component(self.trial_element_index)[1]
quad_inames = quadrature_inames(element)
if len(self.matrix_sequence) == local_dimension():
......@@ -466,7 +468,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
def _quadrature_index(self, sf):
element = self.trial_element
if element is not None and element.num_sub_elements() > 0:
if element is not None and isinstance(element, MixedElement):
element = element.extract_component(self.trial_element_index)[1]
quad_inames = quadrature_inames(element)
index = []
......@@ -498,7 +500,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
def vec_index(self, sf):
element = self.trial_element
if element is not None and element.num_sub_elements() > 0:
if element is not None and isinstance(element, MixedElement):
element = element.extract_component(self.trial_element_index)[1]
quad_inames = quadrature_inames(element)
sliced = 0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment