diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py index 0c2a6ba87dac36ce7e267aebdaea74135a76d80c..9ed825136e2ac1f01f6ef6f0da1aaceb399dc4e2 100644 --- a/python/dune/perftool/pdelab/localoperator.py +++ b/python/dune/perftool/pdelab/localoperator.py @@ -314,10 +314,10 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): # If this is a gradient, we generate an iname additional_inames = frozenset({}) - if accterm.argument.index: + if accterm.new_indices is not None: from ufl.domain import find_geometric_dimension dim = find_geometric_dimension(accterm.argument.expr) - for i in accterm.argument.index._indices: + for i in accterm.new_indices: if i not in visitor.dimension_indices: additional_inames = additional_inames.union(frozenset({grad_iname(i, dim)})) @@ -326,7 +326,12 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): return # We also traverse the test function to get its pymbolic equivalent - test_expr = visitor(accterm.argument.expr) + if accterm.new_indices is not None: + from ufl.classes import Indexed, MultiIndex + accum_expr = Indexed(accterm.argument.expr, MultiIndex(accterm.new_indices)) + else: + accum_expr = accterm.argument.expr + test_expr = visitor(accum_expr) # Combine expression and test function from pymbolic.primitives import Product @@ -397,11 +402,11 @@ def visit_integrals(integrals): accterms = split_into_accumulation_terms(integrand) # Iterate over the terms and generate a kernel - for term in accterms: + for accterm in accterms: # Adjust the index map for the visitor from copy import deepcopy indexmap = deepcopy(dimension_indices) - for i, j in term.indexmap.items(): + for i, j in accterm.indexmap.items(): if i in indexmap: indexmap[j] = indexmap[i] @@ -422,13 +427,17 @@ def visit_integrals(integrals): set_subst_rule(name, expr) # Ensure CSE on detjac * quadrature weight - domain = term.argument.argexpr.ufl_domain() + domain = accterm.argument.argexpr.ufl_domain() if measure == "cell": - set_subst_rule("integration_factor_cell1", uc.QuadratureWeight(domain)*uc.Abs(uc.JacobianDeterminant(domain))) - set_subst_rule("integration_factor_cell2", uc.Abs(uc.JacobianDeterminant(domain))*uc.QuadratureWeight(domain)) + set_subst_rule("integration_factor_cell1", + uc.QuadratureWeight(domain)*uc.Abs(uc.JacobianDeterminant(domain))) + set_subst_rule("integration_factor_cell2", + uc.Abs(uc.JacobianDeterminant(domain))*uc.QuadratureWeight(domain)) else: - set_subst_rule("integration_factor_facet1", uc.FacetJacobianDeterminant(domain)*uc.QuadratureWeight(domain)) - set_subst_rule("integration_factor_facet2", uc.QuadratureWeight(domain)*uc.FacetJacobianDeterminant(domain)) + set_subst_rule("integration_factor_facet1", + uc.FacetJacobianDeterminant(domain)*uc.QuadratureWeight(domain)) + set_subst_rule("integration_factor_facet2", + uc.QuadratureWeight(domain)*uc.FacetJacobianDeterminant(domain)) get_backend(interface="accum_insn")(visitor, term, measure, subdomain_id) diff --git a/python/dune/perftool/ufl/extract_accumulation_terms.py b/python/dune/perftool/ufl/extract_accumulation_terms.py index a5af56ec4342367677aab2d3c8c09928dee0b1c5..5a8a4d7a1405d1b5298026c9edc845a85110bef9 100644 --- a/python/dune/perftool/ufl/extract_accumulation_terms.py +++ b/python/dune/perftool/ufl/extract_accumulation_terms.py @@ -1,16 +1,18 @@ """ This module defines an UFL transformation, that takes a UFL expression -and transforms it into a sum of accumulation terms. +and transforms it into a list of accumulation terms. """ +from dune.perftool.ufl.flatoperators import construct_binary_operator from dune.perftool.ufl.modified_terminals import extract_modified_arguments from dune.perftool.ufl.transformations import ufl_transformation from dune.perftool.ufl.transformations.replace import replace_expression from dune.perftool.ufl.transformations.identitypropagation import identity_propagation +from dune.perftool.ufl.transformations.zeropropagation import zero_propagation from dune.perftool.ufl.transformations.reindexing import reindexing from dune.perftool.ufl.modified_terminals import analyse_modified_argument, ModifiedArgument from dune.perftool.pdelab.restriction import Restriction -from ufl.classes import Zero, Identity, Indexed, IntValue, MultiIndex +from ufl.classes import Zero, Identity, Indexed, IntValue, MultiIndex, Product from ufl.core.multiindex import indices from pytools import Record @@ -21,12 +23,14 @@ class AccumulationTerm(Record): term, argument, indexmap={}, + new_indices=None ): assert isinstance(argument, ModifiedArgument) Record.__init__(self, term=term, argument=argument, indexmap=indexmap, + new_indices=new_indices ) @@ -44,14 +48,14 @@ def split_into_accumulation_terms(expr): Arguments: ---------- - exrp: UFL expression we want to split + expr: UFL expression we want to split """ # Store AccumulationTerms in this list ret = [] # Extract a list of modified terminals for the test function # One accumulation instruction will be generated for each of these. - test_args = extract_modified_arguments(expr, argnumber=0) + test_args = extract_modified_arguments(expr, argnumber=0, do_index=False) # Extract a list of modified terminals for the ansatz function # in jacobian forms. @@ -62,7 +66,7 @@ def split_into_accumulation_terms(expr): # about too much stuff in reconstructing the expressions # 1) We first cut the expression to the relevant modified test_function - # Build a replacement dictionary + # by replacing all other test functions with zero replacement = {ma.expr: Zero(shape=ma.expr.ufl_shape, free_indices=ma.expr.ufl_free_indices, index_dimensions=ma.expr.ufl_index_dimensions) @@ -70,25 +74,67 @@ def split_into_accumulation_terms(expr): replacement[test_arg.expr] = test_arg.expr replace_expr = replace_expression(expr, replacemap=replacement) - # 2) Cut the test function itself from the expression + # 2) Propagate indexed zeros to simplify expression + replace_expr = zero_propagation(replace_expr) + + # 3) Cut the test function itself from the expression + # + # This is done by replacing the test function with an + # appropriate product of identity matrices. This way we can + # make sure that the indices of the result will be right. This + # is best explained by an example: + # + # Suppose we have the following expression: + # + # \sum_{i,j} a_{i,j} (\nabla v)_{i,j} + \sum_{k,l} b_{k,l} (\nable v)_{k,l} + # + # If we want to cut the gradient of the test function v we + # need to make sure, that both sums have the right indices: + # + # \sum_{m,n} (a_{m,n} + b_{m,n}) (\nabla v)_{m,n} + # + # and we extract (a_{m,n} + b_{m,n}). We achieve that by the + # following replacements: + # + # (\nabla v)_{i,j} -> I_{m,i} I_{n,j} + # (\nabla v)_{k,l} -> I_{m,k} I_{n,l} + # + # Resulting in: + # + # \sum_{i,j} a_{i,j} I_{m,i} I_{n,j} + \sum_{k,l} b_{k,l} I_{m,k} I_{n,l} + # + # In step 4 this will collaps to: a_{m,n} + b_{m,n} + replacement = {} indexmap = {} - if test_arg.index: - newi = indices(len(test_arg.index)) - identities = tuple(Indexed(Identity(2), MultiIndex((i,) + (j,))) for i, j in zip(newi, test_arg.index._indices)) - indexmap = {i: j for i, j in zip(test_arg.index._indices, newi)} - from dune.perftool.ufl.flatoperators import construct_binary_operator - from ufl.classes import Product - replacement = {test_arg.expr: construct_binary_operator(identities, Product)} - test_arg = analyse_modified_argument(reindexing(test_arg.expr, replacemap=indexmap)) - else: - replacement = {test_arg.expr: IntValue(1)} + newi = None + # Get all appearances of test functions with their indices + indexed_test_args = extract_modified_arguments(replace_expr, argnumber=0, do_index=True) + for indexed_test_arg in indexed_test_args: + # from pudb import set_trace; set_trace() + if indexed_test_arg.index: + # If the test function is indexed, create a new multiindex of this shape + # -> (m,n) in the example above + if newi is None: + newi = indices(len(indexed_test_arg.index)) + # Replace indexed test function with a product of identities. + identities = tuple(Indexed(Identity(2), MultiIndex((i,) + (j,))) + for i, j in zip(newi, indexed_test_arg.index._indices)) + replacement.update({indexed_test_arg.expr: + construct_binary_operator(identities, Product)}) + indexmap.update({i: j for i, j in zip(indexed_test_arg.index._indices, newi)}) + indexed_test_arg = analyse_modified_argument(reindexing(indexed_test_arg.expr, + replacemap=indexmap)) + else: + replacement.update({indexed_test_arg.expr: IntValue(1)}) replace_expr = replace_expression(replace_expr, replacemap=replacement) - # 3) Collapse any identity nodes that may have been introduced by replacing vectors + # 4) Collapse any identity nodes that may have been introduced by replacing vectors replace_expr = identity_propagation(replace_expr) - # 4) Further split according to trial function in jacobian terms + # 5) Further split according to trial function in jacobian terms if all_jacobian_args: + # TODO -> Jacobians not yet implemented! + assert(False) jac_args = extract_modified_arguments(replace_expr, argnumber=1) for restriction in (Restriction.NONE, Restriction.POSITIVE, Restriction.NEGATIVE): @@ -100,9 +146,9 @@ def split_into_accumulation_terms(expr): jac_expr = replace_expression(replace_expr, replacemap=replacement) if not isinstance(jac_expr, Zero): - ret.append(AccumulationTerm(jac_expr, test_arg, indexmap)) + ret.append(AccumulationTerm(jac_expr, test_arg, indexmap, newi)) else: if not isinstance(replace_expr, Zero): - ret.append(AccumulationTerm(replace_expr, test_arg, indexmap)) + ret.append(AccumulationTerm(replace_expr, test_arg, indexmap, newi)) return ret diff --git a/python/dune/perftool/ufl/modified_terminals.py b/python/dune/perftool/ufl/modified_terminals.py index 4ba1d0a3d34673b19cf04a3819f12ebaee03e353..a73d2610949ec32a94bdbb29516e45a5c2ba8558 100644 --- a/python/dune/perftool/ufl/modified_terminals.py +++ b/python/dune/perftool/ufl/modified_terminals.py @@ -92,7 +92,8 @@ class ModifiedTerminalTracker(MultiFunction): class ModifiedArgumentAnalysis(ModifiedTerminalTracker): - def __init__(self): + def __init__(self, do_index=False): + self.do_index = do_index self.index = None ModifiedTerminalTracker.__init__(self) @@ -101,7 +102,8 @@ class ModifiedArgumentAnalysis(ModifiedTerminalTracker): return self.call(o) def indexed(self, o): - self.index = o.ufl_operands[1] + if self.do_index: + self.index = o.ufl_operands[1] return self.call(o.ufl_operands[0]) def form_argument(self, o): @@ -116,22 +118,23 @@ class ModifiedArgumentAnalysis(ModifiedTerminalTracker): ) -def analyse_modified_argument(expr): - return ModifiedArgumentAnalysis()(expr) +def analyse_modified_argument(expr, **kwargs): + return ModifiedArgumentAnalysis(**kwargs)(expr) class _ModifiedArgumentExtractor(MultiFunction): """ A multifunction that extracts and returns the set of modified arguments """ - def __call__(self, o, argnumber=None, coeffcount=None): + def __call__(self, o, argnumber=None, coeffcount=None, do_index=False): self.argnumber = argnumber self.coeffcount = coeffcount + self.do_index = do_index self.modified_arguments = set() ret = self.call(o) if ret: # This indicates that this entire expression was a modified thing... self.modified_arguments.add(ret) - return tuple(analyse_modified_argument(ma) for ma in self.modified_arguments) + return tuple(analyse_modified_argument(ma, do_index=self.do_index) for ma in self.modified_arguments) def expr(self, o): for op in o.ufl_operands: @@ -143,7 +146,12 @@ class _ModifiedArgumentExtractor(MultiFunction): if self.call(o.ufl_operands[0]): return o - indexed = pass_on + def indexed(self, o): + if self.do_index: + return self.pass_on(o) + else: + self.expr(o) + positive_restricted = pass_on negative_restricted = pass_on grad = pass_on diff --git a/python/dune/perftool/ufl/transformations/identitypropagation.py b/python/dune/perftool/ufl/transformations/identitypropagation.py index aeb21a10790b81cdad8562e671d00e41ead62bdc..d1a2a9e151326d34eebb18c6fa838e7eb3336cf0 100644 --- a/python/dune/perftool/ufl/transformations/identitypropagation.py +++ b/python/dune/perftool/ufl/transformations/identitypropagation.py @@ -1,6 +1,6 @@ """ A transformation to help the form splitting algorithm split -vector and tensor expressions. In a nutshell does: +vector and tensor expressions. In a nutshell it does: \sum_i f(i)I(i,k) => f(k) """ diff --git a/python/dune/perftool/ufl/transformations/zeropropagation.py b/python/dune/perftool/ufl/transformations/zeropropagation.py new file mode 100644 index 0000000000000000000000000000000000000000..14ac17883a34a853781e9cb9199b7691b54f3a6d --- /dev/null +++ b/python/dune/perftool/ufl/transformations/zeropropagation.py @@ -0,0 +1,34 @@ +""" +A transformation that propagates zeros. It transforms: +Indexed(Zero((shape),(),()) -> Zero((),(free_indices),(index_dimensions)) +""" + +from dune.perftool.ufl.transformations import ufl_transformation + +from ufl.algorithms import MultiFunction +from ufl.classes import Zero + + +class ZeroPropagation(MultiFunction): + call = MultiFunction.__call__ + + def __call__(self, expr): + # self.replacemap = GetIndexMap()(expr) + return self.call(expr) + + def expr(self, o): + return self.reuse_if_untouched(o, *tuple(self.call(op) for op in o.ufl_operands)) + + def indexed(self, o): + op, i = o.ufl_operands + if isinstance(op, Zero): + return Zero(shape=o.ufl_shape, + free_indices=o.ufl_free_indices, + index_dimensions=o.ufl_index_dimensions) + else: + return self.reuse_if_untouched(o, self.call(op), self.call(i)) + + +@ufl_transformation(name='zero') +def zero_propagation(expr): + return ZeroPropagation()(expr)