Skip to content
Snippets Groups Projects
Commit 42c9cd29 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Rewrite Expressions such that they may be defined over mixed elements

parent 04dda8f4
No related branches found
No related tags found
No related merge requests found
Showing
with 156 additions and 85 deletions
......@@ -237,8 +237,8 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp
# Check if this is a parameter function
else:
# We expect all coefficients to be of type Expression or VectorExpression!
from dune.perftool.ufl.execution import Expression, VectorExpression
assert isinstance(o, (Expression, VectorExpression))
from dune.perftool.ufl.execution import Expression
assert isinstance(o, Expression)
# Determine the name of the parameter function
from dune.perftool.generation import get_global_context_value
......
......@@ -423,9 +423,9 @@ def define_intersection_lambda(expression, name):
if expression is None:
return "auto {} = [&](const auto& x){{ return 0; }};".format(name)
if expression.is_global:
return "auto {} = [&](const auto& x){{ {} }};".format(name, expression.c_expr)
return "auto {} = [&](const auto& x){{ {} }};".format(name, expression.c_expr[0])
else:
return "auto {} = [&](const auto& is, const auto& x){{ {} }};".format(name, expression.c_expr)
return "auto {} = [&](const auto& is, const auto& x){{ {} }};".format(name, expression.c_expr[0])
@symbol
......@@ -741,9 +741,9 @@ def define_boundary_lambda(boundary, name):
if boundary is None:
return "auto {} = [&](const auto& x){{ return 0.0; }};".format(name)
if boundary.is_global:
return "auto {} = [&](const auto& x){{ {} }};".format(name, boundary.c_expr)
return "auto {} = [&](const auto& x){{ {} }};".format(name, boundary.c_expr[0])
else:
return "auto {} = [&](const auto& e, const auto& x){{ {} }};".format(name, boundary.c_expr)
return "auto {} = [&](const auto& e, const auto& x){{ {} }};".format(name, boundary.c_expr[0])
@symbol
......
......@@ -69,8 +69,22 @@ def define_set_time_method():
return result
def component_to_tree_path(element, component):
subel = element.extract_subelement_component(component)
def _flatten(x):
if isinstance(x, tuple):
return '_'.join(_flatten(i) for i in x if i != ())
else:
return str(x)
return _flatten(subel)
@class_member("parameterclass", access=AccessModifier.PUBLIC)
def define_parameter_function_class_member(name, expr, t, cell):
def define_parameter_function_class_member(name, expr, baset, shape, cell):
t = construct_nested_fieldvector(baset, shape)
geot = "E" if cell else "I"
geo = geot.lower()
result = ["template<typename {}, typename X>".format(geot),
......@@ -78,12 +92,32 @@ def define_parameter_function_class_member(name, expr, t, cell):
"{",
]
if expr.is_global:
result.append(" auto x = {}.geometry().global(local);".format(geo))
# In the case of a non-scalar parameter function, recurse into leafs
if expr.element.value_shape():
# Check that this is a VectorElement, as I have no idea how a parameter function
# over a non-vector mixed element should be well-defined in PDELab.
from ufl import VectorElement
assert isinstance(expr.element, VectorElement)
result.append(" {} result(0.0);".format(t))
from dune.perftool.ufl.execution import split_expression
for i, subexpr in enumerate(split_expression(expr)):
child_name = "{}_{}".format(name, component_to_tree_path(expr.element, i))
result.append(" result[{}] = {}({}, local);".format(i, child_name, geo))
define_parameter_function_class_member(child_name, subexpr, baset, shape[1:], cell)
result.append(" return result;")
else:
result.append(" auto x = local;")
# Evaluate a scalar parameter function
if expr.is_global:
result.append(" auto x = {}.geometry().global(local);".format(geo))
else:
result.append(" auto x = local;")
result.append(" " + expr.c_expr[0])
result.append(" " + expr.c_expr)
result.append("}")
return result
......@@ -175,8 +209,7 @@ def construct_nested_fieldvector(t, shape):
def cell_parameter_function(name, expr, restriction, cellwise_constant, t='double'):
shape = expr.ufl_element().value_shape()
shape_impl = ('fv',) * len(shape)
t = construct_nested_fieldvector(t, shape)
define_parameter_function_class_member(name, expr, t, True)
define_parameter_function_class_member(name, expr, t, shape, True)
if cellwise_constant:
from dune.perftool.generation.loopy import default_declaration
default_declaration(name, shape, shape_impl)
......@@ -190,8 +223,7 @@ def cell_parameter_function(name, expr, restriction, cellwise_constant, t='doubl
def intersection_parameter_function(name, expr, cellwise_constant, t='double'):
shape = expr.ufl_element().value_shape()
shape_impl = ('fv',) * len(shape)
t = construct_nested_fieldvector(t, shape)
define_parameter_function_class_member(name, expr, t, False)
define_parameter_function_class_member(name, expr, t, shape, False)
if cellwise_constant:
from dune.perftool.generation.loopy import default_declaration
default_declaration(name, shape, shape_impl)
......
......@@ -38,7 +38,10 @@ class Coefficient(ufl.Coefficient):
ufl.Coefficient.__init__(self, element, count)
split = ufl.split_functions.split2
def split(obj):
if isinstance(obj, Expression):
return split_expression(obj)
return ufl.split_functions.split2(obj)
def Coefficients(element):
......@@ -54,42 +57,60 @@ def TrialFunctions(element):
class Expression(Coefficient):
def __init__(self, expr, is_global=True, on_intersection=False, cell_type="triangle", cellwise_constant=False):
assert isinstance(expr, str)
self.c_expr = expr
self.is_global = is_global
self.on_intersection = on_intersection
# Avoid ufl complaining about not matching dimension/cells
if cellwise_constant:
_dg = FiniteElement("DG", cell_type, 0)
else:
_dg = FiniteElement("DG", cell_type, 1)
def __init__(self, cppcode=None, element=None, cell="triangle", degree=None, is_global=True, on_intersection=False, cellwise_constant=False):
assert cppcode
# Initialize a coefficient with a dummy finite element map.
Coefficient.__init__(self, _dg)
if isinstance(cppcode, str):
cppcode = (cppcode,)
# TODO the subdomain_data code needs a uflid, not idea how to get it here
# The standard way through class decorator fails here...
def ufl_id(self):
return 0
def wrap_return(e):
if "return" not in e:
return "return {};".format(e)
else:
return e
cppcode = tuple(wrap_return(e) for e in cppcode)
class VectorExpression(Coefficient):
def __init__(self, expr, is_global=True, on_intersection=False, cell_type="triangle", cellwise_constant=False):
assert isinstance(expr, str)
self.c_expr = expr
if cellwise_constant:
if element:
assert element.degree() == 0
elif degree is not None:
assert degree == 0
else:
element = FiniteElement("DG", cell, 0)
if degree is None:
degree = 1
if element is None:
if len(cppcode) == 1:
element = FiniteElement("DG", cell, degree)
else:
element = VectorElement("DG", cell, degree, len(cppcode))
def element_length(elem):
if isinstance(elem, ufl.FiniteElement):
return 1
else:
return elem.value_shape()[0]
assert element_length(element) == len(cppcode)
self.c_expr = cppcode
self.is_global = is_global
self.on_intersection = on_intersection
# Avoid ufl complaining about not matching dimension/cells
if cellwise_constant:
_dgvec = VectorElement("DG", cell_type, 0)
else:
_dgvec = VectorElement("DG", cell_type, 1)
self.element = element
# Initialize a coefficient with a dummy finite element map.
Coefficient.__init__(self, _dgvec)
Coefficient.__init__(self, element)
def __mul__(self, other):
# Allow the combination of Expressions
if isinstance(other, Expression):
from ufl import MixedElement
return Expression(cppcode=self.c_expr + other.c_expr, element=self.element * other.element)
else:
return Coefficient.__mul__(self, other)
# TODO the subdomain_data code needs a uflid, not idea how to get it here
# The standard way through class decorator fails here...
......@@ -97,6 +118,19 @@ class VectorExpression(Coefficient):
return 0
def split_expression(expr):
assert isinstance(expr, Expression)
def element_slice(expression, sub):
offset = sum(subel.value_size() for subel in expression.element.sub_elements()[:sub])
return expression.c_expr[offset:offset + expr.element.sub_elements()[sub].value_size()]
return tuple(Expression(cppcode=element_slice(expr, i),
element=expr.element.sub_elements()[i])
for i in range(expr.element.num_sub_elements())
)
class FiniteElement(ufl.FiniteElement):
def __init__(self, *args, **kwargs):
if ('dirichlet_constraints' in kwargs) or ('dirichlet_expression' in kwargs):
......
cell_type = "interval"
cell = "interval"
f = Expression("return -2.0*x.size();", cell_type=cell_type)
g = Expression("return x.two_norm2();", cell_type=cell_type)
f = Expression("return -2.0*x.size();", cell=cell)
g = Expression("return x.two_norm2();", cell=cell)
V = FiniteElement("CG", cell_type, 1, dirichlet_expression=g)
V = FiniteElement("CG", cell, 1, dirichlet_expression=g)
u = TrialFunction(V)
v = TestFunction(V)
......
cell_type = "interval"
cell = "interval"
f = Expression("return -2.0*x.size();", cell_type=cell_type)
g = Expression("return x.two_norm2();", on_intersection=True, cell_type=cell_type)
f = Expression("return -2.0*x.size();", cell=cell)
g = Expression("return x.two_norm2();", on_intersection=True, cell=cell)
V = FiniteElement("DG", cell_type, 1)
V = FiniteElement("DG", cell, 1)
u = TrialFunction(V)
v = TestFunction(V)
n = FacetNormal(cell_type)('+')
n = FacetNormal(cell)('+')
gamma = 1.0
theta = 1.0
......
cell_type = "quadrilateral"
cell = "quadrilateral"
f = Expression("return -2.0*x.size();", cell_type=cell_type)
g = Expression("return x.two_norm2();", cell_type=cell_type)
f = Expression("return -2.0*x.size();", cell=cell)
g = Expression("return x.two_norm2();", cell=cell)
V = FiniteElement("CG", cell_type, 1, dirichlet_expression=g)
V = FiniteElement("CG", cell, 1, dirichlet_expression=g)
u = TrialFunction(V)
v = TestFunction(V)
......
cell_type = "quadrilateral"
cell = "quadrilateral"
f = Expression("return -2.0*x.size();", cell_type=cell_type)
g = Expression("return x.two_norm2();", on_intersection=True, cell_type=cell_type)
f = Expression("return -2.0*x.size();", cell=cell)
g = Expression("return x.two_norm2();", on_intersection=True, cell=cell)
V = FiniteElement("DG", cell_type, 1)
V = FiniteElement("DG", cell, 1)
u = TrialFunction(V)
v = TestFunction(V)
n = FacetNormal(cell_type)('+')
n = FacetNormal(cell)('+')
gamma = 1.0
theta = 1.0
......
cell_type = "hexahedron"
cell = "hexahedron"
f = Expression("return -2.0*x.size();", cell_type=cell_type)
g = Expression("return x.two_norm2();", cell_type=cell_type)
f = Expression("return -2.0*x.size();", cell=cell)
g = Expression("return x.two_norm2();", cell=cell)
V = FiniteElement("CG", cell_type, 1, dirichlet_expression=g)
V = FiniteElement("CG", cell, 1, dirichlet_expression=g)
u = TrialFunction(V)
v = TestFunction(V)
......
cell_type = "tetrahedron"
cell = "tetrahedron"
f = Expression("return -2.0*x.size();", cell_type=cell_type)
g = Expression("return x.two_norm2();", cell_type=cell_type)
f = Expression("return -2.0*x.size();", cell=cell)
g = Expression("return x.two_norm2();", cell=cell)
V = FiniteElement("CG", "tetrahedron", 1, dirichlet_expression=g)
u = TrialFunction(V)
......
cell_type = "hexahedron"
cell = "hexahedron"
f = Expression("return -2.0*x.size();", cell_type=cell_type)
g = Expression("return x.two_norm2();", on_intersection=True, cell_type=cell_type)
f = Expression("return -2.0*x.size();", cell=cell)
g = Expression("return x.two_norm2();", on_intersection=True, cell=cell)
V = FiniteElement("DG", cell_type, 1)
V = FiniteElement("DG", cell, 1)
u = TrialFunction(V)
v = TestFunction(V)
n = FacetNormal(cell_type)('+')
n = FacetNormal(cell)('+')
gamma = 1.0
theta = 1.0
......
cell_type = "tetrahedron"
cell = "tetrahedron"
f = Expression("return -2.0*x.size();", cell_type=cell_type)
g = Expression("return x.two_norm2();", on_intersection=True, cell_type=cell_type)
f = Expression("return -2.0*x.size();", cell=cell)
g = Expression("return x.two_norm2();", on_intersection=True, cell=cell)
V = FiniteElement("DG", cell_type, 1)
V = FiniteElement("DG", cell, 1)
u = TrialFunction(V)
v = TestFunction(V)
n = FacetNormal(cell_type)('+')
n = FacetNormal(cell)('+')
gamma = 1.0
theta = 1.0
......
v_bctype = Expression("if (x[0] < 1. - 1e-8) return 1; else return 0;", on_intersection=True)
v_dirichlet = Expression("Dune::FieldVector<double, 2> y(0.0); y[0] = 4*x[1]*(1.-x[1]); return y;")
g_v = Expression(("4*x[1]*(1.-x[1])", "0.0"))
g_p = Expression("8*x[0]")
g = g_v * g_p
cell = triangle
P2 = VectorElement("Lagrange", cell, 2, dirichlet_constraints=v_bctype, dirichlet_expression=v_dirichlet)
P2 = VectorElement("Lagrange", cell, 2, dirichlet_constraints=v_bctype, dirichlet_expression=g_v)
P1 = FiniteElement("Lagrange", cell, 1)
TH = P2 * P1
......
g = VectorExpression("Dune::FieldVector<double, 2> y(0.0); y[0]=4*x[1]*(1.-x[1]); return y;", on_intersection=True)
g_v = Expression(("4*x[1]*(1.-x[1])", "0.0"), on_intersection=True)
g_p = Expression("8*(1.-x[0])")
g = g_v * g_p
bctype = Expression("if (x[0] < 1. - 1e-8) return 1; else return 0;", on_intersection=True)
cell = triangle
......@@ -19,15 +21,15 @@ r = inner(grad(u), grad(v))*dx \
+ inner(avg(grad(u))*n, jump(v))*dS \
- eps * inner(avg(grad(v))*n, jump(u))*dS \
- inner(grad(u)*n, v)*ds \
+ eps * inner(grad(v)*n, u-g)*ds \
+ eps * inner(grad(v)*n, u-g_v)*ds \
+ sigma * inner(jump(u), jump(v))*dS \
+ sigma * inner(u-g, v)*ds \
+ sigma * inner(u-g_v, v)*ds \
- p*div(v)*dx \
- avg(p)*inner(jump(v), n)*dS \
+ p*inner(v, n)*ds \
- q*div(u)*dx \
- avg(q)*inner(jump(u), n)*dS \
+ q*inner(u, n)*ds \
- q*inner(g, n)*ds
- q*inner(g_v, n)*ds
forms = [r]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment