Skip to content
Snippets Groups Projects
Commit d22f81e7 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Fix jacobians of systems and use pytools.Record for modified arguments

parent 5f1f23f1
No related branches found
No related tags found
No related merge requests found
......@@ -14,7 +14,6 @@ from dune.perftool.generation import (domain,
valuearg,
get_global_context_value
)
from dune.perftool.ufl.modified_terminals import ModifiedArgumentDescriptor
from dune.perftool.pdelab import (name_index,
restricted_name,
)
......
......@@ -279,6 +279,7 @@ def determine_accumulation_space(expr, number, measure):
if len(args) == 0:
return AccumulationSpace()
# There should be but one modified argument, as the splitting eliminated all others.
assert(len(args) == 1)
ma, = args
......@@ -295,7 +296,7 @@ def determine_accumulation_space(expr, number, measure):
if len(subel.value_shape()) != 0:
from dune.perftool.pdelab.geometry import dimension_iname
idims = tuple(dimension_iname(context='arg', count=number) for i in range(len(subel.value_shape())))
idims = tuple(dimension_iname(context='arg', count=i) for i in range(len(subel.value_shape())))
lfs = lfs_child(lfs, idims, shape=subel.value_shape(), symmetry=subel.symmetry)
subel = subel.sub_elements()[0]
......@@ -370,7 +371,8 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
from ufl.domain import find_geometric_dimension
dim = find_geometric_dimension(accterm.argument.expr)
for i in accterm.argument.index._indices:
additional_inames = additional_inames.union(frozenset({grad_iname(i, dim)}))
if i not in visitor.dimension_indices:
additional_inames = additional_inames.union(frozenset({grad_iname(i, dim)}))
# It may happen that an entire accumulation term vanishes. We do nothing in that case
if pymbolic_expr == 0:
......@@ -441,20 +443,22 @@ def generate_kernel(integrals):
from dune.perftool.pdelab.spaces import traverse_lfs_tree
traverse_lfs_tree(ma)
from dune.perftool.options import set_option
set_option('print_transformations', True)
set_option('print_transformations_dir', '.')
# Now split the given integrand into accumulation expressions
from dune.perftool.ufl.transformations.extract_accumulation_terms import split_into_accumulation_terms
accterms = split_into_accumulation_terms(integrand)
# Get a transformer instance for this kernel
from dune.perftool.ufl.visitor import UFL2LoopyVisitor
visitor = UFL2LoopyVisitor(measure, dimension_indices)
# Iterate over the terms and generate a kernel
for term in accterms:
# Adjust the index map for the visitor
from copy import deepcopy
indexmap = deepcopy(dimension_indices)
for i, j in term.indexmap.items():
if i in indexmap:
indexmap[j] = indexmap[i]
# Get a transformer instance for this kernel
from dune.perftool.ufl.visitor import UFL2LoopyVisitor
visitor = UFL2LoopyVisitor(measure, indexmap)
generate_accumulation_instruction(visitor, term, measure, subdomain_id)
# Extract the information, which is needed to create a loopy kernel.
......
......@@ -161,8 +161,8 @@ def type_gfs(element, basetype=None, index_stack=None):
def traverse_lfs_tree(arg):
from dune.perftool.ufl.modified_terminals import ModifiedArgumentDescriptor
assert isinstance(arg, ModifiedArgumentDescriptor)
from dune.perftool.ufl.modified_terminals import ModifiedArgument
assert isinstance(arg, ModifiedArgument)
# First we need to determine the basename as given in the signature of
# this kernel method!
......
......@@ -3,6 +3,30 @@
from ufl.algorithms import MultiFunction
from dune.perftool import Restriction
from ufl.classes import MultiIndex
from pytools import Record
class ModifiedArgument(Record):
def __init__(self,
expr=None,
argexpr=None,
grad=False,
index=None,
reference_grad=False,
restriction=Restriction.NONE,
component=MultiIndex(()),
reference=False,
):
Record.__init__(self,
expr=expr,
argexpr=argexpr,
grad=grad,
index=index,
reference_grad=reference_grad,
restriction=restriction,
component=component,
reference=reference,
)
class ModifiedTerminalTracker(MultiFunction):
......@@ -10,6 +34,9 @@ class ModifiedTerminalTracker(MultiFunction):
grad, reference_grad, positive_restricted and negative_restricted.
The appearance of those classes changes the internal state of the MF.
"""
call = MultiFunction.__call__
def __init__(self):
MultiFunction.__init__(self)
self.grad = False
......@@ -59,57 +86,33 @@ class ModifiedTerminalTracker(MultiFunction):
return ret
class ModifiedArgumentDescriptor(MultiFunction):
def __init__(self, e):
MultiFunction.__init__(self)
self.grad = False
self.reference = False
self.reference_grad = False
class ModifiedArgumentAnalysis(ModifiedTerminalTracker):
def __init__(self):
self.index = None
self.restriction = Restriction.NONE
self.component = MultiIndex(())
self.expr = e
self.__call__(e)
self.__call__ = None
ModifiedTerminalTracker.__init__(self)
def __eq__(self, other):
return self.expr == other.expr
def grad(self, o):
self.grad = True
self(o.ufl_operands[0])
def reference_grad(self, o):
self.reference_grad = True
self(o.ufl_operands[0])
def reference_value(self, o):
self.reference = True
self(o.ufl_operands[0])
def positive_restricted(self, o):
self.restriction = Restriction.POSITIVE
self(o.ufl_operands[0])
def negative_restricted(self, o):
self.restriction = Restriction.NEGATIVE
self(o.ufl_operands[0])
def __call__(self, o):
self.call_expr = o
return self.call(o)
def indexed(self, o):
self.index = o.ufl_operands[1]
self(o.ufl_operands[0])
return self.call(o.ufl_operands[0])
def function_view(self, o):
self.component = o.ufl_operands[1]
self(o.ufl_operands[0])
def form_argument(self, o):
return ModifiedArgument(expr=self.call_expr,
argexpr=o,
index=self.index,
restriction=self.restriction,
component=self.component,
grad=self.grad,
reference_grad=self.reference_grad,
reference=self.reference,
)
def argument(self, o):
self.argexpr = o
def coefficient(self, o):
self.argexpr = o
def analyse_modified_argument(expr):
return ModifiedArgumentAnalysis()(expr)
class _ModifiedArgumentExtractor(MultiFunction):
......@@ -123,7 +126,7 @@ class _ModifiedArgumentExtractor(MultiFunction):
if ret:
# This indicates that this entire expression was a modified thing...
self.modified_arguments.add(ret)
return tuple(ModifiedArgumentDescriptor(ma) for ma in self.modified_arguments)
return tuple(analyse_modified_argument(ma) for ma in self.modified_arguments)
def expr(self, o):
for op in o.ufl_operands:
......
......@@ -7,7 +7,7 @@ from dune.perftool.ufl.transformations import ufl_transformation
from dune.perftool.ufl.transformations.replace import replace_expression
from dune.perftool.ufl.transformations.identitypropagation import identity_propagation
from dune.perftool.ufl.transformations.reindexing import reindexing
from dune.perftool.ufl.modified_terminals import ModifiedArgumentDescriptor
from dune.perftool.ufl.modified_terminals import analyse_modified_argument, ModifiedArgument
from ufl.classes import Zero, Identity, Indexed, IntValue, MultiIndex
from ufl.core.multiindex import indices
......@@ -16,12 +16,21 @@ from pytools import Record
class AccumulationTerm(Record):
def __init__(self, term, argument):
Record.__init__(self, term=term, argument=argument)
def __init__(self,
term,
argument,
indexmap={},
):
assert isinstance(argument, ModifiedArgument)
Record.__init__(self,
term=term,
argument=argument,
indexmap=indexmap,
)
@ufl_transformation(name="accterms", extraction_lambda=lambda l: [at.term for at in l])
def split_into_accumulation_terms(expr):
def split_into_accumulation_terms(expr, indexmap={}):
ret = []
# Extract a list of modified terminals for the test function
......@@ -29,8 +38,7 @@ def split_into_accumulation_terms(expr):
test_args = extract_modified_arguments(expr, argnumber=0)
# Extract a list of modified terminals for the ansatz function
# in jacobian forms. Only the restriction of those terminals will
# be used to generate new accumulation terms!
# in jacobian forms.
all_jacobian_args = extract_modified_arguments(expr, argnumber=1)
for test_arg in test_args:
......@@ -39,18 +47,23 @@ def split_into_accumulation_terms(expr):
# 1) We first cut the expression to the relevant modified test_function
# Build a replacement dictionary
replacement = {ma.expr: Zero() for ma in test_args}
replacement = {ma.expr: Zero(shape=ma.expr.ufl_shape,
free_indices=ma.expr.ufl_free_indices,
index_dimensions=ma.expr.ufl_index_dimensions)
for ma in test_args}
replacement[test_arg.expr] = test_arg.expr
replace_expr = replace_expression(expr, replacemap=replacement)
# 2) Cut the test function itself from the expression
indexmap = {}
if test_arg.index:
newi = indices(len(test_arg.index))
identities = tuple(Indexed(Identity(2), MultiIndex((i,) + (j,))) for i, j in zip(newi, test_arg.index._indices))
indexmap = {i: j for i, j in zip(test_arg.index._indices, newi)}
from dune.perftool.ufl.flatoperators import construct_binary_operator
from ufl.classes import Product
replacement = {test_arg.expr: construct_binary_operator(identities, Product)}
test_arg = ModifiedArgumentDescriptor(reindexing(test_arg.expr, replacemap={i: j for i, j in zip(test_arg.index._indices, newi)}))
test_arg = analyse_modified_argument(reindexing(test_arg.expr, replacemap=indexmap))
else:
replacement = {test_arg.expr: IntValue(1)}
replace_expr = replace_expression(replace_expr, replacemap=replacement)
......@@ -60,16 +73,21 @@ def split_into_accumulation_terms(expr):
# 4) Further split according to trial function in jacobian terms
if all_jacobian_args:
for jac_arg in all_jacobian_args:
# Update the list!
jac_args = extract_modified_arguments(replace_expr, argnumber=1)
for jac_arg in jac_args:
# TODO Some jacobian terms can be joined
replacement = {ma.expr: Zero() for ma in all_jacobian_args}
replacement = {ma.expr: Zero(shape=ma.expr.ufl_shape,
free_indices=ma.expr.ufl_free_indices,
index_dimensions=ma.expr.ufl_index_dimensions)
for ma in jac_args}
replacement[jac_arg.expr] = jac_arg.expr
jac_expr = replace_expression(replace_expr, replacemap=replacement)
if not isinstance(jac_expr, Zero):
ret.append(AccumulationTerm(jac_expr, test_arg))
ret.append(AccumulationTerm(jac_expr, test_arg, indexmap))
else:
if not isinstance(replace_expr, Zero):
ret.append(AccumulationTerm(replace_expr, test_arg))
ret.append(AccumulationTerm(replace_expr, test_arg, indexmap))
return ret
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment