From 63c0ce6b01964e553d6ab15036c4c68f50648f38 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Wed, 14 Dec 2016 22:19:14 +0100
Subject: [PATCH] Also handle boundary type lambdas

Need to use handler from pymbolic as loopy codegen handlers
are too involved for standalone usage.
---
 python/dune/perftool/pdelab/driver.py        | 23 ++++++++++++++---
 python/dune/perftool/pdelab/localoperator.py | 27 ++++++++++++--------
 test/poisson/poisson_neumann.ufl             | 18 ++++++-------
 3 files changed, 44 insertions(+), 24 deletions(-)

diff --git a/python/dune/perftool/pdelab/driver.py b/python/dune/perftool/pdelab/driver.py
index 8303cf8a..a27ac9d5 100644
--- a/python/dune/perftool/pdelab/driver.py
+++ b/python/dune/perftool/pdelab/driver.py
@@ -430,12 +430,27 @@ def name_constraintscontainer(expr):
 
 @preamble
 def define_intersection_lambda(expression, name):
+    from dune.perftool.ufl.execution import Expression
+    from ufl.classes import Expr
     if expression is None:
         return "auto {} = [&](const auto& x){{ return 0; }};".format(name)
-    if expression.is_global:
-        return "auto {} = [&](const auto& x){{ {} }};".format(name, expression.c_expr[0])
-    else:
-        return "auto {} = [&](const auto& is, const auto& x){{ {} }};".format(name, expression.c_expr[0])
+    elif isinstance(expression, Expression):
+        if expression.is_global:
+            return "auto {} = [&](const auto& x){{ {} }};".format(name, expression.c_expr[0])
+        else:
+            return "auto {} = [&](const auto& is, const auto& x){{ {} }};".format(name, expression.c_expr[0])
+    elif isinstance(expression, Expr):
+        # Set up a visitor
+        with global_context(integral_type="exterior_facet", formdata=_driver_data["formdata"], driver=True):
+            from dune.perftool.ufl.visitor import UFL2LoopyVisitor
+            from dune.perftool.pdelab import PDELabInterface
+            visitor = UFL2LoopyVisitor(PDELabInterface(), "exterior_facet", {})
+            from pymbolic.mapper.c_code import CCodeMapper
+            ccm = CCodeMapper()
+            expr = visitor(expression)
+            return "auto {} = [&](const auto& x){{ return {}; }};".format(name, ccm(expr))
+
+    raise ValueError("Expression not understood")
 
 
 def name_bctype_lambda(name, dirichlet):
diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py
index 4359f7dd..0686f44c 100644
--- a/python/dune/perftool/pdelab/localoperator.py
+++ b/python/dune/perftool/pdelab/localoperator.py
@@ -30,6 +30,7 @@ from dune.perftool.ufl.modified_terminals import Restriction
 import dune.perftool.loopy.mangler
 
 from pymbolic.primitives import Variable
+import pymbolic.primitives as prim
 from pytools import Record
 
 import loopy as lp
@@ -250,7 +251,7 @@ def determine_accumulation_space(expr, number, measure):
                              )
 
 
-def boundary_predicates(expr, measure, subdomain_id):
+def boundary_predicates(expr, measure, subdomain_id, visitor):
     predicates = frozenset([])
 
     if subdomain_id not in ['everywhere', 'otherwise']:
@@ -276,16 +277,22 @@ def boundary_predicates(expr, measure, subdomain_id):
         assert measure in subdomains
         subdomain_data = subdomains[measure]
 
-        # Determine the name of the parameter function
-        name = get_global_context_value("data").object_names[id(subdomain_data)]
+        from ufl.classes import Expr
+        if isinstance(subdomain_data, Expr):
+            cond = visitor(subdomain_data)
+        else:
+            # Determine the name of the parameter function
+            cond = get_global_context_value("data").object_names[id(subdomain_data)]
+
+            # Trigger the generation of code for this thing in the parameter class
+            from ufl.checks import is_cellwise_constant
+            cellwise_constant = is_cellwise_constant(expr)
+            from dune.perftool.pdelab.parameter import intersection_parameter_function
+            intersection_parameter_function(cond, subdomain_data, cellwise_constant, t='int32')
 
-        # Trigger the generation of code for this thing in the parameter class
-        from ufl.checks import is_cellwise_constant
-        cellwise_constant = is_cellwise_constant(expr)
-        from dune.perftool.pdelab.parameter import intersection_parameter_function
-        intersection_parameter_function(name, subdomain_data, cellwise_constant, t='int32')
+            cond = prim.Variable(cond)
 
-        predicates = predicates.union(['{} == {}'.format(name, subdomain_id)])
+        predicates = predicates.union([prim.Comparison(cond, '==', subdomain_id)])
 
     return predicates
 
@@ -331,7 +338,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
     from dune.perftool.pdelab.argument import name_accumulation_variable
     accumvar = name_accumulation_variable((ansatz_lfs.get_restriction() + test_lfs.get_restriction()))
 
-    predicates = boundary_predicates(accterm.term, measure, subdomain_id)
+    predicates = boundary_predicates(accterm.term, measure, subdomain_id, visitor)
 
     rank = 1 if ansatz_lfs.lfs is None else 2
 
diff --git a/test/poisson/poisson_neumann.ufl b/test/poisson/poisson_neumann.ufl
index b47aacd0..f1a46662 100644
--- a/test/poisson/poisson_neumann.ufl
+++ b/test/poisson/poisson_neumann.ufl
@@ -1,20 +1,18 @@
 cell = triangle
+x = SpatialCoordinate(cell)
 
-g = Expression("Dune::FieldVector<double,2> c(0.5); c-= x; return std::exp(-1.*c.two_norm2());")
-#j = Expression("Dune::FieldVector<double,2> c(0.5); c-= x; double s; if (x[1]>0.5) s=1.; else s=-1.; return -2.*s*(x[1]-0.5)*std::exp(-1.*c.two_norm2());", on_intersection=True)
-bctype = Expression("if ((x[1]<1e-8) || (x[1]>1.-1e-8)) return 0; else return 1;", on_intersection=True)
+c = (0.5-x[0])**2 + (0.5-x[1])**2
+g = exp(-1.*c)
+f = 4*(1.-c)*g
+sgn = conditional(x[1] > 0.5, 1., -1.)
+j = -2.*sgn*(x[1]-0.5)*g
+
+bctype = conditional(Or(x[1]<1e-8, x[1]>1.-1e-8), 0, 1)
 
 V = FiniteElement("CG", "triangle", 1, dirichlet_expression=g, dirichlet_constraints=bctype)
 u = TrialFunction(V)
 v = TestFunction(V)
 
-x = SpatialCoordinate(cell)
-
-c = (0.5-x[0])**2 + (0.5-x[1])**2
-f = 4*(1.-c)*exp(-1.*c)
-sgn = conditional(x[1] > 0.5, 1., -1.)
-j = -2.*sgn*(x[1]-0.5)*exp(-1.*c)
-
 # Define the boundary measure that knows where we are...
 ds = ds(subdomain_data=bctype)
 
-- 
GitLab