diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py index 006685f1eefbfc4c09a2b01661eff7fe310b6616..068cae0e4a85593348b2b83583f78828e6bf0888 100644 --- a/python/dune/perftool/pdelab/localoperator.py +++ b/python/dune/perftool/pdelab/localoperator.py @@ -427,7 +427,7 @@ def generate_kernel(integrals): # Generate code for the LFS trees present in the form from dune.perftool.ufl.modified_terminals import extract_modified_arguments - test_ma = extract_modified_arguments(integrand,) + test_ma = extract_modified_arguments(integrand, argnumber=0) trial_ma = extract_modified_arguments(integrand, coeffcount=0) apply_ma = extract_modified_arguments(integrand, coeffcount=1) diff --git a/python/dune/perftool/ufl/modified_terminals.py b/python/dune/perftool/ufl/modified_terminals.py index 97e4c2c75e1a8171cfbf9bb6f192763e8f983b65..1bed30564ac382c1d6d242cd93d2617c4063b597 100644 --- a/python/dune/perftool/ufl/modified_terminals.py +++ b/python/dune/perftool/ufl/modified_terminals.py @@ -144,7 +144,7 @@ class _ModifiedArgumentExtractor(MultiFunction): reference_value = pass_on def argument(self, o): - if self.argnumber is None or o.number() == self.argnumber: + if o.number() == self.argnumber: return o def coefficient(self, o): diff --git a/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py b/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py index 9d8b39d4f3d784e6629aed9145f89b2b2d4c2ba7..7fdf46b35e57cf9096916c72aa2bc4625709e399 100644 --- a/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py +++ b/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py @@ -61,10 +61,10 @@ def split_into_accumulation_terms(expr): # TODO Some jacobian terms can be joined replacement = {ma.expr: Zero() for ma in all_jacobian_args} replacement[jac_arg.expr] = jac_arg.expr - replace_expr = replace_expression(replace_expr, replacemap=replacement) + jac_expr = replace_expression(replace_expr, replacemap=replacement) - if not isinstance(replace_expr, Zero): - ret.append(AccumulationTerm(replace_expr, test_arg)) + if not isinstance(jac_expr, Zero): + ret.append(AccumulationTerm(jac_expr, test_arg)) else: if not isinstance(replace_expr, Zero): ret.append(AccumulationTerm(replace_expr, test_arg))