From 34010e9ea5423e75b855bd69443b61313346ba06 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ren=C3=A9=20He=C3=9F?= <rene.hess@iwr.uni-heidelberg.de>
Date: Tue, 29 Aug 2017 15:44:22 +0200
Subject: [PATCH] TensorProductElements (with same degree in all directions)

- Use TensorProductElement in one example.

- We still use the sumfact option to swicth to the sum factorization
  code branch. This way it is still possible to easily swicth between
  sumfact and non sumfact code.

- Only TensorProductElements with the same degree in all directions
  will work. Anisotropie and adaption of quadrature rule will happen
  in the next commits.
---
 .../dune/perftool/pdelab/driver/__init__.py   | 10 ++-
 .../perftool/pdelab/driver/constraints.py     |  4 +-
 .../pdelab/driver/gridfunctionspace.py        | 65 ++++++++++++++-----
 .../perftool/pdelab/driver/interpolate.py     |  4 +-
 python/dune/perftool/pdelab/driver/vtk.py     |  5 +-
 python/dune/perftool/pdelab/quadrature.py     |  5 +-
 python/dune/perftool/sumfact/accumulation.py  | 12 ++--
 python/dune/perftool/sumfact/basis.py         | 19 +++---
 python/dune/perftool/ufl/visitor.py           |  5 +-
 test/sumfact/poisson/poisson_2d.ufl           |  5 +-
 10 files changed, 96 insertions(+), 38 deletions(-)

diff --git a/python/dune/perftool/pdelab/driver/__init__.py b/python/dune/perftool/pdelab/driver/__init__.py
index 81ebce1e..2c38881f 100644
--- a/python/dune/perftool/pdelab/driver/__init__.py
+++ b/python/dune/perftool/pdelab/driver/__init__.py
@@ -136,7 +136,7 @@ def isDG(fem):
 
 
 def FEM_name_mangling(fem):
-    from ufl import MixedElement, VectorElement, FiniteElement, TensorElement
+    from ufl import MixedElement, VectorElement, FiniteElement, TensorElement, TensorProductElement
     if isinstance(fem, VectorElement):
         return FEM_name_mangling(fem.sub_elements()[0]) + "_" + str(fem.num_sub_elements())
     if isinstance(fem, TensorElement):
@@ -155,6 +155,14 @@ def FEM_name_mangling(fem):
             return "Q" + str(fem._degree)
         if isDG(fem):
             return "DG" + str(fem._degree)
+    if isinstance(fem, TensorProductElement):
+        assert(len(set(subel._short_name for subel in fem.sub_elements())) == 1)
+
+        if isLagrange(fem.sub_elements()[0]):
+            return "TensorQ" + '_'.join(map(str, fem._degree))
+        if isDG(fem.sub_elements()[0]):
+            return "TensorDG" + '_'.join(map(str, fem._degree))
+        raise NotImplementedError("fem name mangling")
 
     raise NotImplementedError("FEM NAME MANGLING")
 
diff --git a/python/dune/perftool/pdelab/driver/constraints.py b/python/dune/perftool/pdelab/driver/constraints.py
index 20525328..9158fc6b 100644
--- a/python/dune/perftool/pdelab/driver/constraints.py
+++ b/python/dune/perftool/pdelab/driver/constraints.py
@@ -14,7 +14,7 @@ from dune.perftool.pdelab.driver.gridfunctionspace import (name_gfs,
                                                            preprocess_leaf_data,
                                                            )
 
-from ufl import FiniteElement, MixedElement, TensorElement, VectorElement
+from ufl import FiniteElement, MixedElement, TensorElement, VectorElement, TensorProductElement
 
 
 def name_assembled_constraints():
@@ -53,7 +53,7 @@ def name_bctype_function(element, is_dirichlet):
         define_composite_bctype_function(element, is_dirichlet, name, tuple(childs))
         return name
     else:
-        assert isinstance(element, FiniteElement)
+        assert isinstance(element, (FiniteElement, TensorProductElement))
         name = "{}_bctype".format(FEM_name_mangling(element).lower())
         define_bctype_function(element, is_dirichlet[0], name)
         return name
diff --git a/python/dune/perftool/pdelab/driver/gridfunctionspace.py b/python/dune/perftool/pdelab/driver/gridfunctionspace.py
index e442c08e..2c042e09 100644
--- a/python/dune/perftool/pdelab/driver/gridfunctionspace.py
+++ b/python/dune/perftool/pdelab/driver/gridfunctionspace.py
@@ -7,6 +7,7 @@ from dune.perftool.pdelab.driver import (FEM_name_mangling,
                                          get_dimension,
                                          get_test_element,
                                          get_trial_element,
+                                         isLagrange,
                                          isDG,
                                          isPk,
                                          isQk,
@@ -16,7 +17,7 @@ from dune.perftool.pdelab.driver import (FEM_name_mangling,
                                          preprocess_leaf_data,
                                          )
 
-from ufl import FiniteElement, MixedElement, TensorElement, VectorElement
+from ufl import FiniteElement, MixedElement, TensorElement, VectorElement, TensorProductElement
 
 
 @preamble
@@ -124,24 +125,47 @@ def typedef_fem(element, name):
     df = type_domainfield()
     r = type_range()
     dim = get_dimension()
+
     if get_option("blockstructured"):
         include_file("dune/perftool/blockstructured/blockstructuredqkfem.hh", filetag="driver")
         degree = element.degree() * get_option("number_of_blocks")
-        return "using {} = Dune::PDELab::BlockstructuredQkLocalFiniteElementMap<{}, {}, {}, {}>;".format(name, gv, df, r, degree)
-    if isQk(element):
+        return "using {} = Dune::PDELab::BlockstructuredQkLocalFiniteElementMap<{}, {}, {}, {}>;" \
+            .format(name, gv, df, r, degree)
+
+    if isinstance(element, TensorProductElement):
+        # Only allow TensorProductElements where all subelements are
+        # of the same type ('CG' or 'DG')
+        assert(len(set(subel._short_name for subel in element.sub_elements())) == 1)
+
+        # TensorProductElements have Qk structure -> no Pk
+        if isLagrange(element.sub_elements()[0]):
+            include_file("dune/pdelab/finiteelementmap/qkfem.hh", filetag="driver")
+            return "using {} = Dune::PDELab::QkLocalFiniteElementMap<{}, {}, {}, {}>;" \
+                .format(name, gv, df, r, max(element.degree()))
+        elif isDG(element.sub_elements()[0]):
+            include_file("dune/pdelab/finiteelementmap/qkdg.hh", filetag="driver")
+            # TODO allow switching the basis here!
+            return "using {} = Dune::PDELab::QkDGLocalFiniteElementMap<{}, {}, {}, {}>;" \
+                .format(name, df, r, max(element.degree()), dim)
+        raise NotImplementedError("FEM not implemented in dune-perftool")
+    elif isQk(element):
         include_file("dune/pdelab/finiteelementmap/qkfem.hh", filetag="driver")
-        return "using {} = Dune::PDELab::QkLocalFiniteElementMap<{}, {}, {}, {}>;".format(name, gv, df, r, element.degree())
-    if isPk(element):
+        return "using {} = Dune::PDELab::QkLocalFiniteElementMap<{}, {}, {}, {}>;" \
+            .format(name, gv, df, r, element.degree())
+    elif isPk(element):
         include_file("dune/pdelab/finiteelementmap/pkfem.hh", filetag="driver")
-        return "using {} = Dune::PDELab::PkLocalFiniteElementMap<{}, {}, {}, {}>;".format(name, gv, df, r, element.degree())
-    if isDG(element):
+        return "using {} = Dune::PDELab::PkLocalFiniteElementMap<{}, {}, {}, {}>;" \
+            .format(name, gv, df, r, element.degree())
+    elif isDG(element):
         if isQuadrilateral(element.cell()):
             include_file("dune/pdelab/finiteelementmap/qkdg.hh", filetag="driver")
             # TODO allow switching the basis here!
-            return "using {} = Dune::PDELab::QkDGLocalFiniteElementMap<{}, {}, {}, {}>;".format(name, df, r, element.degree(), dim)
+            return "using {} = Dune::PDELab::QkDGLocalFiniteElementMap<{}, {}, {}, {}>;" \
+                .format(name, df, r, element.degree(), dim)
         if isSimplical(element.cell()):
             include_file("dune/pdelab/finiteelementmap/opbfem.hh", filetag="driver")
-            return "using {} = Dune::PDELab::OPBLocalFiniteElementMap<{}, {}, {}, {}, Dune::GeometryType::simplex>;".format(name, df, r, element.degree(), dim)
+            return "using {} = Dune::PDELab::OPBLocalFiniteElementMap<{}, {}, {}, {}, Dune::GeometryType::simplex>;" \
+                .format(name, df, r, element.degree(), dim)
         raise NotImplementedError("Geometry type not known in code generation")
 
     raise NotImplementedError("FEM not implemented in dune-perftool")
@@ -157,15 +181,26 @@ def type_fem(element):
 def define_fem(element, name):
     femtype = type_fem(element)
     from dune.perftool.pdelab.driver import isDG
-    if isDG(element):
+    if isinstance(element, TensorProductElement):
+        # Only allow TensorProductElements where all subelements are
+        # of the same type ('CG' or 'DG')
+        assert(len(set(subel._short_name for subel in element.sub_elements())) == 1)
+        if isDG(element.sub_elements()[0]):
+            return "{} {};".format(femtype, name)
+        else:
+            assert(isLagrange(element.sub_elements()[0]))
+            gv = name_leafview()
+            return "{} {}({});".format(femtype, name, gv)
+    elif isDG(element):
         return "{} {};".format(femtype, name)
     else:
+        assert(isLagrange(element))
         gv = name_leafview()
         return "{} {}({});".format(femtype, name, gv)
 
 
 def name_fem(element):
-    assert isinstance(element, FiniteElement)
+    assert isinstance(element, (FiniteElement, TensorProductElement))
     name = "{}_fem".format(FEM_name_mangling(element).lower())
     define_fem(element, name)
     return name
@@ -192,7 +227,7 @@ def name_gfs(element, is_dirichlet, treepath=()):
                                        "_".join(str(t) for t in treepath))
         define_power_gfs(element, is_dirichlet, name, subgfs)
         return name
-    if isinstance(element, MixedElement):
+    elif isinstance(element, MixedElement):
         k = 0
         subgfs = []
         for i, subel in enumerate(element.sub_elements()):
@@ -203,7 +238,7 @@ def name_gfs(element, is_dirichlet, treepath=()):
         define_composite_gfs(element, is_dirichlet, name, tuple(subgfs))
         return name
     else:
-        assert isinstance(element, FiniteElement)
+        assert isinstance(element, (FiniteElement, TensorProductElement))
         name = "{}{}_gfs_{}".format(FEM_name_mangling(element).lower(),
                                     "_dirichlet" if is_dirichlet[0] else "",
                                     "_".join(str(t) for t in treepath))
@@ -230,7 +265,7 @@ def type_gfs(element, is_dirichlet):
         name = "{}_POW{}GFS".format(subgfs, element.num_sub_elements())
         typedef_power_gfs(element, is_dirichlet, name, subgfs)
         return name
-    if isinstance(element, MixedElement):
+    elif isinstance(element, MixedElement):
         k = 0
         subgfs = []
         for subel in element.sub_elements():
@@ -240,7 +275,7 @@ def type_gfs(element, is_dirichlet):
         typedef_composite_gfs(element, name, tuple(subgfs))
         return name
     else:
-        assert isinstance(element, FiniteElement)
+        assert isinstance(element, (FiniteElement, TensorProductElement))
         name = "{}{}_GFS".format(FEM_name_mangling(element).upper(),
                                  "_dirichlet" if is_dirichlet[0] else "",
                                  )
diff --git a/python/dune/perftool/pdelab/driver/interpolate.py b/python/dune/perftool/pdelab/driver/interpolate.py
index cacb75aa..f3cc2afb 100644
--- a/python/dune/perftool/pdelab/driver/interpolate.py
+++ b/python/dune/perftool/pdelab/driver/interpolate.py
@@ -15,7 +15,7 @@ from dune.perftool.pdelab.driver.gridfunctionspace import (name_trial_gfs,
                                                            )
 from dune.perftool.pdelab.driver.gridoperator import (name_parameters,)
 
-from ufl import FiniteElement, MixedElement, TensorElement, VectorElement
+from ufl import FiniteElement, MixedElement, TensorElement, VectorElement, TensorProductElement
 
 
 def _do_interpolate(dirichlet):
@@ -55,7 +55,7 @@ def name_boundary_function(element, func):
         define_compositegfs_parameterfunction(name, tuple(childs))
         return name
     else:
-        assert isinstance(element, FiniteElement)
+        assert isinstance(element, (FiniteElement, TensorProductElement))
         name = get_counted_variable("func")
         define_boundary_function(name, func[0])
         return name
diff --git a/python/dune/perftool/pdelab/driver/vtk.py b/python/dune/perftool/pdelab/driver/vtk.py
index 7edf92f3..824f24de 100644
--- a/python/dune/perftool/pdelab/driver/vtk.py
+++ b/python/dune/perftool/pdelab/driver/vtk.py
@@ -41,7 +41,10 @@ def type_vtkwriter():
 @preamble
 def define_subsamplinglevel(name):
     ini = name_initree()
-    return "int {} = {}.get<int>(\"vtk.subsamplinglevel\", {});".format(name, ini, max(get_trial_element().degree() - 1, 0))
+    degree = get_trial_element().degree()
+    if isinstance(degree, tuple):
+        degree = max(degree)
+    return "int {} = {}.get<int>(\"vtk.subsamplinglevel\", {});".format(name, ini, max(degree - 1, 0))
 
 
 def name_subsamplinglevel():
diff --git a/python/dune/perftool/pdelab/quadrature.py b/python/dune/perftool/pdelab/quadrature.py
index 8356ecd2..313b6113 100644
--- a/python/dune/perftool/pdelab/quadrature.py
+++ b/python/dune/perftool/pdelab/quadrature.py
@@ -191,7 +191,10 @@ def _estimate_quadrature_order():
     integrals = form.integrals_by_type(integral_type)
     polynomial_degree = 0
     for i in integrals:
-        polynomial_degree = max(polynomial_degree, i.metadata()['estimated_polynomial_degree'])
+        degree = i.metadata()['estimated_polynomial_degree']
+        if isinstance(degree, tuple):
+            degree = max(degree)
+        polynomial_degree = max(polynomial_degree, degree)
 
     return polynomial_degree
 
diff --git a/python/dune/perftool/sumfact/accumulation.py b/python/dune/perftool/sumfact/accumulation.py
index fa48689f..3559bdfc 100644
--- a/python/dune/perftool/sumfact/accumulation.py
+++ b/python/dune/perftool/sumfact/accumulation.py
@@ -41,7 +41,7 @@ import loopy as lp
 import numpy as np
 import pymbolic.primitives as prim
 import ufl.classes as uc
-
+from ufl import FiniteElement, TensorProductElement
 
 @iname
 def _sumfact_iname(bound, _type, count):
@@ -138,8 +138,7 @@ def get_accumulation_info(expr, visitor):
 
 
 def _get_childs(element):
-    from ufl import FiniteElement
-    if isinstance(element, FiniteElement):
+    if isinstance(element, (FiniteElement, TensorProductElement)):
         yield (0, element)
     else:
         for i in range(element.value_size()):
@@ -200,9 +199,12 @@ def generate_accumulation_instruction(expr, visitor):
     trial_info = visitor.trial_info
 
     leaf_element = test_info.element
-    if leaf_element.num_sub_elements() > 0:
+    if leaf_element.num_sub_elements() > 0 and not isinstance(leaf_element, TensorProductElement):
         leaf_element = leaf_element.extract_component(test_info.element_index)[1]
-    basis_size = leaf_element.degree() + 1
+    degree = leaf_element._degree
+    if isinstance(degree, tuple):
+        degree = max(degree)
+    basis_size = degree + 1
 
     from dune.perftool.pdelab.localoperator import boundary_predicates
     predicates = boundary_predicates(expr,
diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py
index a58a7bb7..d117dc5b 100644
--- a/python/dune/perftool/sumfact/basis.py
+++ b/python/dune/perftool/sumfact/basis.py
@@ -42,7 +42,7 @@ from dune.perftool.tools import maybe_wrap_subscript
 from dune.perftool.pdelab.basis import shape_as_pymbolic
 from dune.perftool.sumfact.accumulation import sumfact_iname
 
-from ufl import VectorElement, TensorElement
+from ufl import VectorElement, TensorElement, TensorProductElement
 
 from pytools import product, ImmutableRecord
 
@@ -96,20 +96,23 @@ class LFSSumfactKernelInput(SumfactKernelInputBase, ImmutableRecord):
 
 def _basis_functions_per_direction(element):
     """Number of basis functions per direction """
-    from ufl import FiniteElement
-    assert isinstance(element, FiniteElement)
-    return element.degree() + 1
+    from ufl import FiniteElement, TensorProductElement
+    assert isinstance(element, (FiniteElement, TensorProductElement))
+    degree = element.degree()
+    if isinstance(degree, tuple):
+        degree = max(degree)
+    return degree + 1
 
 
 @kernel_cached
 def pymbolic_coefficient_gradient(element, restriction, index, coeff_func, visitor_indices):
     sub_element = element
     grad_index = visitor_indices[0]
-    if element.num_sub_elements() > 0:
+    if element.num_sub_elements() > 0 and not isinstance(element, TensorProductElement):
         sub_element = element.extract_component(index)[1]
 
     from ufl import FiniteElement
-    assert isinstance(sub_element, FiniteElement)
+    assert isinstance(sub_element, (FiniteElement, TensorProductElement))
 
     # Number of basis functions per direction
     basis_size = _basis_functions_per_direction(sub_element)
@@ -190,8 +193,8 @@ def sumfact_lfs_iname(element, bound, dim):
 
 @backend(interface="lfs_inames", name="sumfact")
 def lfs_inames(element, restriction, number=1, context=''):
-    from ufl import FiniteElement
-    assert isinstance(element, FiniteElement)
+    from ufl import FiniteElement, TensorProductElement
+    assert isinstance(element, (FiniteElement, TensorProductElement))
     if number == 0:
         return ()
     else:
diff --git a/python/dune/perftool/ufl/visitor.py b/python/dune/perftool/ufl/visitor.py
index e50818de..7086be2d 100644
--- a/python/dune/perftool/ufl/visitor.py
+++ b/python/dune/perftool/ufl/visitor.py
@@ -24,6 +24,7 @@ from ufl.algorithms import MultiFunction
 from ufl.checks import is_cellwise_constant
 from ufl import (VectorElement,
                  TensorElement,
+                 TensorProductElement,
                  )
 from ufl.classes import (FixedIndex,
                          IndexSum,
@@ -98,7 +99,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
         leaf_element = o.ufl_element()
 
         # Select the correct leaf element in the case of this being a mixed finite element
-        if o.ufl_element().num_sub_elements() > 0:
+        if o.ufl_element().num_sub_elements() > 0 and not isinstance(o.ufl_element(),TensorProductElement):
             index = self.indices[0]
             assert isinstance(index, int)
             self.indices = self.indices[1:]
@@ -127,7 +128,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
             self.interface.initialize_function_spaces(o, self)
 
             index = None
-            if o.ufl_element().num_sub_elements() > 0:
+            if o.ufl_element().num_sub_elements() > 0 and not isinstance(o.ufl_element(), TensorProductElement):
                 index = self.indices[0]
                 assert isinstance(index, int)
                 self.indices = self.indices[1:]
diff --git a/test/sumfact/poisson/poisson_2d.ufl b/test/sumfact/poisson/poisson_2d.ufl
index 0d494b70..d2c78a8d 100644
--- a/test/sumfact/poisson/poisson_2d.ufl
+++ b/test/sumfact/poisson/poisson_2d.ufl
@@ -5,7 +5,10 @@ c = (0.5-x[0])**2 + (0.5-x[1])**2
 g = exp(-1.*c)
 f = 2*(2.-2*c)*g
 
-V = FiniteElement("CG", cell, degree)
+V_0 = FiniteElement("CG", interval, degree)
+V_1 = FiniteElement("CG", interval, degree)
+V = TensorProductElement(V_0, V_1, cell=cell)
+
 u = TrialFunction(V)
 v = TestFunction(V)
 
-- 
GitLab