From 542ddd4c9fbef1d3c92e9633ba234ae2952dc52e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20He=C3=9F?= <rene.hess@iwr.uni-heidelberg.de> Date: Mon, 27 Mar 2017 14:36:13 +0200 Subject: [PATCH] Hopefully last jacobian splitting bugfix In the end I had to combine the last two versions. For jacobians: - We want to split according to the FunctionView (e.g. Stokes). - We don't want to split according to gradients (e.g. sumfactorization). --- .../ufl/extract_accumulation_terms.py | 39 +++++++++++++++---- .../dune/perftool/ufl/modified_terminals.py | 23 +++++++++-- 2 files changed, 51 insertions(+), 11 deletions(-) diff --git a/python/dune/perftool/ufl/extract_accumulation_terms.py b/python/dune/perftool/ufl/extract_accumulation_terms.py index 4ef70070..b5d473b9 100644 --- a/python/dune/perftool/ufl/extract_accumulation_terms.py +++ b/python/dune/perftool/ufl/extract_accumulation_terms.py @@ -131,19 +131,44 @@ def split_into_accumulation_terms(expr): replace_expr = identity_propagation(replace_expr) # 5) Further split according to trial function in jacobian terms + # + # Note: We need to split according to the FunctionView. For + # example in Stokes with test functions v and q we have to + # split between those. + # + # But: We don't want to split according to gradients since + # this would break the input buffers of stage 3 of + # sumfactorization if all_jacobian_args: - indexed_jac_args = extract_modified_arguments(replace_expr, argnumber=1, do_index=True) - for restriction in (Restriction.NONE, Restriction.POSITIVE, Restriction.NEGATIVE): + trial_args = extract_modified_arguments(replace_expr, + argnumber=1, + do_index=False, + do_gradient=False) + for trial_arg in trial_args: + # 5.1) Restrict to this trial argument replacement = {ma.expr: Zero(shape=ma.expr.ufl_shape, free_indices=ma.expr.ufl_free_indices, index_dimensions=ma.expr.ufl_index_dimensions) - if ma.restriction != restriction else ma.expr - for ma in indexed_jac_args} - + for ma in trial_args} + replacement[trial_arg.expr] = trial_arg.expr jac_expr = replace_expression(replace_expr, replacemap=replacement) - if not isinstance(jac_expr, Zero): - ret.append(AccumulationTerm(jac_expr, test_arg, indexmap, newi)) + # 5.2) Propagate indexed zeros to simplify expression + jac_expr = zero_propagation(jac_expr) + + # 5.3) Accumulate according to restriction + indexed_jac_args = extract_modified_arguments(jac_expr, argnumber=1, do_index=True) + for restriction in (Restriction.NONE, Restriction.POSITIVE, Restriction.NEGATIVE): + replacement = {ma.expr: Zero(shape=ma.expr.ufl_shape, + free_indices=ma.expr.ufl_free_indices, + index_dimensions=ma.expr.ufl_index_dimensions) + if ma.restriction != restriction else ma.expr + for ma in indexed_jac_args} + + acc_expr = replace_expression(jac_expr, replacemap=replacement) + + if not isinstance(jac_expr, Zero): + ret.append(AccumulationTerm(acc_expr, test_arg, indexmap, newi)) else: if not isinstance(replace_expr, Zero): ret.append(AccumulationTerm(replace_expr, test_arg, indexmap, newi)) diff --git a/python/dune/perftool/ufl/modified_terminals.py b/python/dune/perftool/ufl/modified_terminals.py index a73d2610..b5d2f2c5 100644 --- a/python/dune/perftool/ufl/modified_terminals.py +++ b/python/dune/perftool/ufl/modified_terminals.py @@ -125,16 +125,20 @@ def analyse_modified_argument(expr, **kwargs): class _ModifiedArgumentExtractor(MultiFunction): """ A multifunction that extracts and returns the set of modified arguments """ - def __call__(self, o, argnumber=None, coeffcount=None, do_index=False): + def __call__(self, o, argnumber=None, coeffcount=None, do_index=False, do_gradient=True): self.argnumber = argnumber self.coeffcount = coeffcount self.do_index = do_index + self.do_gradient = do_gradient self.modified_arguments = set() ret = self.call(o) if ret: # This indicates that this entire expression was a modified thing... self.modified_arguments.add(ret) - return tuple(analyse_modified_argument(ma, do_index=self.do_index) for ma in self.modified_arguments) + return tuple(analyse_modified_argument(ma, + do_index=self.do_index + ) + for ma in self.modified_arguments) def expr(self, o): for op in o.ufl_operands: @@ -152,10 +156,21 @@ class _ModifiedArgumentExtractor(MultiFunction): else: self.expr(o) + def reference_grad(self, o): + if self.do_gradient: + return self.pass_on(o) + else: + self.expr(o) + + def grad(self, o): + if self.do_gradient: + return self.pass_on(o) + else: + self.expr(o) + + positive_restricted = pass_on negative_restricted = pass_on - grad = pass_on - reference_grad = pass_on function_view = pass_on reference_value = pass_on -- GitLab