Skip to content
Snippets Groups Projects
Commit 63c0ce6b authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Also handle boundary type lambdas

Need to use handler from pymbolic as loopy codegen handlers
are too involved for standalone usage.
parent ce87e420
No related branches found
No related tags found
No related merge requests found
...@@ -430,12 +430,27 @@ def name_constraintscontainer(expr): ...@@ -430,12 +430,27 @@ def name_constraintscontainer(expr):
@preamble @preamble
def define_intersection_lambda(expression, name): def define_intersection_lambda(expression, name):
from dune.perftool.ufl.execution import Expression
from ufl.classes import Expr
if expression is None: if expression is None:
return "auto {} = [&](const auto& x){{ return 0; }};".format(name) return "auto {} = [&](const auto& x){{ return 0; }};".format(name)
if expression.is_global: elif isinstance(expression, Expression):
return "auto {} = [&](const auto& x){{ {} }};".format(name, expression.c_expr[0]) if expression.is_global:
else: return "auto {} = [&](const auto& x){{ {} }};".format(name, expression.c_expr[0])
return "auto {} = [&](const auto& is, const auto& x){{ {} }};".format(name, expression.c_expr[0]) else:
return "auto {} = [&](const auto& is, const auto& x){{ {} }};".format(name, expression.c_expr[0])
elif isinstance(expression, Expr):
# Set up a visitor
with global_context(integral_type="exterior_facet", formdata=_driver_data["formdata"], driver=True):
from dune.perftool.ufl.visitor import UFL2LoopyVisitor
from dune.perftool.pdelab import PDELabInterface
visitor = UFL2LoopyVisitor(PDELabInterface(), "exterior_facet", {})
from pymbolic.mapper.c_code import CCodeMapper
ccm = CCodeMapper()
expr = visitor(expression)
return "auto {} = [&](const auto& x){{ return {}; }};".format(name, ccm(expr))
raise ValueError("Expression not understood")
def name_bctype_lambda(name, dirichlet): def name_bctype_lambda(name, dirichlet):
......
...@@ -30,6 +30,7 @@ from dune.perftool.ufl.modified_terminals import Restriction ...@@ -30,6 +30,7 @@ from dune.perftool.ufl.modified_terminals import Restriction
import dune.perftool.loopy.mangler import dune.perftool.loopy.mangler
from pymbolic.primitives import Variable from pymbolic.primitives import Variable
import pymbolic.primitives as prim
from pytools import Record from pytools import Record
import loopy as lp import loopy as lp
...@@ -250,7 +251,7 @@ def determine_accumulation_space(expr, number, measure): ...@@ -250,7 +251,7 @@ def determine_accumulation_space(expr, number, measure):
) )
def boundary_predicates(expr, measure, subdomain_id): def boundary_predicates(expr, measure, subdomain_id, visitor):
predicates = frozenset([]) predicates = frozenset([])
if subdomain_id not in ['everywhere', 'otherwise']: if subdomain_id not in ['everywhere', 'otherwise']:
...@@ -276,16 +277,22 @@ def boundary_predicates(expr, measure, subdomain_id): ...@@ -276,16 +277,22 @@ def boundary_predicates(expr, measure, subdomain_id):
assert measure in subdomains assert measure in subdomains
subdomain_data = subdomains[measure] subdomain_data = subdomains[measure]
# Determine the name of the parameter function from ufl.classes import Expr
name = get_global_context_value("data").object_names[id(subdomain_data)] if isinstance(subdomain_data, Expr):
cond = visitor(subdomain_data)
else:
# Determine the name of the parameter function
cond = get_global_context_value("data").object_names[id(subdomain_data)]
# Trigger the generation of code for this thing in the parameter class
from ufl.checks import is_cellwise_constant
cellwise_constant = is_cellwise_constant(expr)
from dune.perftool.pdelab.parameter import intersection_parameter_function
intersection_parameter_function(cond, subdomain_data, cellwise_constant, t='int32')
# Trigger the generation of code for this thing in the parameter class cond = prim.Variable(cond)
from ufl.checks import is_cellwise_constant
cellwise_constant = is_cellwise_constant(expr)
from dune.perftool.pdelab.parameter import intersection_parameter_function
intersection_parameter_function(name, subdomain_data, cellwise_constant, t='int32')
predicates = predicates.union(['{} == {}'.format(name, subdomain_id)]) predicates = predicates.union([prim.Comparison(cond, '==', subdomain_id)])
return predicates return predicates
...@@ -331,7 +338,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): ...@@ -331,7 +338,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
from dune.perftool.pdelab.argument import name_accumulation_variable from dune.perftool.pdelab.argument import name_accumulation_variable
accumvar = name_accumulation_variable((ansatz_lfs.get_restriction() + test_lfs.get_restriction())) accumvar = name_accumulation_variable((ansatz_lfs.get_restriction() + test_lfs.get_restriction()))
predicates = boundary_predicates(accterm.term, measure, subdomain_id) predicates = boundary_predicates(accterm.term, measure, subdomain_id, visitor)
rank = 1 if ansatz_lfs.lfs is None else 2 rank = 1 if ansatz_lfs.lfs is None else 2
......
cell = triangle cell = triangle
x = SpatialCoordinate(cell)
g = Expression("Dune::FieldVector<double,2> c(0.5); c-= x; return std::exp(-1.*c.two_norm2());") c = (0.5-x[0])**2 + (0.5-x[1])**2
#j = Expression("Dune::FieldVector<double,2> c(0.5); c-= x; double s; if (x[1]>0.5) s=1.; else s=-1.; return -2.*s*(x[1]-0.5)*std::exp(-1.*c.two_norm2());", on_intersection=True) g = exp(-1.*c)
bctype = Expression("if ((x[1]<1e-8) || (x[1]>1.-1e-8)) return 0; else return 1;", on_intersection=True) f = 4*(1.-c)*g
sgn = conditional(x[1] > 0.5, 1., -1.)
j = -2.*sgn*(x[1]-0.5)*g
bctype = conditional(Or(x[1]<1e-8, x[1]>1.-1e-8), 0, 1)
V = FiniteElement("CG", "triangle", 1, dirichlet_expression=g, dirichlet_constraints=bctype) V = FiniteElement("CG", "triangle", 1, dirichlet_expression=g, dirichlet_constraints=bctype)
u = TrialFunction(V) u = TrialFunction(V)
v = TestFunction(V) v = TestFunction(V)
x = SpatialCoordinate(cell)
c = (0.5-x[0])**2 + (0.5-x[1])**2
f = 4*(1.-c)*exp(-1.*c)
sgn = conditional(x[1] > 0.5, 1., -1.)
j = -2.*sgn*(x[1]-0.5)*exp(-1.*c)
# Define the boundary measure that knows where we are... # Define the boundary measure that knows where we are...
ds = ds(subdomain_data=bctype) ds = ds(subdomain_data=bctype)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment