From ce87e42034c23aaf17c9952b3a87214a0ef243e1 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Wed, 14 Dec 2016 21:46:19 +0100
Subject: [PATCH] Make boundary lambda generation work for simple poisson

---
 python/dune/perftool/compile.py         |  2 +-
 python/dune/perftool/pdelab/driver.py   | 22 +++++++++++++++++-----
 python/dune/perftool/pdelab/geometry.py |  2 +-
 python/dune/perftool/ufl/execution.py   |  5 +++--
 python/dune/perftool/ufl/visitor.py     |  6 +++++-
 test/poisson/poisson.ufl                | 11 ++++++-----
 6 files changed, 33 insertions(+), 15 deletions(-)

diff --git a/python/dune/perftool/compile.py b/python/dune/perftool/compile.py
index ec456ec9..cfb4b9f3 100644
--- a/python/dune/perftool/compile.py
+++ b/python/dune/perftool/compile.py
@@ -20,7 +20,7 @@ from dune.perftool.pdelab.localoperator import (generate_localoperator_basefile,
                                                 name_localoperator_file)
 from dune.perftool.ufl.preprocess import preprocess_form
 
-import os.path
+from os.path import splitext, basename
 
 
 # Disable loopy caching before we do anything else!
diff --git a/python/dune/perftool/pdelab/driver.py b/python/dune/perftool/pdelab/driver.py
index 39ef2e24..8303cf8a 100644
--- a/python/dune/perftool/pdelab/driver.py
+++ b/python/dune/perftool/pdelab/driver.py
@@ -8,6 +8,7 @@ gained there.
 """
 from dune.perftool.error import PerftoolCodegenError
 from dune.perftool.generation import (generator_factory,
+                                      global_context,
                                       include_file,
                                       cached,
                                       preamble,
@@ -748,13 +749,24 @@ def define_vector(name, formdata):
 
 @preamble
 def define_boundary_lambda(boundary, name):
+    from dune.perftool.ufl.execution import Expression
+    from ufl.classes import Expr
     if boundary is None:
         return "auto {} = [&](const auto& x){{ return 0.0; }};".format(name)
-    if boundary.is_global:
-        return "auto {} = [&](const auto& x){{ {} }};".format(name, boundary.c_expr[0])
-    else:
-        return "auto {} = [&](const auto& e, const auto& x){{ {} }};".format(name, boundary.c_expr[0])
-
+    elif isinstance(boundary, Expression):
+        if boundary.is_global:
+            return "auto {} = [&](const auto& x){{ {} }};".format(name, boundary.c_expr[0])
+        else:
+            return "auto {} = [&](const auto& e, const auto& x){{ {} }};".format(name, boundary.c_expr[0])
+    elif isinstance(boundary, Expr):
+        # Set up a visitor
+        with global_context(integral_type="exterior_facet", formdata=_driver_data["formdata"], driver=True):
+            from dune.perftool.ufl.visitor import UFL2LoopyVisitor
+            from dune.perftool.pdelab import PDELabInterface
+            visitor = UFL2LoopyVisitor(PDELabInterface(), "exterior_facet", {})
+            return "auto {} = [&](const auto& x){{ return {}; }};".format(name, visitor(boundary))
+
+    raise ValueError("Expression not understood")
 
 def name_boundary_lambda(boundary, name):
     define_boundary_lambda(boundary, name + "lambda")
diff --git a/python/dune/perftool/pdelab/geometry.py b/python/dune/perftool/pdelab/geometry.py
index 1be633ff..bcb17d07 100644
--- a/python/dune/perftool/pdelab/geometry.py
+++ b/python/dune/perftool/pdelab/geometry.py
@@ -387,7 +387,7 @@ def name_facet_jacobian_determinant():
 
 
 def apply_to_global_transformation(name, local):
-    temporary_variable(name, shape=(name_dimension(),), shape_impl=("fv",))
+    temporary_variable(name, shape=(world_dimension(),), shape_impl=("fv",))
     geo = name_geometry()
     code = "{} = {}.global({});".format(name,
                                         geo,
diff --git a/python/dune/perftool/ufl/execution.py b/python/dune/perftool/ufl/execution.py
index 57d337fb..112aa5ec 100644
--- a/python/dune/perftool/ufl/execution.py
+++ b/python/dune/perftool/ufl/execution.py
@@ -145,13 +145,14 @@ class FiniteElement(ufl.FiniteElement):
             self.dirichlet_constraints = kwargs.pop('dirichlet_constraints', 'return true;')
             if isinstance(self.dirichlet_constraints, str):
                 self.dirichlet_constraints = Expression(self.dirichlet_constraints)
-            assert isinstance(self.dirichlet_constraints, Expression)
+            from ufl.classes import Expr
+            assert isinstance(self.dirichlet_constraints, (Expression, Expr))
 
             # Get dirichlet_constraints and convert it to Expression if necessary!
             self.dirichlet_expression = kwargs.pop('dirichlet_expression', 'return 0.0;')
             if isinstance(self.dirichlet_expression, str):
                 self.dirichlet_expression = Expression(self.dirichlet_expression)
-            assert isinstance(self.dirichlet_expression, Expression)
+            assert isinstance(self.dirichlet_expression, (Expression, Expr))
 
         # Initialize the original finite element from ufl
         ufl.FiniteElement.__init__(self, *args, **kwargs)
diff --git a/python/dune/perftool/ufl/visitor.py b/python/dune/perftool/ufl/visitor.py
index a476668f..0ad6a9d3 100644
--- a/python/dune/perftool/ufl/visitor.py
+++ b/python/dune/perftool/ufl/visitor.py
@@ -305,7 +305,11 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
     #
 
     def spatial_coordinate(self, o):
-        return self.interface.pymbolic_spatial_coordinate(self.restriction)
+        # If this is called from the driver, we just want to return x
+        if get_global_context_value("driver", False):
+            return prim.Variable("x")
+        else:
+            return self.interface.pymbolic_spatial_coordinate(self.restriction)
 
     def facet_normal(self, o):
         # The normal must be restricted to be well-defined
diff --git a/test/poisson/poisson.ufl b/test/poisson/poisson.ufl
index 27395230..73515675 100644
--- a/test/poisson/poisson.ufl
+++ b/test/poisson/poisson.ufl
@@ -1,13 +1,14 @@
-g = Expression("Dune::FieldVector<double,2> c(0.5); c-= x; return std::exp(-1.*c.two_norm2());")
 cell = triangle
 
+x = SpatialCoordinate(cell)
+
+c = (0.5-x[0])**2 + (0.5-x[1])**2
+g = exp(-1.*c)
+f = 4*(1.-c)*g
+
 V = FiniteElement("CG", cell, 1, dirichlet_expression=g)
 u = TrialFunction(V)
 v = TestFunction(V)
 
-x = SpatialCoordinate(cell)
-
-c = (0.5-x[0])**2 + (0.5-x[1])**2
-f = 4*(1.-c)*exp(-1.*c)
 
 forms = [(inner(grad(u), grad(v)) - f*v)*dx]
-- 
GitLab