From cc8841be75082aa69e8d6906535e6f46504795da Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Wed, 6 Sep 2017 09:24:43 +0200 Subject: [PATCH] Avoid testing with num_sub_elements() We recently learned that usage of sub_elements in UFL is ambiguous... --- python/dune/perftool/blockstructured/basis.py | 6 ++++-- python/dune/perftool/pdelab/basis.py | 12 ++++++++---- python/dune/perftool/sumfact/basis.py | 6 +++--- python/dune/perftool/sumfact/realization.py | 4 +++- python/dune/perftool/sumfact/symbolic.py | 10 ++++++---- 5 files changed, 24 insertions(+), 14 deletions(-) diff --git a/python/dune/perftool/blockstructured/basis.py b/python/dune/perftool/blockstructured/basis.py index 59d55174..77b266d7 100644 --- a/python/dune/perftool/blockstructured/basis.py +++ b/python/dune/perftool/blockstructured/basis.py @@ -21,6 +21,8 @@ from dune.perftool.pdelab.spaces import type_leaf_gfs from dune.perftool.pdelab.restriction import restricted_name from dune.perftool.blockstructured.spaces import lfs_inames +from ufl import MixedElement + import pymbolic.primitives as prim @@ -79,7 +81,7 @@ def evaluate_basis(leaf_element, name, restriction): def pymbolic_basis(leaf_element, restriction, number, context=''): - assert leaf_element.num_sub_elements() == 0 + assert not isinstance(leaf_element, MixedElement) name = "phi_{}".format(FEM_name_mangling(leaf_element)) name = restricted_name(name, restriction) evaluate_basis(leaf_element, name, restriction) @@ -102,7 +104,7 @@ def evaluate_reference_gradient(leaf_element, name, restriction): def pymbolic_reference_gradient(leaf_element, restriction, number, context=''): - assert leaf_element.num_sub_elements() == 0 + assert not isinstance(leaf_element, MixedElement) name = "js_{}".format(FEM_name_mangling(leaf_element)) name = restricted_name(name, restriction) evaluate_reference_gradient(leaf_element, name, restriction) diff --git a/python/dune/perftool/pdelab/basis.py b/python/dune/perftool/pdelab/basis.py index 2bda1822..0c7ab18a 100644 --- a/python/dune/perftool/pdelab/basis.py +++ b/python/dune/perftool/pdelab/basis.py @@ -35,7 +35,11 @@ from dune.perftool.pdelab.restriction import restricted_name from dune.perftool.pdelab.driver import (isPk, isQk, isDG) + from pymbolic.primitives import Product, Subscript, Variable + +from ufl import MixedElement + from loopy import Reduction @@ -107,7 +111,7 @@ def evaluate_basis(leaf_element, name, restriction): def pymbolic_basis(leaf_element, restriction, number, context=''): - assert leaf_element.num_sub_elements() == 0 + assert not isinstance(leaf_element, MixedElement) name = "phi_{}".format(FEM_name_mangling(leaf_element)) name = restricted_name(name, restriction) evaluate_basis(leaf_element, name, restriction) @@ -135,7 +139,7 @@ def evaluate_reference_gradient(leaf_element, name, restriction): def pymbolic_reference_gradient(leaf_element, restriction, number, context=''): - assert leaf_element.num_sub_elements() == 0 + assert not isinstance(leaf_element, MixedElement) name = "js_{}".format(FEM_name_mangling(leaf_element)) name = restricted_name(name, restriction) evaluate_reference_gradient(leaf_element, name, restriction) @@ -156,7 +160,7 @@ def shape_as_pymbolic(shape): @kernel_cached def evaluate_coefficient(visitor, element, name, container, restriction, index): sub_element = element - if element.num_sub_elements() > 0: + if isinstance(element, MixedElement): sub_element = element.extract_component(index)[1] from ufl import FiniteElement @@ -189,7 +193,7 @@ def evaluate_coefficient(visitor, element, name, container, restriction, index): @kernel_cached def evaluate_coefficient_gradient(visitor, element, name, container, restriction, index): sub_element = element - if element.num_sub_elements() > 0: + if isinstance(element, MixedElement): sub_element = element.extract_component(index)[1] from ufl import FiniteElement assert isinstance(sub_element, FiniteElement) diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py index 8e745c7d..fc2b5a57 100644 --- a/python/dune/perftool/sumfact/basis.py +++ b/python/dune/perftool/sumfact/basis.py @@ -155,10 +155,10 @@ def pymbolic_coefficient_gradient(element, restriction, index, coeff_func, visit @kernel_cached def pymbolic_coefficient(element, restriction, index, coeff_func, visitor_indices): sub_element = element - if element.num_sub_elements() > 0: + if isinstance(element, MixedElement): sub_element = element.extract_component(index)[1] from ufl import FiniteElement - assert isinstance(sub_element, FiniteElement) + assert isinstance(sub_element, (FiniteElement, TensorProductElement)) # Basis functions per direction basis_size = _basis_functions_per_direction(sub_element) @@ -252,7 +252,7 @@ def pymbolic_basis(element, restriction, number): if number == 0: return 1 - assert element.num_sub_elements() == 0 + assert not isinstance(element, MixedElement) name = "phi_{}".format(FEM_name_mangling(element)) name = restricted_name(name, restriction) diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py index 98276bbc..91200069 100644 --- a/python/dune/perftool/sumfact/realization.py +++ b/python/dune/perftool/sumfact/realization.py @@ -30,6 +30,8 @@ from dune.perftool.sumfact.vectorization import attach_vectorization_info from dune.perftool.sumfact.accumulation import sumfact_iname from dune.perftool.loopy.vcl import ExplicitVCLCast +from ufl import MixedElement + import loopy as lp import numpy as np import pymbolic.primitives as prim @@ -229,7 +231,7 @@ def _realize_sum_factorization_kernel(sf): direct_output = "{}x{}".format(direct_output, sf.trial_element_index) rowsize = sum(tuple(s for s in _local_sizes(sf.trial_element))) element = sf.trial_element - if element.num_sub_elements() > 0: + if isinstance(element, MixedElement): element = element.extract_component(sf.trial_element_index)[1] other_shape = tuple(element.degree() + 1 for e in range(sf.length)) from pytools import product diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py index a2fe394e..7aa9c327 100644 --- a/python/dune/perftool/sumfact/symbolic.py +++ b/python/dune/perftool/sumfact/symbolic.py @@ -8,6 +8,8 @@ from dune.perftool.sumfact.tabulation import BasisTabulationMatrixBase, BasisTab from pytools import ImmutableRecord, product +from ufl import MixedElement + import pymbolic.primitives as prim import loopy as lp import frozendict @@ -192,7 +194,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): else: from dune.perftool.sumfact.basis import lfs_inames element = self.trial_element - if element.num_sub_elements() > 0: + if isinstance(element, MixedElement): element = element.extract_component(self.trial_element_index)[1] return lfs_inames(element, self.restriction) @@ -212,7 +214,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): def quadrature_index(self, _): element = self.trial_element - if element is not None and element.num_sub_elements() > 0: + if element is not None and isinstance(element, MixedElement): element = element.extract_component(self.trial_element_index)[1] quad_inames = quadrature_inames(element) if len(self.matrix_sequence) == local_dimension(): @@ -466,7 +468,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) def _quadrature_index(self, sf): element = self.trial_element - if element is not None and element.num_sub_elements() > 0: + if element is not None and isinstance(element, MixedElement): element = element.extract_component(self.trial_element_index)[1] quad_inames = quadrature_inames(element) index = [] @@ -498,7 +500,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) def vec_index(self, sf): element = self.trial_element - if element is not None and element.num_sub_elements() > 0: + if element is not None and isinstance(element, MixedElement): element = element.extract_component(self.trial_element_index)[1] quad_inames = quadrature_inames(element) sliced = 0 -- GitLab