Skip to content
Snippets Groups Projects
Commit 7be932c5 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Some steps towards DG

parent 1046a2fe
No related branches found
No related tags found
No related merge requests found
Showing with 156 additions and 46 deletions
...@@ -11,3 +11,6 @@ ...@@ -11,3 +11,6 @@
path = python/ufl path = python/ufl
url = https://bitbucket.org/fenics-project/ufl.git url = https://bitbucket.org/fenics-project/ufl.git
ignore = untracked ignore = untracked
[submodule "python/pymbolic"]
path = python/pymbolic
url = https://github.com/inducer/pymbolic.git
class Restriction: class Restriction:
NONE = 0 NONE = 0
INSIDE = 1 NEGATIVE = 1
OUTSIDE = 2 POSITIVE = 2
...@@ -49,6 +49,8 @@ def read_ufl(uflfile): ...@@ -49,6 +49,8 @@ def read_ufl(uflfile):
form = transform_form(form, reindexing) form = transform_form(form, reindexing)
# form = transform_form(form, split_arguments) # form = transform_form(form, split_arguments)
formdata.preprocessed_form = form
return formdata, data.object_names return formdata, data.object_names
......
...@@ -73,9 +73,11 @@ def show_code(which, kernel): ...@@ -73,9 +73,11 @@ def show_code(which, kernel):
clear() clear()
print("Showing the generated dune-pdelab code for {}:\n".format(kernel_name(which))) print("Showing the generated dune-pdelab code for {}:\n".format(kernel_name(which)))
from dune.perftool.pdelab.localoperator import measure_specific_details, AssemblyMethod from dune.perftool.generation import global_context
signature = measure_specific_details(which[0])["{}_signature".format(which[1])] with global_context(integral_type=which[0], form_type=which[1]):
print("".join(AssemblyMethod(signature, kernel).generate())) from dune.perftool.pdelab.localoperator import assembly_routine_signature, AssemblyMethod
signature = assembly_routine_signature()
print("".join(AssemblyMethod(signature, kernel).generate()))
print("Press Return to return to the previous menu") print("Press Return to return to the previous menu")
input() input()
......
...@@ -8,8 +8,9 @@ from __future__ import absolute_import ...@@ -8,8 +8,9 @@ from __future__ import absolute_import
from dune.perftool import Restriction from dune.perftool import Restriction
from dune.perftool.ufl.modified_terminals import ModifiedTerminalTracker from dune.perftool.ufl.modified_terminals import ModifiedTerminalTracker
from dune.perftool.pymbolic.uflmapper import UFL2PymbolicMapper from dune.perftool.pymbolic.uflmapper import UFL2PymbolicMapper
from dune.perftool.pdelab.geometry import GeometryMapper
from dune.perftool.generation import (domain, from dune.perftool.generation import (domain,
global_context,
globalarg, globalarg,
iname, iname,
instruction, instruction,
...@@ -31,7 +32,7 @@ def index_sum_iname(i): ...@@ -31,7 +32,7 @@ def index_sum_iname(i):
return name_index(i) return name_index(i)
class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper): class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapper):
def __init__(self): def __init__(self):
super(UFL2LoopyVisitor, self).__init__() super(UFL2LoopyVisitor, self).__init__()
...@@ -73,12 +74,6 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper): ...@@ -73,12 +74,6 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper):
# Now continue processing the expression # Now continue processing the expression
return self.call(o.ufl_operands[0]) return self.call(o.ufl_operands[0])
# TODO use multiple inheritance and have a geometry transformer in the pdelab subpackage
def facet_area(self, o):
from pymbolic.primitives import Variable
from dune.perftool.pdelab.geometry import name_facetarea
return Variable(name_facetarea())
class _Counter: class _Counter:
counter = 0 counter = 0
...@@ -106,19 +101,21 @@ def transform_accumulation_term(term, measure, subdomain_id): ...@@ -106,19 +101,21 @@ def transform_accumulation_term(term, measure, subdomain_id):
rmap = {} rmap = {}
for ma in test_ma: for ma in test_ma:
# Set up the local function space structure with global_context(restriction=ma.restriction):
traverse_lfs_tree(ma) # Set up the local function space structure
traverse_lfs_tree(ma)
# Get the expression for the modified argument representing the test function # Get the expression for the modified argument representing the test function
from dune.perftool.pdelab.argument import pymbolic_testfunction from dune.perftool.pdelab.argument import pymbolic_testfunction
rmap[ma.expr] = pymbolic_testfunction(ma) rmap[ma.expr] = pymbolic_testfunction(ma)
for ma in trial_ma: for ma in trial_ma:
# Set up the local function space structure with global_context(restriction=ma.restriction):
traverse_lfs_tree(ma) # Set up the local function space structure
traverse_lfs_tree(ma)
# Get the expression for the modified argument representing the trial function # Get the expression for the modified argument representing the trial function
from dune.perftool.pdelab.argument import pymbolic_trialfunction from dune.perftool.pdelab.argument import pymbolic_trialfunction
rmap[ma.expr] = pymbolic_trialfunction(ma) rmap[ma.expr] = pymbolic_trialfunction(ma)
# Get the transformer! # Get the transformer!
ufl2l_mf = UFL2LoopyVisitor() ufl2l_mf = UFL2LoopyVisitor()
......
...@@ -10,7 +10,10 @@ from dune.perftool.generation import (preamble, ...@@ -10,7 +10,10 @@ from dune.perftool.generation import (preamble,
def name_index(index): def name_index(index):
from ufl.classes import MultiIndex, Index from ufl.classes import MultiIndex, Index
if isinstance(index, Index): if isinstance(index, Index):
return str(index) # This failed for index > 9 because ufl placed curly brackets around
# return str(index)
return "i_{}".format(index.count())
if isinstance(index, MultiIndex): if isinstance(index, MultiIndex):
assert len(index) == 1 assert len(index) == 1
return str(index._indices[0]) # return str(index._indices[0])
return "i_{}".format(index._indices[0].count())
...@@ -269,7 +269,7 @@ def evaluate_trialfunction(element, name): ...@@ -269,7 +269,7 @@ def evaluate_trialfunction(element, name):
temporary_variable(name, shape=()) temporary_variable(name, shape=())
lfs = name_lfs(element) lfs = name_lfs(element)
index = lfs_iname(element) index = lfs_iname(element)
basis = name_basis() basis = name_basis(element)
instruction(inames=(quadrature_iname(), instruction(inames=(quadrature_iname(),
index, index,
), ),
......
...@@ -63,6 +63,10 @@ def isQk(fem): ...@@ -63,6 +63,10 @@ def isQk(fem):
return isLagrange(fem) and isQuadrilateral(fem) return isLagrange(fem) and isQuadrilateral(fem)
def isDG(fem):
return fem._short_name is 'DG'
def FEM_name_mangling(fem): def FEM_name_mangling(fem):
from ufl import MixedElement, VectorElement, FiniteElement from ufl import MixedElement, VectorElement, FiniteElement
if isinstance(fem, MixedElement): if isinstance(fem, MixedElement):
...@@ -79,6 +83,9 @@ def FEM_name_mangling(fem): ...@@ -79,6 +83,9 @@ def FEM_name_mangling(fem):
return "P" + str(fem._degree) return "P" + str(fem._degree)
if isQk(fem): if isQk(fem):
return "Q" + str(fem._degree) return "Q" + str(fem._degree)
if isDG(fem):
return "DG" + str(fem._degree)
raise NotImplementedError("FEM NAME MANGLING") raise NotImplementedError("FEM NAME MANGLING")
...@@ -242,12 +249,23 @@ def typedef_fem(expr, name): ...@@ -242,12 +249,23 @@ def typedef_fem(expr, name):
gv = type_leafview() gv = type_leafview()
df = type_domainfield() df = type_domainfield()
r = type_range() r = type_range()
dim = name_dimension()
if isPk(expr): if isPk(expr):
include_file("dune/pdelab/finiteelementmap/pkfem.hh", filetag="driver") include_file("dune/pdelab/finiteelementmap/pkfem.hh", filetag="driver")
return "typedef Dune::PDELab::PkLocalFiniteElementMap<{}, {}, {}, {}> {};".format(gv, df, r, expr._degree, name) return "typedef Dune::PDELab::PkLocalFiniteElementMap<{}, {}, {}, {}> {};".format(gv, df, r, expr._degree, name)
if isQk(generator._kwargs['expr']): if isQk(expr):
include_file("dune/pdelab/finiteelementmap/qkfem.hh", filetag="driver") include_file("dune/pdelab/finiteelementmap/qkfem.hh", filetag="driver")
return "typedef Dune::PDELab::QkLocalFiniteElementMap<{}, {}, {}, {}> {};".format(gv, df, r, expr._degree, name) return "typedef Dune::PDELab::QkLocalFiniteElementMap<{}, {}, {}, {}> {};".format(gv, df, r, expr._degree, name)
if isDG(expr):
if isQuadrilateral(expr):
include_file("dune/pdelab/finiteelementmap/qkdg.hh", filetag="driver")
# TODO allow switching the basis here!
return "typedef Dune::PDELab::QkDGLocalFiniteElementMap<{}, {}, {}, {}> {}".format(df, r, expr._degree, dim, name)
if isSimplical(expr):
include_file("dune/pdelab/finiteelementmap/opbfem.hh", filetag="driver")
return "typedef Dune::PDELab::OPBLocalFiniteElementMap<{}, {}, {}, {}, Dune::GeometryType::cube> {}".format(df, r, expr._degree, dim, name)
raise NotImplementedError("Geometry type not known in code generation")
raise NotImplementedError("FEM not implemented in dune-perftool") raise NotImplementedError("FEM not implemented in dune-perftool")
...@@ -261,8 +279,11 @@ def type_fem(expr): ...@@ -261,8 +279,11 @@ def type_fem(expr):
@preamble @preamble
def define_fem(expr, name): def define_fem(expr, name):
femtype = type_fem(expr) femtype = type_fem(expr)
gv = name_leafview() if isDG(expr):
return "{} {}({});".format(femtype, name, gv) return "{} {};".format(femtype, name)
else:
gv = name_leafview()
return "{} {}({});".format(femtype, name, gv)
@symbol @symbol
......
...@@ -6,6 +6,37 @@ from dune.perftool.generation import (preamble, ...@@ -6,6 +6,37 @@ from dune.perftool.generation import (preamble,
from dune.perftool.pdelab.quadrature import (name_quadrature_position, from dune.perftool.pdelab.quadrature import (name_quadrature_position,
quadrature_preamble, quadrature_preamble,
) )
from ufl.algorithms import MultiFunction
class GeometryMapper(MultiFunction):
"""
A collection of visitors for geometry related UFL nodes
NB: This is kind of 'abstract' as it needs to be combined
with a ModifiedTerminalTracker through multi inheritance.
"""
def __init__(self):
super(GeometryMapper, self).__init__()
def facet_normal(self, o):
# The normal must be restricted to be well-defined
assert self.restriction is not Restriction.NONE
from pymbolic.primitives import Variable
if self.restriction == Restriction.POSITIVE:
return Variable(name_unit_outer_normal())
if self.restriction == Restriction.NEGATIVE:
# It is highly unnatural to have this generator function,
# but I do run into subtle trouble with return -1*outer
# as the indexing into the normal happens only later.
# Not investing more time into this cornercase right now.
return Variable(name_unit_inner_normal())
# TODO This one was just copied over so, it might need some tweaking
def facet_area(self, o):
from pymbolic.primitives import Variable
from dune.perftool.pdelab.geometry import name_facetarea
return Variable(name_facetarea())
@symbol @symbol
...@@ -55,7 +86,7 @@ def type_geometry_wrapper(): ...@@ -55,7 +86,7 @@ def type_geometry_wrapper():
@preamble @preamble
def define_restricted_cell(name, restriction): def define_restricted_cell(name, restriction):
ig = name_intersection_geometry_wrapper() ig = name_intersection_geometry_wrapper()
which = "inside" if restriction == Restriction.INSIDE else "outside" which = "inside" if restriction == Restriction.NEGATIVE else "outside"
return "const auto& {} = {}.{}();".format(name, return "const auto& {} = {}.{}();".format(name,
ig, ig,
which, which,
...@@ -68,7 +99,7 @@ def _name_cell(restriction): ...@@ -68,7 +99,7 @@ def _name_cell(restriction):
eg = name_element_geometry_wrapper() eg = name_element_geometry_wrapper()
return "{}.entity()".format(eg) return "{}.entity()".format(eg)
which = "inside" if restriction == Restriction.INSIDE else "outside" which = "inside" if restriction == Restriction.NEGATIVE else "outside"
name = "cell_{}".format(which) name = "cell_{}".format(which)
define_restricted_cell(name, restriction) define_restricted_cell(name, restriction)
return name return name
...@@ -81,7 +112,7 @@ def name_cell(): ...@@ -81,7 +112,7 @@ def name_cell():
if it == 'cell': if it == 'cell':
r = Restriction.NONE r = Restriction.NONE
if it == 'exterior_facet': if it == 'exterior_facet':
r = Restriction.INSIDE r = Restriction.NEGATIVE
if it == 'interior_facet': if it == 'interior_facet':
raise NotImplementedError raise NotImplementedError
...@@ -138,7 +169,7 @@ def name_geometry(): ...@@ -138,7 +169,7 @@ def name_geometry():
def define_in_cell_geometry(restriction, name): def define_in_cell_geometry(restriction, name):
cell = _name_cell(restriction) cell = _name_cell(restriction)
ig = name_intersection_geometry_wrapper() ig = name_intersection_geometry_wrapper()
which = "In" if restriction == Restriction.INSIDE else "Out" which = "In" if restriction == Restriction.NEGATIVE else "Out"
return "auto {} = {}.geometryIn{}side();".format(name, return "auto {} = {}.geometryIn{}side();".format(name,
ig, ig,
which which
...@@ -149,7 +180,7 @@ def define_in_cell_geometry(restriction, name): ...@@ -149,7 +180,7 @@ def define_in_cell_geometry(restriction, name):
def name_in_cell_geometry(restriction): def name_in_cell_geometry(restriction):
assert restriction is not Restriction.NONE assert restriction is not Restriction.NONE
name = "geo_in_{}side".format("in" if restriction is Restriction.INSIDE else "out") name = "geo_in_{}side".format("in" if restriction is Restriction.NEGATIVE else "out")
define_in_cell_geometry(restriction, name) define_in_cell_geometry(restriction, name)
return name return name
...@@ -168,7 +199,7 @@ def apply_in_cell_transformation(name, local, restriction): ...@@ -168,7 +199,7 @@ def apply_in_cell_transformation(name, local, restriction):
def name_in_cell_coordinates(local, basename, restriction): def name_in_cell_coordinates(local, basename, restriction):
name = "{}_in_inside".format(basename) name = "{}_in_inside".format(basename)
temporary_variable(name, shape=(name_dimension(),), shape_impl=("fv",)) temporary_variable(name, shape=(name_dimension(),), shape_impl=("fv",))
apply_in_cell_transformation(name, local, restriction=Restriction.INSIDE) apply_in_cell_transformation(name, local, restriction=Restriction.NEGATIVE)
return name return name
...@@ -179,9 +210,10 @@ def to_cell_coordinates(local, basename): ...@@ -179,9 +210,10 @@ def to_cell_coordinates(local, basename):
if it == 'cell': if it == 'cell':
return local return local
if it == 'exterior_facet': if it == 'exterior_facet':
return name_in_cell_coordinates(local, basename, Restriction.INSIDE) return name_in_cell_coordinates(local, basename, Restriction.NEGATIVE)
if it == 'interior_facet': if it == 'interior_facet':
raise NotImplementedError restriction = get_global_context_value("restriction")
return name_in_cell_coordinates(local, basename, restriction)
@preamble @preamble
...@@ -196,6 +228,44 @@ def name_dimension(): ...@@ -196,6 +228,44 @@ def name_dimension():
return "dim" return "dim"
def evaluate_unit_outer_normal(name):
ig = name_intersection_geometry_wrapper()
qp = name_quadrature_position()
return quadrature_preamble("{} = {}.unitOuterNormal({});".format(name, ig, qp),
assignees=frozenset({name}),
)
@preamble
def declare_normal(name, shape, shape_impl):
ig = name_intersection_geometry_wrapper()
return "auto {} = {}.centerUnitOuterNormal();".format(name, ig)
@symbol
def name_unit_outer_normal():
name = "outer_normal"
temporary_variable(name, shape=(name_dimension(),), decl_method=declare_normal)
evaluate_unit_outer_normal(name)
return "outer_normal"
def evaluate_unit_inner_normal(name):
outer = name_unit_outer_normal()
return quadrature_preamble("auto {} = -1. * {};".format(name, outer),
assignees=frozenset({name}),
read_variables=frozenset({outer}),
)
@symbol
def name_unit_inner_normal():
name = "inner_normal"
temporary_variable(name, shape=(name_dimension(),), decl_method=declare_normal)
evaluate_unit_inner_normal(name)
return "inner_normal"
@symbol @symbol
def type_jacobian_inverse_transposed(): def type_jacobian_inverse_transposed():
geo = type_element_geometry_wrapper() geo = type_element_geometry_wrapper()
......
...@@ -24,12 +24,12 @@ class UFL2PymbolicMapper(MultiFunction): ...@@ -24,12 +24,12 @@ class UFL2PymbolicMapper(MultiFunction):
return Product(tuple(self.call(op) for op in get_operands(o))) return Product(tuple(self.call(op) for op in get_operands(o)))
def multi_index(self, o): def multi_index(self, o):
return tuple(self.call(op) for op in o.ufl_operands) from dune.perftool.pdelab import name_index
return tuple(Variable(name_index(op)) for op in o.indices())
def index(self, o): def index(self, o):
# One might as well take the uflname as string here, but I apply this function # One might as well take the uflname as string here, but I apply this function
from dune.perftool.pdelab import name_index from dune.perftool.pdelab import name_index
return Variable(name_index(o)) return Variable(name_index(o))
def fixed_index(self, o): def fixed_index(self, o):
......
...@@ -18,14 +18,22 @@ class ModifiedTerminalTracker(MultiFunction): ...@@ -18,14 +18,22 @@ class ModifiedTerminalTracker(MultiFunction):
def positive_restricted(self, o): def positive_restricted(self, o):
assert self.restriction == Restriction.NONE assert self.restriction == Restriction.NONE
self.restriction = Restriction.POSITIVE self.restriction = Restriction.POSITIVE
ret = self.call(o.ufl_operands[0])
from dune.perftool.generation import global_context
with global_context(restriction=Restriction.POSITIVE):
ret = self.call(o.ufl_operands[0])
self.restriction = Restriction.NONE self.restriction = Restriction.NONE
return ret return ret
def negative_restricted(self, o): def negative_restricted(self, o):
assert self.restriction == Restriction.NONE assert self.restriction == Restriction.NONE
self.restriction = Restriction.NEGATIVE self.restriction = Restriction.NEGATIVE
ret = self.call(o.ufl_operands[0])
from dune.perftool.generation import global_context
with global_context(restriction=Restriction.NEGATIVE):
ret = self.call(o.ufl_operands[0])
self.restriction = Restriction.NONE self.restriction = Restriction.NONE
return ret return ret
......
...@@ -19,7 +19,7 @@ class _ReplacementDict(dict): ...@@ -19,7 +19,7 @@ class _ReplacementDict(dict):
def __init__(self, good=[], bad=[]): def __init__(self, good=[], bad=[]):
dict.__init__(self) dict.__init__(self)
for a in bad: for a in bad:
self[a] = Zero() self[a] = Zero(shape=a.ufl_shape, free_indices=a.ufl_free_indices, index_dimensions=a.ufl_index_dimensions)
for a in good: for a in good:
self[a] = a self[a] = a
......
Subproject commit 80ef7a7745829f53ae949f68ad458b88c06d66a0
...@@ -4,14 +4,17 @@ V = FiniteElement("DG", cell, 1) ...@@ -4,14 +4,17 @@ V = FiniteElement("DG", cell, 1)
u = TrialFunction(V) u = TrialFunction(V)
v = TestFunction(V) v = TestFunction(V)
n = FacetNormal(cell) n = FacetNormal(cell)('+')
gamma = 1.0 gamma = 1.0
theta = 1.0 theta = 1.0
r = inner(grad(u), grad(v))*dx \ r = inner(grad(u), grad(v))*dx \
- inner(n, avg(grad(u)))*jump(v)*(dS+ds) \ - inner(n, avg(grad(u)))*jump(v)*dS \
+ gamma*jump(u)*jump(v)*(dS+ds) \ + gamma*jump(u)*jump(v)*dS \
- theta*jump(u)*inner(grad(v), n)*(dS+ds) - theta*jump(u)*inner(avg(grad(v)), n)*dS \
- inner(n, grad(u))*v*ds \
+ gamma*u*v*ds \
- theta*u*inner(grad(v), n)*ds
forms = [r] forms = [r]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment