Skip to content
Snippets Groups Projects
Commit 46ab4811 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Make all lambdas local and introduce a preamble concept on the lambda visitor

parent dfdbf11b
No related branches found
No related tags found
No related merge requests found
......@@ -115,7 +115,7 @@ def define_intersection_lambda(name, func):
return "auto {} = [&](const auto& x){{ return {}; }};".format(name, float(func))
elif isinstance(func, Expr):
from dune.perftool.pdelab.driver.visitor import ufl_to_code
return "auto {} = [&](const auto& x){{ return {}; }};".format(name, ufl_to_code(func))
return "auto {} = [&](const auto& is, const auto& xl){{ {} }};".format(name, ufl_to_code(func))
raise ValueError("Expression not understood")
......
......@@ -95,10 +95,7 @@ def define_boundary_lambda(name, boundary):
if isinstance(boundary, (int, float)):
return "auto {} = [&](const auto& x){{ return {}; }};".format(name, float(boundary))
elif isinstance(boundary, Expr):
from dune.perftool.loopy.target import type_floatingpoint
from dune.perftool.pdelab.driver.visitor import ufl_to_code
return "auto {} = [&](const auto& x){{ return ({}){}; }};".format(name,
type_floatingpoint(),
ufl_to_code(boundary))
return "auto {} = [&](const auto& is, const auto& xl){{ {}; }};".format(name, ufl_to_code(boundary))
else:
raise NotImplementedError("What is this?")
......@@ -4,11 +4,6 @@ from dune.perftool.ufl.visitor import UFL2LoopyVisitor
import pymbolic.primitives as prim
@preamble(section="init")
def driver_using_statement(what):
return "using {};".format(what)
@preamble(section="gridoperator")
def set_lop_to_starting_time():
from dune.perftool.pdelab.driver import get_form_ident
......@@ -23,17 +18,20 @@ class DriverUFL2PymbolicVisitor(UFL2LoopyVisitor):
UFL2LoopyVisitor.__init__(self, PDELabInterface(), "exterior_facet", {})
def __call__(self, expr):
return self._call(expr, False)
self.preambles = []
ret = self._call(expr, False)
return set(self.preambles), ret
def spatial_coordinate(self, o):
self.preambles.append("auto x=is.geometry().global(xl);")
return prim.Variable("x")
def max_value(self, o):
driver_using_statement("std::max")
self.preambles.append("using std::max;")
return UFL2LoopyVisitor.max_value(self, o)
def min_value(self, o):
driver_using_statement("std::min")
self.preambles.append("using std::min;")
return UFL2LoopyVisitor.min_value(self, o)
def coefficient(self, o):
......@@ -55,5 +53,5 @@ def ufl_to_code(expr, boundary=True):
visitor = DriverUFL2PymbolicVisitor()
from pymbolic.mapper.c_code import CCodeMapper
ccm = CCodeMapper()
vis = visitor(expr)
return ccm(vis)
preambles, vis_expr = visitor(expr)
return "{} return {};".format("".join(preambles), ccm(vis_expr))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment