From 31f426a76cb3c02bcd90f69614ba56570dda802e Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Thu, 8 Dec 2016 17:06:50 +0100
Subject: [PATCH] Properly differeniate between worlddim and dim of the
 integrated entity

---
 python/dune/perftool/pdelab/basis.py       |  4 +--
 python/dune/perftool/pdelab/geometry.py    | 29 +++++++++++-----------
 python/dune/perftool/pdelab/parameter.py   |  1 -
 python/dune/perftool/pdelab/quadrature.py  | 25 ++++++-------------
 python/dune/perftool/sumfact/amatrix.py    |  2 --
 python/dune/perftool/sumfact/basis.py      | 26 ++++++-------------
 python/dune/perftool/sumfact/quadrature.py | 20 ++++++---------
 7 files changed, 39 insertions(+), 68 deletions(-)

diff --git a/python/dune/perftool/pdelab/basis.py b/python/dune/perftool/pdelab/basis.py
index 2364b2bc..251258e8 100644
--- a/python/dune/perftool/pdelab/basis.py
+++ b/python/dune/perftool/pdelab/basis.py
@@ -19,7 +19,7 @@ from dune.perftool.pdelab.spaces import (lfs_child,
                                          type_gfs,
                                          )
 from dune.perftool.pdelab.geometry import (dimension_iname,
-                                           name_dimension,
+                                           world_dimension,
                                            name_jacobian_inverse_transposed,
                                            to_cell_coordinates,
                                            )
@@ -96,7 +96,7 @@ def pymbolic_basis(leaf_element, restriction, number, context=''):
 @kernel_cached
 def evaluate_reference_gradient(leaf_element, name, restriction):
     lfs = name_leaf_lfs(leaf_element, restriction)
-    temporary_variable(name, shape=(name_lfs_bound(lfs), 1, name_dimension()), decl_method=declare_cache_temporary(leaf_element, restriction, 'Jacobian'))
+    temporary_variable(name, shape=(name_lfs_bound(lfs), 1, world_dimension()), decl_method=declare_cache_temporary(leaf_element, restriction, 'Jacobian'))
     cache = name_localbasis_cache(leaf_element)
     qp = get_backend("qp_in_cell")(restriction)
     instruction(inames=get_backend("quad_inames")(),
diff --git a/python/dune/perftool/pdelab/geometry.py b/python/dune/perftool/pdelab/geometry.py
index 3f8392af..973104df 100644
--- a/python/dune/perftool/pdelab/geometry.py
+++ b/python/dune/perftool/pdelab/geometry.py
@@ -13,9 +13,7 @@ from dune.perftool.generation import (backend,
                                       valuearg,
                                       )
 from dune.perftool.options import option_switch
-from dune.perftool.pdelab.quadrature import (pymbolic_quadrature_position_in_cell,
-                                             quadrature_preamble,
-                                             )
+from dune.perftool.pdelab.quadrature import quadrature_preamble
 from dune.perftool.tools import get_pymbolic_basename
 from ufl.algorithms import MultiFunction
 from pymbolic.primitives import Variable
@@ -214,7 +212,7 @@ def apply_in_cell_transformation(name, local, restriction):
 def pymbolic_in_cell_coordinates(local, restriction):
     basename = get_pymbolic_basename(local)
     name = "{}_in_{}side".format(basename, "in" if restriction is Restriction.NEGATIVE else "out")
-    temporary_variable(name, shape=(name_dimension(),), shape_impl=("fv",))
+    temporary_variable(name, shape=(world_dimension(),), shape_impl=("fv",))
     apply_in_cell_transformation(name, local, restriction)
     return Variable(name)
 
@@ -227,18 +225,21 @@ def to_cell_coordinates(local, restriction):
         return pymbolic_in_cell_coordinates(local, restriction)
 
 
-def name_dimension():
+def world_dimension():
     formdata = get_global_context_value('formdata')
     return formdata.geometric_dimension
 
 
-def world_dimension():
-    return name_dimension()
+def intersection_dimension():
+    return world_dimension() - 1
 
 
-def name_intersection_dimension():
-    formdata = get_global_context_value('formdata')
-    return formdata.geometric_dimension - 1
+def local_dimension():
+    it = get_global_context_value('integral_type')
+    if it == "cell":
+        return world_dimension()
+    else:
+        return intersection_dimension()
 
 
 def evaluate_unit_outer_normal(name):
@@ -259,7 +260,7 @@ def declare_normal(name, shape, shape_impl):
 
 def name_unit_outer_normal():
     name = "outer_normal"
-    temporary_variable(name, shape=(name_dimension(),), decl_method=declare_normal)
+    temporary_variable(name, shape=(world_dimension(),), decl_method=declare_normal)
     evaluate_unit_outer_normal(name)
     return "outer_normal"
 
@@ -274,7 +275,7 @@ def evaluate_unit_inner_normal(name):
 
 def name_unit_inner_normal():
     name = "inner_normal"
-    temporary_variable(name, shape=(name_dimension(),), decl_method=declare_normal)
+    temporary_variable(name, shape=(world_dimension(),), decl_method=declare_normal)
     evaluate_unit_inner_normal(name)
     return "inner_normal"
 
@@ -300,7 +301,7 @@ def define_jacobian_inverse_transposed_temporary(restriction):
 def define_constant_jacobian_inveser_transposed(name, restriction):
     geo = name_cell_geometry(restriction)
     pos = name_localcenter()
-    dim = name_dimension()
+    dim = world_dimension()
 
     if restriction:
         geo_in = name_in_cell_geometry(restriction)
@@ -316,7 +317,7 @@ def define_constant_jacobian_inveser_transposed(name, restriction):
 
 @backend(interface="define_jit", name="default")
 def define_jacobian_inverse_transposed(name, restriction):
-    dim = name_dimension()
+    dim = world_dimension()
     temporary_variable(name, decl_method=define_jacobian_inverse_transposed_temporary(restriction), shape=(dim, dim))
     geo = name_cell_geometry(restriction)
     pos = get_backend("qp_in_cell")(restriction)
diff --git a/python/dune/perftool/pdelab/parameter.py b/python/dune/perftool/pdelab/parameter.py
index 0e7ffe69..2d816249 100644
--- a/python/dune/perftool/pdelab/parameter.py
+++ b/python/dune/perftool/pdelab/parameter.py
@@ -11,7 +11,6 @@ from dune.perftool.generation import (class_basename,
                                       temporary_variable
                                       )
 from dune.perftool.pdelab.geometry import (name_cell,
-                                           name_dimension,
                                            name_intersection,
                                            )
 from dune.perftool.pdelab.quadrature import (pymbolic_quadrature_position,
diff --git a/python/dune/perftool/pdelab/quadrature.py b/python/dune/perftool/pdelab/quadrature.py
index f6be15fd..78f3e131 100644
--- a/python/dune/perftool/pdelab/quadrature.py
+++ b/python/dune/perftool/pdelab/quadrature.py
@@ -72,19 +72,6 @@ def name_quadrature_point():
     return "qp"
 
 
-def _local_dim():
-    # To determine the shape, I do query global information here for lack of good alternatives
-    from dune.perftool.generation import get_global_context_value
-    it = get_global_context_value("integral_type")
-    from dune.perftool.pdelab.geometry import name_dimension, name_intersection_dimension
-    if it == 'cell':
-        dim = name_dimension()
-    else:
-        dim = name_intersection_dimension()
-
-    return dim
-
-
 @preamble
 def fill_quadrature_points_cache(name):
     from dune.perftool.pdelab.geometry import name_geometry
@@ -97,7 +84,8 @@ def fill_quadrature_points_cache(name):
 @class_member(classtag="operator")
 def typedef_quadrature_points(name):
     range_field = lop_template_range_field()
-    dim = _local_dim()
+    from dune.perftool.pdelab.geometry import local_dimension
+    dim = local_dimension()
     return "using {} = typename Dune::QuadraturePoint<{}, {}>::Vector;".format(name, range_field, dim)
 
 
@@ -115,7 +103,8 @@ def define_quadrature_points(name):
 
 def name_quadrature_points():
     """Name of vector storing quadrature points as class member"""
-    dim = _local_dim()
+    from dune.perftool.pdelab.geometry import local_dimension
+    dim = local_dimension()
     name = "qp_order" + str(dim)
     shape = (name_quadrature_bound(), dim)
     globalarg(name, shape=shape, dtype=numpy.float64, managed=False)
@@ -150,7 +139,8 @@ def fill_quadrature_weights_cache(name):
 @class_member(classtag="operator")
 def typedef_quadrature_weights(name):
     range_field = lop_template_range_field()
-    dim = _local_dim()
+    from dune.perftool.pdelab.geometry import local_dimension
+    dim = local_dimension()
     return "using {} = typename Dune::QuadraturePoint<{}, {}>::Field;".format(name, range_field, dim)
 
 
@@ -174,7 +164,8 @@ def define_quadrature_weights(name):
 
 def name_quadrature_weights():
     """"Name of vector storing quadrature weights as class member"""
-    dim = _local_dim()
+    from dune.perftool.pdelab.geometry import local_dimension
+    dim = local_dimension()
     name = "qw_order" + str(dim)
     define_quadrature_weights(name)
     fill_quadrature_weights_cache(name)
diff --git a/python/dune/perftool/sumfact/amatrix.py b/python/dune/perftool/sumfact/amatrix.py
index 93b1f5fd..029ca0e3 100644
--- a/python/dune/perftool/sumfact/amatrix.py
+++ b/python/dune/perftool/sumfact/amatrix.py
@@ -235,8 +235,6 @@ def define_theta(name, shape, transpose, derivative, additional_indices=()):
                        potentially_vectorized=True,
                        )
 
-    # TODO Enforce the alignment here!
-
     i = theta_iname("i", shape[0])
     j = theta_iname("j", shape[1])
 
diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py
index 89a72af1..b0f95ec1 100644
--- a/python/dune/perftool/sumfact/basis.py
+++ b/python/dune/perftool/sumfact/basis.py
@@ -25,6 +25,7 @@ from dune.perftool.sumfact.sumfact import (setup_theta,
                                            sum_factorization_kernel,
                                            )
 from dune.perftool.sumfact.quadrature import quadrature_inames
+from dune.perftool.pdelab.geometry import world_dimension
 from dune.perftool.loopy.buffer import initialize_buffer
 from dune.perftool.pdelab.driver import FEM_name_mangling
 from dune.perftool.pdelab.restriction import restricted_name
@@ -52,16 +53,11 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor):
     from ufl.functionview import select_subelement
     sub_element = select_subelement(element, component)
     rank = len(sub_element.value_shape()) + 1
-    shape = sub_element.value_shape() + (element.cell().geometric_dimension(),)
+    shape = sub_element.value_shape() + (world_dimension(),)
     shape_impl = ('arr',) * rank
     temporary_variable(name, shape=shape, shape_impl=shape_impl)
 
-    # TODO:
-    # - This only covers rank 1
-    # - Avoid setting up whole gradient if only one component is needed?
-    # Get geometric dimension
-    formdata = get_global_context_value('formdata')
-    dim = formdata.geometric_dimension
+    dim = world_dimension()
     buffers = []
     insn_dep = None
     for i in range(dim):
@@ -127,8 +123,7 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor):
 @kernel_cached
 def pymbolic_trialfunction(element, restriction, component, visitor):
     # Get geometric dimension
-    formdata = get_global_context_value('formdata')
-    dim = formdata.geometric_dimension
+    dim = world_dimension()
 
     # Construct the matrix sequence for this sum factorization
     a_matrices = construct_amatrix_sequence()
@@ -180,8 +175,7 @@ def sumfact_lfs_iname(bound, dim):
 @backend(interface="lfs_inames", name="sumfact")
 def lfs_inames(element, restriction, number=1, context=''):
     assert number == 1
-    formdata = get_global_context_value('formdata')
-    dim = formdata.geometric_dimension
+    dim = world_dimension()
     return tuple(sumfact_lfs_iname(basis_functions_per_direction(), d) for d in range(dim))
 
 
@@ -220,10 +214,8 @@ def pymbolic_basis(element, restriction, number):
 @backend(interface="evaluate_grad")
 @kernel_cached
 def evaluate_reference_gradient(element, name, restriction):
-    from dune.perftool.pdelab.geometry import name_dimension
-    temporary_variable(
-        name,
-        shape=(name_dimension(),))
+    dim = world_dimension()
+    temporary_variable(name, shape=(dim,))
     quad_inames = quadrature_inames()
     inames = lfs_inames(element, restriction)
     assert(len(quad_inames) == len(inames))
@@ -232,10 +224,6 @@ def evaluate_reference_gradient(element, name, restriction):
     theta = name_theta()
     dtheta = name_theta(derivative=True)
 
-    # Get geometric dimension
-    formdata = get_global_context_value('formdata')
-    dim = formdata.geometric_dimension
-
     for i in range(dim):
         calls = [prim.Subscript(prim.Variable(theta), (prim.Variable(m), prim.Variable(n)))
                  for (m, n) in zip(quad_inames, inames)]
diff --git a/python/dune/perftool/sumfact/quadrature.py b/python/dune/perftool/sumfact/quadrature.py
index 83a5ebb5..dbc870e1 100644
--- a/python/dune/perftool/sumfact/quadrature.py
+++ b/python/dune/perftool/sumfact/quadrature.py
@@ -12,7 +12,9 @@ from dune.perftool.sumfact.amatrix import (quadrature_points_per_direction,
                                            name_oned_quadrature_weights,
                                            )
 from dune.perftool.pdelab.argument import name_accumulation_variable
-from dune.perftool.pdelab.geometry import dimension_iname
+from dune.perftool.pdelab.geometry import (dimension_iname,
+                                           local_dimension,
+                                           )
 
 from loopy import CallMangleInfo
 from loopy.symbolic import FunctionIdentifier
@@ -77,9 +79,7 @@ def sumfact_quad_iname(d, context):
 
 @backend(interface="quad_inames", name="sumfact")
 def quadrature_inames(context=''):
-    formdata = get_global_context_value('formdata')
-    dim = formdata.geometric_dimension
-    return tuple(sumfact_quad_iname(d, context) for d in range(dim))
+    return tuple(sumfact_quad_iname(d, context) for d in range(local_dimension()))
 
 
 def define_recursive_quadrature_weight(name, dir):
@@ -98,9 +98,7 @@ def define_recursive_quadrature_weight(name, dir):
 
 
 def recursive_quadrature_weight(dir=0):
-    formdata = get_global_context_value('formdata')
-    dim = formdata.geometric_dimension
-    if dir == dim:
+    if dir == local_dimension():
         return pymbolic_base_weight()
     else:
         name = 'weight_{}'.format(dir)
@@ -113,9 +111,7 @@ def quadrature_weight():
 
 
 def define_quadrature_position(name):
-    formdata = get_global_context_value('formdata')
-    dim = formdata.geometric_dimension
-    for i in range(dim):
+    for i in range(local_dimension()):
         instruction(expression=Subscript(Variable(name_oned_quadrature_points()), (Variable(quadrature_inames()[i]),)),
                     assignee=Subscript(Variable(name), (i,)),
                     forced_iname_deps=frozenset(quadrature_inames()),
@@ -126,10 +122,8 @@ def define_quadrature_position(name):
 
 @backend(interface="quad_pos", name="sumfact")
 def pymbolic_quadrature_position():
-    formdata = get_global_context_value('formdata')
-    dim = formdata.geometric_dimension
     name = 'pos'
-    temporary_variable(name, shape=(dim,), shape_impl=("fv",))
+    temporary_variable(name, shape=(local_dimension(),), shape_impl=("fv",))
     define_quadrature_position(name)
     return Variable(name)
 
-- 
GitLab