From d22f81e7db61d56516197e4efcd37093db383251 Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Tue, 18 Oct 2016 13:16:08 +0200 Subject: [PATCH] Fix jacobians of systems and use pytools.Record for modified arguments --- python/dune/perftool/pdelab/argument.py | 1 - python/dune/perftool/pdelab/localoperator.py | 24 +++-- python/dune/perftool/pdelab/spaces.py | 4 +- .../dune/perftool/ufl/modified_terminals.py | 91 ++++++++++--------- .../extract_accumulation_terms.py | 42 ++++++--- 5 files changed, 93 insertions(+), 69 deletions(-) diff --git a/python/dune/perftool/pdelab/argument.py b/python/dune/perftool/pdelab/argument.py index d62f884d..a55a485c 100644 --- a/python/dune/perftool/pdelab/argument.py +++ b/python/dune/perftool/pdelab/argument.py @@ -14,7 +14,6 @@ from dune.perftool.generation import (domain, valuearg, get_global_context_value ) -from dune.perftool.ufl.modified_terminals import ModifiedArgumentDescriptor from dune.perftool.pdelab import (name_index, restricted_name, ) diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py index 71f552cc..d7a2e545 100644 --- a/python/dune/perftool/pdelab/localoperator.py +++ b/python/dune/perftool/pdelab/localoperator.py @@ -279,6 +279,7 @@ def determine_accumulation_space(expr, number, measure): if len(args) == 0: return AccumulationSpace() + # There should be but one modified argument, as the splitting eliminated all others. assert(len(args) == 1) ma, = args @@ -295,7 +296,7 @@ def determine_accumulation_space(expr, number, measure): if len(subel.value_shape()) != 0: from dune.perftool.pdelab.geometry import dimension_iname - idims = tuple(dimension_iname(context='arg', count=number) for i in range(len(subel.value_shape()))) + idims = tuple(dimension_iname(context='arg', count=i) for i in range(len(subel.value_shape()))) lfs = lfs_child(lfs, idims, shape=subel.value_shape(), symmetry=subel.symmetry) subel = subel.sub_elements()[0] @@ -370,7 +371,8 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): from ufl.domain import find_geometric_dimension dim = find_geometric_dimension(accterm.argument.expr) for i in accterm.argument.index._indices: - additional_inames = additional_inames.union(frozenset({grad_iname(i, dim)})) + if i not in visitor.dimension_indices: + additional_inames = additional_inames.union(frozenset({grad_iname(i, dim)})) # It may happen that an entire accumulation term vanishes. We do nothing in that case if pymbolic_expr == 0: @@ -441,20 +443,22 @@ def generate_kernel(integrals): from dune.perftool.pdelab.spaces import traverse_lfs_tree traverse_lfs_tree(ma) - from dune.perftool.options import set_option - set_option('print_transformations', True) - set_option('print_transformations_dir', '.') - # Now split the given integrand into accumulation expressions from dune.perftool.ufl.transformations.extract_accumulation_terms import split_into_accumulation_terms accterms = split_into_accumulation_terms(integrand) - # Get a transformer instance for this kernel - from dune.perftool.ufl.visitor import UFL2LoopyVisitor - visitor = UFL2LoopyVisitor(measure, dimension_indices) - # Iterate over the terms and generate a kernel for term in accterms: + # Adjust the index map for the visitor + from copy import deepcopy + indexmap = deepcopy(dimension_indices) + for i, j in term.indexmap.items(): + if i in indexmap: + indexmap[j] = indexmap[i] + + # Get a transformer instance for this kernel + from dune.perftool.ufl.visitor import UFL2LoopyVisitor + visitor = UFL2LoopyVisitor(measure, indexmap) generate_accumulation_instruction(visitor, term, measure, subdomain_id) # Extract the information, which is needed to create a loopy kernel. diff --git a/python/dune/perftool/pdelab/spaces.py b/python/dune/perftool/pdelab/spaces.py index 3812cc1f..e439425f 100644 --- a/python/dune/perftool/pdelab/spaces.py +++ b/python/dune/perftool/pdelab/spaces.py @@ -161,8 +161,8 @@ def type_gfs(element, basetype=None, index_stack=None): def traverse_lfs_tree(arg): - from dune.perftool.ufl.modified_terminals import ModifiedArgumentDescriptor - assert isinstance(arg, ModifiedArgumentDescriptor) + from dune.perftool.ufl.modified_terminals import ModifiedArgument + assert isinstance(arg, ModifiedArgument) # First we need to determine the basename as given in the signature of # this kernel method! diff --git a/python/dune/perftool/ufl/modified_terminals.py b/python/dune/perftool/ufl/modified_terminals.py index 1bed3056..33319b27 100644 --- a/python/dune/perftool/ufl/modified_terminals.py +++ b/python/dune/perftool/ufl/modified_terminals.py @@ -3,6 +3,30 @@ from ufl.algorithms import MultiFunction from dune.perftool import Restriction from ufl.classes import MultiIndex +from pytools import Record + + +class ModifiedArgument(Record): + def __init__(self, + expr=None, + argexpr=None, + grad=False, + index=None, + reference_grad=False, + restriction=Restriction.NONE, + component=MultiIndex(()), + reference=False, + ): + Record.__init__(self, + expr=expr, + argexpr=argexpr, + grad=grad, + index=index, + reference_grad=reference_grad, + restriction=restriction, + component=component, + reference=reference, + ) class ModifiedTerminalTracker(MultiFunction): @@ -10,6 +34,9 @@ class ModifiedTerminalTracker(MultiFunction): grad, reference_grad, positive_restricted and negative_restricted. The appearance of those classes changes the internal state of the MF. """ + + call = MultiFunction.__call__ + def __init__(self): MultiFunction.__init__(self) self.grad = False @@ -59,57 +86,33 @@ class ModifiedTerminalTracker(MultiFunction): return ret -class ModifiedArgumentDescriptor(MultiFunction): - def __init__(self, e): - MultiFunction.__init__(self) - - self.grad = False - self.reference = False - self.reference_grad = False +class ModifiedArgumentAnalysis(ModifiedTerminalTracker): + def __init__(self): self.index = None - self.restriction = Restriction.NONE - self.component = MultiIndex(()) - self.expr = e - - self.__call__(e) - self.__call__ = None + ModifiedTerminalTracker.__init__(self) - def __eq__(self, other): - return self.expr == other.expr - - def grad(self, o): - self.grad = True - self(o.ufl_operands[0]) - - def reference_grad(self, o): - self.reference_grad = True - self(o.ufl_operands[0]) - - def reference_value(self, o): - self.reference = True - self(o.ufl_operands[0]) - - def positive_restricted(self, o): - self.restriction = Restriction.POSITIVE - self(o.ufl_operands[0]) - - def negative_restricted(self, o): - self.restriction = Restriction.NEGATIVE - self(o.ufl_operands[0]) + def __call__(self, o): + self.call_expr = o + return self.call(o) def indexed(self, o): self.index = o.ufl_operands[1] - self(o.ufl_operands[0]) + return self.call(o.ufl_operands[0]) - def function_view(self, o): - self.component = o.ufl_operands[1] - self(o.ufl_operands[0]) + def form_argument(self, o): + return ModifiedArgument(expr=self.call_expr, + argexpr=o, + index=self.index, + restriction=self.restriction, + component=self.component, + grad=self.grad, + reference_grad=self.reference_grad, + reference=self.reference, + ) - def argument(self, o): - self.argexpr = o - def coefficient(self, o): - self.argexpr = o +def analyse_modified_argument(expr): + return ModifiedArgumentAnalysis()(expr) class _ModifiedArgumentExtractor(MultiFunction): @@ -123,7 +126,7 @@ class _ModifiedArgumentExtractor(MultiFunction): if ret: # This indicates that this entire expression was a modified thing... self.modified_arguments.add(ret) - return tuple(ModifiedArgumentDescriptor(ma) for ma in self.modified_arguments) + return tuple(analyse_modified_argument(ma) for ma in self.modified_arguments) def expr(self, o): for op in o.ufl_operands: diff --git a/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py b/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py index 3d53f9c2..0429c1c9 100644 --- a/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py +++ b/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py @@ -7,7 +7,7 @@ from dune.perftool.ufl.transformations import ufl_transformation from dune.perftool.ufl.transformations.replace import replace_expression from dune.perftool.ufl.transformations.identitypropagation import identity_propagation from dune.perftool.ufl.transformations.reindexing import reindexing -from dune.perftool.ufl.modified_terminals import ModifiedArgumentDescriptor +from dune.perftool.ufl.modified_terminals import analyse_modified_argument, ModifiedArgument from ufl.classes import Zero, Identity, Indexed, IntValue, MultiIndex from ufl.core.multiindex import indices @@ -16,12 +16,21 @@ from pytools import Record class AccumulationTerm(Record): - def __init__(self, term, argument): - Record.__init__(self, term=term, argument=argument) + def __init__(self, + term, + argument, + indexmap={}, + ): + assert isinstance(argument, ModifiedArgument) + Record.__init__(self, + term=term, + argument=argument, + indexmap=indexmap, + ) @ufl_transformation(name="accterms", extraction_lambda=lambda l: [at.term for at in l]) -def split_into_accumulation_terms(expr): +def split_into_accumulation_terms(expr, indexmap={}): ret = [] # Extract a list of modified terminals for the test function @@ -29,8 +38,7 @@ def split_into_accumulation_terms(expr): test_args = extract_modified_arguments(expr, argnumber=0) # Extract a list of modified terminals for the ansatz function - # in jacobian forms. Only the restriction of those terminals will - # be used to generate new accumulation terms! + # in jacobian forms. all_jacobian_args = extract_modified_arguments(expr, argnumber=1) for test_arg in test_args: @@ -39,18 +47,23 @@ def split_into_accumulation_terms(expr): # 1) We first cut the expression to the relevant modified test_function # Build a replacement dictionary - replacement = {ma.expr: Zero() for ma in test_args} + replacement = {ma.expr: Zero(shape=ma.expr.ufl_shape, + free_indices=ma.expr.ufl_free_indices, + index_dimensions=ma.expr.ufl_index_dimensions) + for ma in test_args} replacement[test_arg.expr] = test_arg.expr replace_expr = replace_expression(expr, replacemap=replacement) # 2) Cut the test function itself from the expression + indexmap = {} if test_arg.index: newi = indices(len(test_arg.index)) identities = tuple(Indexed(Identity(2), MultiIndex((i,) + (j,))) for i, j in zip(newi, test_arg.index._indices)) + indexmap = {i: j for i, j in zip(test_arg.index._indices, newi)} from dune.perftool.ufl.flatoperators import construct_binary_operator from ufl.classes import Product replacement = {test_arg.expr: construct_binary_operator(identities, Product)} - test_arg = ModifiedArgumentDescriptor(reindexing(test_arg.expr, replacemap={i: j for i, j in zip(test_arg.index._indices, newi)})) + test_arg = analyse_modified_argument(reindexing(test_arg.expr, replacemap=indexmap)) else: replacement = {test_arg.expr: IntValue(1)} replace_expr = replace_expression(replace_expr, replacemap=replacement) @@ -60,16 +73,21 @@ def split_into_accumulation_terms(expr): # 4) Further split according to trial function in jacobian terms if all_jacobian_args: - for jac_arg in all_jacobian_args: + # Update the list! + jac_args = extract_modified_arguments(replace_expr, argnumber=1) + for jac_arg in jac_args: # TODO Some jacobian terms can be joined - replacement = {ma.expr: Zero() for ma in all_jacobian_args} + replacement = {ma.expr: Zero(shape=ma.expr.ufl_shape, + free_indices=ma.expr.ufl_free_indices, + index_dimensions=ma.expr.ufl_index_dimensions) + for ma in jac_args} replacement[jac_arg.expr] = jac_arg.expr jac_expr = replace_expression(replace_expr, replacemap=replacement) if not isinstance(jac_expr, Zero): - ret.append(AccumulationTerm(jac_expr, test_arg)) + ret.append(AccumulationTerm(jac_expr, test_arg, indexmap)) else: if not isinstance(replace_expr, Zero): - ret.append(AccumulationTerm(replace_expr, test_arg)) + ret.append(AccumulationTerm(replace_expr, test_arg, indexmap)) return ret -- GitLab