diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py
index 006685f1eefbfc4c09a2b01661eff7fe310b6616..068cae0e4a85593348b2b83583f78828e6bf0888 100644
--- a/python/dune/perftool/pdelab/localoperator.py
+++ b/python/dune/perftool/pdelab/localoperator.py
@@ -427,7 +427,7 @@ def generate_kernel(integrals):
 
         # Generate code for the LFS trees present in the form
         from dune.perftool.ufl.modified_terminals import extract_modified_arguments
-        test_ma = extract_modified_arguments(integrand,)
+        test_ma = extract_modified_arguments(integrand, argnumber=0)
         trial_ma = extract_modified_arguments(integrand, coeffcount=0)
         apply_ma = extract_modified_arguments(integrand, coeffcount=1)
 
diff --git a/python/dune/perftool/ufl/modified_terminals.py b/python/dune/perftool/ufl/modified_terminals.py
index 97e4c2c75e1a8171cfbf9bb6f192763e8f983b65..1bed30564ac382c1d6d242cd93d2617c4063b597 100644
--- a/python/dune/perftool/ufl/modified_terminals.py
+++ b/python/dune/perftool/ufl/modified_terminals.py
@@ -144,7 +144,7 @@ class _ModifiedArgumentExtractor(MultiFunction):
     reference_value = pass_on
 
     def argument(self, o):
-        if self.argnumber is None or o.number() == self.argnumber:
+        if o.number() == self.argnumber:
             return o
 
     def coefficient(self, o):
diff --git a/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py b/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py
index 9d8b39d4f3d784e6629aed9145f89b2b2d4c2ba7..7fdf46b35e57cf9096916c72aa2bc4625709e399 100644
--- a/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py
+++ b/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py
@@ -61,10 +61,10 @@ def split_into_accumulation_terms(expr):
                 # TODO Some jacobian terms can be joined
                 replacement = {ma.expr: Zero() for ma in all_jacobian_args}
                 replacement[jac_arg.expr] = jac_arg.expr
-                replace_expr = replace_expression(replace_expr, replacemap=replacement)
+                jac_expr = replace_expression(replace_expr, replacemap=replacement)
 
-                if not isinstance(replace_expr, Zero):
-                    ret.append(AccumulationTerm(replace_expr, test_arg))
+                if not isinstance(jac_expr, Zero):
+                    ret.append(AccumulationTerm(jac_expr, test_arg))
         else:
             if not isinstance(replace_expr, Zero):
                 ret.append(AccumulationTerm(replace_expr, test_arg))