Skip to content
Snippets Groups Projects
Commit d22f81e7 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Fix jacobians of systems and use pytools.Record for modified arguments

parent 5f1f23f1
No related branches found
No related tags found
No related merge requests found
...@@ -14,7 +14,6 @@ from dune.perftool.generation import (domain, ...@@ -14,7 +14,6 @@ from dune.perftool.generation import (domain,
valuearg, valuearg,
get_global_context_value get_global_context_value
) )
from dune.perftool.ufl.modified_terminals import ModifiedArgumentDescriptor
from dune.perftool.pdelab import (name_index, from dune.perftool.pdelab import (name_index,
restricted_name, restricted_name,
) )
......
...@@ -279,6 +279,7 @@ def determine_accumulation_space(expr, number, measure): ...@@ -279,6 +279,7 @@ def determine_accumulation_space(expr, number, measure):
if len(args) == 0: if len(args) == 0:
return AccumulationSpace() return AccumulationSpace()
# There should be but one modified argument, as the splitting eliminated all others.
assert(len(args) == 1) assert(len(args) == 1)
ma, = args ma, = args
...@@ -295,7 +296,7 @@ def determine_accumulation_space(expr, number, measure): ...@@ -295,7 +296,7 @@ def determine_accumulation_space(expr, number, measure):
if len(subel.value_shape()) != 0: if len(subel.value_shape()) != 0:
from dune.perftool.pdelab.geometry import dimension_iname from dune.perftool.pdelab.geometry import dimension_iname
idims = tuple(dimension_iname(context='arg', count=number) for i in range(len(subel.value_shape()))) idims = tuple(dimension_iname(context='arg', count=i) for i in range(len(subel.value_shape())))
lfs = lfs_child(lfs, idims, shape=subel.value_shape(), symmetry=subel.symmetry) lfs = lfs_child(lfs, idims, shape=subel.value_shape(), symmetry=subel.symmetry)
subel = subel.sub_elements()[0] subel = subel.sub_elements()[0]
...@@ -370,7 +371,8 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): ...@@ -370,7 +371,8 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
from ufl.domain import find_geometric_dimension from ufl.domain import find_geometric_dimension
dim = find_geometric_dimension(accterm.argument.expr) dim = find_geometric_dimension(accterm.argument.expr)
for i in accterm.argument.index._indices: for i in accterm.argument.index._indices:
additional_inames = additional_inames.union(frozenset({grad_iname(i, dim)})) if i not in visitor.dimension_indices:
additional_inames = additional_inames.union(frozenset({grad_iname(i, dim)}))
# It may happen that an entire accumulation term vanishes. We do nothing in that case # It may happen that an entire accumulation term vanishes. We do nothing in that case
if pymbolic_expr == 0: if pymbolic_expr == 0:
...@@ -441,20 +443,22 @@ def generate_kernel(integrals): ...@@ -441,20 +443,22 @@ def generate_kernel(integrals):
from dune.perftool.pdelab.spaces import traverse_lfs_tree from dune.perftool.pdelab.spaces import traverse_lfs_tree
traverse_lfs_tree(ma) traverse_lfs_tree(ma)
from dune.perftool.options import set_option
set_option('print_transformations', True)
set_option('print_transformations_dir', '.')
# Now split the given integrand into accumulation expressions # Now split the given integrand into accumulation expressions
from dune.perftool.ufl.transformations.extract_accumulation_terms import split_into_accumulation_terms from dune.perftool.ufl.transformations.extract_accumulation_terms import split_into_accumulation_terms
accterms = split_into_accumulation_terms(integrand) accterms = split_into_accumulation_terms(integrand)
# Get a transformer instance for this kernel
from dune.perftool.ufl.visitor import UFL2LoopyVisitor
visitor = UFL2LoopyVisitor(measure, dimension_indices)
# Iterate over the terms and generate a kernel # Iterate over the terms and generate a kernel
for term in accterms: for term in accterms:
# Adjust the index map for the visitor
from copy import deepcopy
indexmap = deepcopy(dimension_indices)
for i, j in term.indexmap.items():
if i in indexmap:
indexmap[j] = indexmap[i]
# Get a transformer instance for this kernel
from dune.perftool.ufl.visitor import UFL2LoopyVisitor
visitor = UFL2LoopyVisitor(measure, indexmap)
generate_accumulation_instruction(visitor, term, measure, subdomain_id) generate_accumulation_instruction(visitor, term, measure, subdomain_id)
# Extract the information, which is needed to create a loopy kernel. # Extract the information, which is needed to create a loopy kernel.
......
...@@ -161,8 +161,8 @@ def type_gfs(element, basetype=None, index_stack=None): ...@@ -161,8 +161,8 @@ def type_gfs(element, basetype=None, index_stack=None):
def traverse_lfs_tree(arg): def traverse_lfs_tree(arg):
from dune.perftool.ufl.modified_terminals import ModifiedArgumentDescriptor from dune.perftool.ufl.modified_terminals import ModifiedArgument
assert isinstance(arg, ModifiedArgumentDescriptor) assert isinstance(arg, ModifiedArgument)
# First we need to determine the basename as given in the signature of # First we need to determine the basename as given in the signature of
# this kernel method! # this kernel method!
......
...@@ -3,6 +3,30 @@ ...@@ -3,6 +3,30 @@
from ufl.algorithms import MultiFunction from ufl.algorithms import MultiFunction
from dune.perftool import Restriction from dune.perftool import Restriction
from ufl.classes import MultiIndex from ufl.classes import MultiIndex
from pytools import Record
class ModifiedArgument(Record):
def __init__(self,
expr=None,
argexpr=None,
grad=False,
index=None,
reference_grad=False,
restriction=Restriction.NONE,
component=MultiIndex(()),
reference=False,
):
Record.__init__(self,
expr=expr,
argexpr=argexpr,
grad=grad,
index=index,
reference_grad=reference_grad,
restriction=restriction,
component=component,
reference=reference,
)
class ModifiedTerminalTracker(MultiFunction): class ModifiedTerminalTracker(MultiFunction):
...@@ -10,6 +34,9 @@ class ModifiedTerminalTracker(MultiFunction): ...@@ -10,6 +34,9 @@ class ModifiedTerminalTracker(MultiFunction):
grad, reference_grad, positive_restricted and negative_restricted. grad, reference_grad, positive_restricted and negative_restricted.
The appearance of those classes changes the internal state of the MF. The appearance of those classes changes the internal state of the MF.
""" """
call = MultiFunction.__call__
def __init__(self): def __init__(self):
MultiFunction.__init__(self) MultiFunction.__init__(self)
self.grad = False self.grad = False
...@@ -59,57 +86,33 @@ class ModifiedTerminalTracker(MultiFunction): ...@@ -59,57 +86,33 @@ class ModifiedTerminalTracker(MultiFunction):
return ret return ret
class ModifiedArgumentDescriptor(MultiFunction): class ModifiedArgumentAnalysis(ModifiedTerminalTracker):
def __init__(self, e): def __init__(self):
MultiFunction.__init__(self)
self.grad = False
self.reference = False
self.reference_grad = False
self.index = None self.index = None
self.restriction = Restriction.NONE ModifiedTerminalTracker.__init__(self)
self.component = MultiIndex(())
self.expr = e
self.__call__(e)
self.__call__ = None
def __eq__(self, other): def __call__(self, o):
return self.expr == other.expr self.call_expr = o
return self.call(o)
def grad(self, o):
self.grad = True
self(o.ufl_operands[0])
def reference_grad(self, o):
self.reference_grad = True
self(o.ufl_operands[0])
def reference_value(self, o):
self.reference = True
self(o.ufl_operands[0])
def positive_restricted(self, o):
self.restriction = Restriction.POSITIVE
self(o.ufl_operands[0])
def negative_restricted(self, o):
self.restriction = Restriction.NEGATIVE
self(o.ufl_operands[0])
def indexed(self, o): def indexed(self, o):
self.index = o.ufl_operands[1] self.index = o.ufl_operands[1]
self(o.ufl_operands[0]) return self.call(o.ufl_operands[0])
def function_view(self, o): def form_argument(self, o):
self.component = o.ufl_operands[1] return ModifiedArgument(expr=self.call_expr,
self(o.ufl_operands[0]) argexpr=o,
index=self.index,
restriction=self.restriction,
component=self.component,
grad=self.grad,
reference_grad=self.reference_grad,
reference=self.reference,
)
def argument(self, o):
self.argexpr = o
def coefficient(self, o): def analyse_modified_argument(expr):
self.argexpr = o return ModifiedArgumentAnalysis()(expr)
class _ModifiedArgumentExtractor(MultiFunction): class _ModifiedArgumentExtractor(MultiFunction):
...@@ -123,7 +126,7 @@ class _ModifiedArgumentExtractor(MultiFunction): ...@@ -123,7 +126,7 @@ class _ModifiedArgumentExtractor(MultiFunction):
if ret: if ret:
# This indicates that this entire expression was a modified thing... # This indicates that this entire expression was a modified thing...
self.modified_arguments.add(ret) self.modified_arguments.add(ret)
return tuple(ModifiedArgumentDescriptor(ma) for ma in self.modified_arguments) return tuple(analyse_modified_argument(ma) for ma in self.modified_arguments)
def expr(self, o): def expr(self, o):
for op in o.ufl_operands: for op in o.ufl_operands:
......
...@@ -7,7 +7,7 @@ from dune.perftool.ufl.transformations import ufl_transformation ...@@ -7,7 +7,7 @@ from dune.perftool.ufl.transformations import ufl_transformation
from dune.perftool.ufl.transformations.replace import replace_expression from dune.perftool.ufl.transformations.replace import replace_expression
from dune.perftool.ufl.transformations.identitypropagation import identity_propagation from dune.perftool.ufl.transformations.identitypropagation import identity_propagation
from dune.perftool.ufl.transformations.reindexing import reindexing from dune.perftool.ufl.transformations.reindexing import reindexing
from dune.perftool.ufl.modified_terminals import ModifiedArgumentDescriptor from dune.perftool.ufl.modified_terminals import analyse_modified_argument, ModifiedArgument
from ufl.classes import Zero, Identity, Indexed, IntValue, MultiIndex from ufl.classes import Zero, Identity, Indexed, IntValue, MultiIndex
from ufl.core.multiindex import indices from ufl.core.multiindex import indices
...@@ -16,12 +16,21 @@ from pytools import Record ...@@ -16,12 +16,21 @@ from pytools import Record
class AccumulationTerm(Record): class AccumulationTerm(Record):
def __init__(self, term, argument): def __init__(self,
Record.__init__(self, term=term, argument=argument) term,
argument,
indexmap={},
):
assert isinstance(argument, ModifiedArgument)
Record.__init__(self,
term=term,
argument=argument,
indexmap=indexmap,
)
@ufl_transformation(name="accterms", extraction_lambda=lambda l: [at.term for at in l]) @ufl_transformation(name="accterms", extraction_lambda=lambda l: [at.term for at in l])
def split_into_accumulation_terms(expr): def split_into_accumulation_terms(expr, indexmap={}):
ret = [] ret = []
# Extract a list of modified terminals for the test function # Extract a list of modified terminals for the test function
...@@ -29,8 +38,7 @@ def split_into_accumulation_terms(expr): ...@@ -29,8 +38,7 @@ def split_into_accumulation_terms(expr):
test_args = extract_modified_arguments(expr, argnumber=0) test_args = extract_modified_arguments(expr, argnumber=0)
# Extract a list of modified terminals for the ansatz function # Extract a list of modified terminals for the ansatz function
# in jacobian forms. Only the restriction of those terminals will # in jacobian forms.
# be used to generate new accumulation terms!
all_jacobian_args = extract_modified_arguments(expr, argnumber=1) all_jacobian_args = extract_modified_arguments(expr, argnumber=1)
for test_arg in test_args: for test_arg in test_args:
...@@ -39,18 +47,23 @@ def split_into_accumulation_terms(expr): ...@@ -39,18 +47,23 @@ def split_into_accumulation_terms(expr):
# 1) We first cut the expression to the relevant modified test_function # 1) We first cut the expression to the relevant modified test_function
# Build a replacement dictionary # Build a replacement dictionary
replacement = {ma.expr: Zero() for ma in test_args} replacement = {ma.expr: Zero(shape=ma.expr.ufl_shape,
free_indices=ma.expr.ufl_free_indices,
index_dimensions=ma.expr.ufl_index_dimensions)
for ma in test_args}
replacement[test_arg.expr] = test_arg.expr replacement[test_arg.expr] = test_arg.expr
replace_expr = replace_expression(expr, replacemap=replacement) replace_expr = replace_expression(expr, replacemap=replacement)
# 2) Cut the test function itself from the expression # 2) Cut the test function itself from the expression
indexmap = {}
if test_arg.index: if test_arg.index:
newi = indices(len(test_arg.index)) newi = indices(len(test_arg.index))
identities = tuple(Indexed(Identity(2), MultiIndex((i,) + (j,))) for i, j in zip(newi, test_arg.index._indices)) identities = tuple(Indexed(Identity(2), MultiIndex((i,) + (j,))) for i, j in zip(newi, test_arg.index._indices))
indexmap = {i: j for i, j in zip(test_arg.index._indices, newi)}
from dune.perftool.ufl.flatoperators import construct_binary_operator from dune.perftool.ufl.flatoperators import construct_binary_operator
from ufl.classes import Product from ufl.classes import Product
replacement = {test_arg.expr: construct_binary_operator(identities, Product)} replacement = {test_arg.expr: construct_binary_operator(identities, Product)}
test_arg = ModifiedArgumentDescriptor(reindexing(test_arg.expr, replacemap={i: j for i, j in zip(test_arg.index._indices, newi)})) test_arg = analyse_modified_argument(reindexing(test_arg.expr, replacemap=indexmap))
else: else:
replacement = {test_arg.expr: IntValue(1)} replacement = {test_arg.expr: IntValue(1)}
replace_expr = replace_expression(replace_expr, replacemap=replacement) replace_expr = replace_expression(replace_expr, replacemap=replacement)
...@@ -60,16 +73,21 @@ def split_into_accumulation_terms(expr): ...@@ -60,16 +73,21 @@ def split_into_accumulation_terms(expr):
# 4) Further split according to trial function in jacobian terms # 4) Further split according to trial function in jacobian terms
if all_jacobian_args: if all_jacobian_args:
for jac_arg in all_jacobian_args: # Update the list!
jac_args = extract_modified_arguments(replace_expr, argnumber=1)
for jac_arg in jac_args:
# TODO Some jacobian terms can be joined # TODO Some jacobian terms can be joined
replacement = {ma.expr: Zero() for ma in all_jacobian_args} replacement = {ma.expr: Zero(shape=ma.expr.ufl_shape,
free_indices=ma.expr.ufl_free_indices,
index_dimensions=ma.expr.ufl_index_dimensions)
for ma in jac_args}
replacement[jac_arg.expr] = jac_arg.expr replacement[jac_arg.expr] = jac_arg.expr
jac_expr = replace_expression(replace_expr, replacemap=replacement) jac_expr = replace_expression(replace_expr, replacemap=replacement)
if not isinstance(jac_expr, Zero): if not isinstance(jac_expr, Zero):
ret.append(AccumulationTerm(jac_expr, test_arg)) ret.append(AccumulationTerm(jac_expr, test_arg, indexmap))
else: else:
if not isinstance(replace_expr, Zero): if not isinstance(replace_expr, Zero):
ret.append(AccumulationTerm(replace_expr, test_arg)) ret.append(AccumulationTerm(replace_expr, test_arg, indexmap))
return ret return ret
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment