From 4b137861b5e7562180cde04fa3ce3fb321563ef2 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Thu, 8 Dec 2016 18:20:03 +0100
Subject: [PATCH] Implement the evaluation of basis for jacobians

---
 python/dune/perftool/sumfact/basis.py   | 30 ++++++++++++++++++-------
 python/dune/perftool/sumfact/sumfact.py | 12 +---------
 python/dune/perftool/sumfact/switch.py  | 22 ++++++++++++++++++
 3 files changed, 45 insertions(+), 19 deletions(-)

diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py
index 7274e58a..0965bdce 100644
--- a/python/dune/perftool/sumfact/basis.py
+++ b/python/dune/perftool/sumfact/basis.py
@@ -26,6 +26,9 @@ from dune.perftool.sumfact.sumfact import (get_facedir,
                                            sum_factorization_kernel,
                                            )
 from dune.perftool.sumfact.quadrature import quadrature_inames
+from dune.perftool.sumfact.switch import (get_facedir,
+                                          get_facemod,
+                                          )
 from dune.perftool.pdelab.geometry import world_dimension
 from dune.perftool.loopy.buffer import initialize_buffer
 from dune.perftool.pdelab.driver import FEM_name_mangling
@@ -188,14 +191,25 @@ def evaluate_basis(element, name, restriction):
     theta = name_theta()
     quad_inames = quadrature_inames()
     inames = lfs_inames(element, restriction)
-    assert(len(quad_inames) == len(inames))
-
-    instruction(expression=prim.Product(tuple(prim.Subscript(prim.Variable(theta),
-                                                             (prim.Variable(i), prim.Variable(j))
-                                                             )
-                                              for (i, j) in zip(quad_inames, inames)
-                                              )
-                                        ),
+    facedir = get_facedir(restriction)
+
+    # Collect the pairs of lfs/quad inames that are in use
+    # On facets, the normal direction of the facet is excluded
+    prod = tuple(prim.Subscript(prim.Variable(theta),
+                                (prim.Variable(i), prim.Variable(j))
+                                )
+                 for (i, j) in zip(quad_inames, tuple(iname for i, iname in enumerate(inames) if i != facedir))
+                 )
+
+    # Add the missing direction on facedirs by evaluating at either 0 or 1
+    if facedir:
+        facemod = get_facemod(restriction)
+        from dune.perftool.sumfact.amatrix import PolynomialLookup, name_polynomials
+        prod = prod + (prim.Call(PolynomialLookup(name_polynomials(), False),
+                                 (prim.Variable(inames[facedir]), facemod)),)
+
+    # Issue the product
+    instruction(expression=prim.Product(prod),
                 assignee=prim.Variable(name),
                 forced_iname_deps=frozenset(quad_inames + inames),
                 forced_iname_deps_is_final=True,
diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py
index a4377817..ed219647 100644
--- a/python/dune/perftool/sumfact/sumfact.py
+++ b/python/dune/perftool/sumfact/sumfact.py
@@ -43,6 +43,7 @@ from dune.perftool.sumfact.amatrix import (AMatrix,
                                            basis_functions_per_direction,
                                            construct_amatrix_sequence,
                                            )
+from dune.perftool.sumfact.switch import get_facedir
 from dune.perftool.loopy.symbolic import SumfactKernel
 from dune.perftool.tools import get_pymbolic_basename
 from dune.perftool.error import PerftoolError
@@ -62,17 +63,6 @@ import pymbolic.primitives as prim
 from pytools import product
 
 
-def get_facedir(restriction):
-    from dune.perftool.pdelab.restriction import Restriction
-    if restriction == Restriction.NEGATIVE or get_global_context_value("integral_type") == "exterior_facet":
-        return get_global_context_value("facedir_s")
-    if restriction == Restriction.POSITIVE:
-        return get_global_context_value("facedir_n")
-    if restriction == Restriction.NONE:
-        return None
-    assert False
-
-
 @iname
 def _sumfact_iname(bound, _type, count):
     name = "sf_{}_{}".format(_type, str(count))
diff --git a/python/dune/perftool/sumfact/switch.py b/python/dune/perftool/sumfact/switch.py
index 78b31590..d46d5d75 100644
--- a/python/dune/perftool/sumfact/switch.py
+++ b/python/dune/perftool/sumfact/switch.py
@@ -105,3 +105,25 @@ def generate_interior_facet_switch():
     block.append("}")
 
     return ClassMember(signature + block)
+
+
+def get_facedir(restriction):
+    from dune.perftool.pdelab.restriction import Restriction
+    if restriction == Restriction.NEGATIVE or get_global_context_value("integral_type") == "exterior_facet":
+        return get_global_context_value("facedir_s")
+    if restriction == Restriction.POSITIVE:
+        return get_global_context_value("facedir_n")
+    if restriction == Restriction.NONE:
+        return None
+    assert False
+
+
+def get_facemod(restriction):
+    from dune.perftool.pdelab.restriction import Restriction
+    if restriction == Restriction.NEGATIVE or get_global_context_value("integral_type") == "exterior_facet":
+        return get_global_context_value("facemod_s")
+    if restriction == Restriction.POSITIVE:
+        return get_global_context_value("facemod_n")
+    if restriction == Restriction.NONE:
+        return None
+    assert False
-- 
GitLab