From 1b02ba97470017dc7d8afa0e793bbd9acf5d21cb Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ren=C3=A9=20He=C3=9F?= <rene.hess@iwr.uni-heidelberg.de>
Date: Fri, 24 Mar 2017 15:47:59 +0100
Subject: [PATCH] Fix another bug in jacobian splitting

---
 python/dune/perftool/sumfact/sumfact.py       | 19 ++++++++-----
 .../ufl/extract_accumulation_terms.py         | 27 +++++--------------
 2 files changed, 19 insertions(+), 27 deletions(-)

diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py
index 4580d46a..73952843 100644
--- a/python/dune/perftool/sumfact/sumfact.py
+++ b/python/dune/perftool/sumfact/sumfact.py
@@ -159,8 +159,9 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
             index = ()
             vectag = frozenset()
 
+        base_storage_size = product(max(mat.rows, mat.cols) for mat in a_matrices)
         temp = initialize_buffer(buf,
-                                 base_storage_size=product(max(mat.rows, mat.cols) for mat in a_matrices),
+                                 base_storage_size=base_storage_size,
                                  num=2
                                  ).get_temporary(shape=shape,
                                                  dim_tags=dim_tags,
@@ -168,9 +169,11 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
                                                  )
 
         # Those input fields, that are padded need to be set to zero
-        # in order to do a horizontal_add lateron
+        # in order to do a horizontal_add later on
         for pad in padding:
-            instruction(assignee=prim.Subscript(prim.Variable(temp), tuple(Variable(i) for i in quadrature_inames()) + (pad,)),
+            assignee = prim.Subscript(prim.Variable(temp),
+                                      tuple(Variable(i) for i in quadrature_inames()) + (pad,))
+            instruction(assignee=assignee,
                         expression=0,
                         forced_iname_deps=frozenset(quadrature_inames() + visitor.inames),
                         forced_iname_deps_is_final=True,
@@ -316,13 +319,15 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
                         depends_on=insn_dep,
                         )
 
-        # Mark the transformation that moves the quadrature loop inside the trialfunction loops for application
+        # Mark the transformation that moves the quadrature loop
+        # inside the trialfunction loops for application
         transform(nest_quadrature_loops, visitor.inames)
 
         return insn_dep
 
     # Extract the restrictions on argument-1:
-    jac_restrictions = frozenset(tuple(ma.restriction for ma in extract_modified_arguments(accterm.term, argnumber=1)))
+    jac_restrictions = frozenset(tuple(ma.restriction for ma in
+                                       extract_modified_arguments(accterm.term, argnumber=1, do_index=True)))
     if not jac_restrictions:
         jac_restrictions = frozenset({0})
 
@@ -461,9 +466,9 @@ def sum_factorization_kernel(a_matrices,
 
     It can make sense to permute the order of directions. If you have
     a small m_l (e.g. stage 1 on faces) it is better to do direction l
-    first. This can be done permuting:
+    first. This can be done by:
 
-    - The order of the A matrices.
+    - Permuting the order of the A matrices.
     - Permuting the input tensor.
     - Permuting the output tensor (this assures that the directions of
       the output tensor are again ordered from 0 to d-1).
diff --git a/python/dune/perftool/ufl/extract_accumulation_terms.py b/python/dune/perftool/ufl/extract_accumulation_terms.py
index 33032376..4ef70070 100644
--- a/python/dune/perftool/ufl/extract_accumulation_terms.py
+++ b/python/dune/perftool/ufl/extract_accumulation_terms.py
@@ -132,31 +132,18 @@ def split_into_accumulation_terms(expr):
 
         # 5) Further split according to trial function in jacobian terms
         if all_jacobian_args:
-            jac_args = extract_modified_arguments(replace_expr, argnumber=1, do_index=False)
-            for jac_arg in jac_args:
-                # 5.1) Cut the expression to this ansatz function
+            indexed_jac_args = extract_modified_arguments(replace_expr, argnumber=1, do_index=True)
+            for restriction in (Restriction.NONE, Restriction.POSITIVE, Restriction.NEGATIVE):
                 replacement = {ma.expr: Zero(shape=ma.expr.ufl_shape,
                                              free_indices=ma.expr.ufl_free_indices,
                                              index_dimensions=ma.expr.ufl_index_dimensions)
-                               for ma in jac_args}
-                replacement[jac_arg.expr] = jac_arg.expr
-                jac_expr = replace_expression(replace_expr, replacemap=replacement)
-
-                # 5.2) Propagate indexed zeros to simplify expression
-                jac_expr = zero_propagation(jac_expr)
+                               if ma.restriction != restriction else ma.expr
+                               for ma in indexed_jac_args}
 
-                indexed_jac_args = extract_modified_arguments(jac_expr, argnumber=1, do_index=True)
-                for restriction in (Restriction.NONE, Restriction.POSITIVE, Restriction.NEGATIVE):
-                    replacement = {ma.expr: Zero(shape=ma.expr.ufl_shape,
-                                                 free_indices=ma.expr.ufl_free_indices,
-                                                 index_dimensions=ma.expr.ufl_index_dimensions)
-                                   if ma.restriction != restriction else ma.expr
-                                   for ma in indexed_jac_args}
-
-                    jac_accum_expr = replace_expression(jac_expr, replacemap=replacement)
+                jac_expr = replace_expression(replace_expr, replacemap=replacement)
 
-                    if not isinstance(jac_expr, Zero):
-                        ret.append(AccumulationTerm(jac_accum_expr, test_arg, indexmap, newi))
+                if not isinstance(jac_expr, Zero):
+                    ret.append(AccumulationTerm(jac_expr, test_arg, indexmap, newi))
         else:
             if not isinstance(replace_expr, Zero):
                 ret.append(AccumulationTerm(replace_expr, test_arg, indexmap, newi))
-- 
GitLab