From 88947ddc3d2de04f3c9db5a96361990d3ab7abd3 Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Wed, 12 Oct 2016 13:52:38 +0200 Subject: [PATCH] [bugfix] fix splitting of jacobian skeleton terms --- python/dune/perftool/pdelab/localoperator.py | 2 +- python/dune/perftool/ufl/modified_terminals.py | 2 +- .../ufl/transformations/extract_accumulation_terms.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py index 006685f1..068cae0e 100644 --- a/python/dune/perftool/pdelab/localoperator.py +++ b/python/dune/perftool/pdelab/localoperator.py @@ -427,7 +427,7 @@ def generate_kernel(integrals): # Generate code for the LFS trees present in the form from dune.perftool.ufl.modified_terminals import extract_modified_arguments - test_ma = extract_modified_arguments(integrand,) + test_ma = extract_modified_arguments(integrand, argnumber=0) trial_ma = extract_modified_arguments(integrand, coeffcount=0) apply_ma = extract_modified_arguments(integrand, coeffcount=1) diff --git a/python/dune/perftool/ufl/modified_terminals.py b/python/dune/perftool/ufl/modified_terminals.py index 97e4c2c7..1bed3056 100644 --- a/python/dune/perftool/ufl/modified_terminals.py +++ b/python/dune/perftool/ufl/modified_terminals.py @@ -144,7 +144,7 @@ class _ModifiedArgumentExtractor(MultiFunction): reference_value = pass_on def argument(self, o): - if self.argnumber is None or o.number() == self.argnumber: + if o.number() == self.argnumber: return o def coefficient(self, o): diff --git a/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py b/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py index 9d8b39d4..7fdf46b3 100644 --- a/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py +++ b/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py @@ -61,10 +61,10 @@ def split_into_accumulation_terms(expr): # TODO Some jacobian terms can be joined replacement = {ma.expr: Zero() for ma in all_jacobian_args} replacement[jac_arg.expr] = jac_arg.expr - replace_expr = replace_expression(replace_expr, replacemap=replacement) + jac_expr = replace_expression(replace_expr, replacemap=replacement) - if not isinstance(replace_expr, Zero): - ret.append(AccumulationTerm(replace_expr, test_arg)) + if not isinstance(jac_expr, Zero): + ret.append(AccumulationTerm(jac_expr, test_arg)) else: if not isinstance(replace_expr, Zero): ret.append(AccumulationTerm(replace_expr, test_arg)) -- GitLab