From 32212687b484b25ef27941e4b26fe78736ef066b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20He=C3=9F?= <rene.hess@iwr.uni-heidelberg.de> Date: Fri, 24 Mar 2017 10:07:27 +0100 Subject: [PATCH] Extend new splitting to jacobians --- .../ufl/extract_accumulation_terms.py | 30 ++++++++++++------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/python/dune/perftool/ufl/extract_accumulation_terms.py b/python/dune/perftool/ufl/extract_accumulation_terms.py index 5a8a4d7a..cefc8275 100644 --- a/python/dune/perftool/ufl/extract_accumulation_terms.py +++ b/python/dune/perftool/ufl/extract_accumulation_terms.py @@ -59,7 +59,7 @@ def split_into_accumulation_terms(expr): # Extract a list of modified terminals for the ansatz function # in jacobian forms. - all_jacobian_args = extract_modified_arguments(expr, argnumber=1) + all_jacobian_args = extract_modified_arguments(expr, argnumber=1, do_index=False) for test_arg in test_args: # Do this as a multi step replacement procedure to avoid UFL nagging @@ -110,7 +110,6 @@ def split_into_accumulation_terms(expr): # Get all appearances of test functions with their indices indexed_test_args = extract_modified_arguments(replace_expr, argnumber=0, do_index=True) for indexed_test_arg in indexed_test_args: - # from pudb import set_trace; set_trace() if indexed_test_arg.index: # If the test function is indexed, create a new multiindex of this shape # -> (m,n) in the example above @@ -133,20 +132,31 @@ def split_into_accumulation_terms(expr): # 5) Further split according to trial function in jacobian terms if all_jacobian_args: - # TODO -> Jacobians not yet implemented! - assert(False) - jac_args = extract_modified_arguments(replace_expr, argnumber=1) - - for restriction in (Restriction.NONE, Restriction.POSITIVE, Restriction.NEGATIVE): + jac_args = extract_modified_arguments(replace_expr, argnumber=1, do_index=False) + for jac_arg in jac_args: + # 5.1) Cut the expression to this ansatz function replacement = {ma.expr: Zero(shape=ma.expr.ufl_shape, free_indices=ma.expr.ufl_free_indices, index_dimensions=ma.expr.ufl_index_dimensions) - if ma.restriction != restriction else ma.expr for ma in jac_args} + replacement[jac_arg.expr] = jac_arg.expr jac_expr = replace_expression(replace_expr, replacemap=replacement) - if not isinstance(jac_expr, Zero): - ret.append(AccumulationTerm(jac_expr, test_arg, indexmap, newi)) + # 5.2) Propagate indexed zeros to simplify expression + jac_expr = zero_propagation(jac_expr) + + indexed_jac_args = extract_modified_arguments(jac_expr, argnumber=1, do_index=True) + for restriction in (Restriction.NONE, Restriction.POSITIVE, Restriction.NEGATIVE): + replacement = {ma.expr: Zero(shape=ma.expr.ufl_shape, + free_indices=ma.expr.ufl_free_indices, + index_dimensions=ma.expr.ufl_index_dimensions) + if ma.restriction != restriction else ma.expr + for ma in indexed_jac_args} + + jac_expr = replace_expression(jac_expr, replacemap=replacement) + + if not isinstance(jac_expr, Zero): + ret.append(AccumulationTerm(jac_expr, test_arg, indexmap, newi)) else: if not isinstance(replace_expr, Zero): ret.append(AccumulationTerm(replace_expr, test_arg, indexmap, newi)) -- GitLab