From 542ddd4c9fbef1d3c92e9633ba234ae2952dc52e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ren=C3=A9=20He=C3=9F?= <rene.hess@iwr.uni-heidelberg.de>
Date: Mon, 27 Mar 2017 14:36:13 +0200
Subject: [PATCH] Hopefully last jacobian splitting bugfix

In the end I had to combine the last two versions. For jacobians:
- We want to split according to the FunctionView (e.g. Stokes).
- We don't want to split according to gradients (e.g. sumfactorization).
---
 .../ufl/extract_accumulation_terms.py         | 39 +++++++++++++++----
 .../dune/perftool/ufl/modified_terminals.py   | 23 +++++++++--
 2 files changed, 51 insertions(+), 11 deletions(-)

diff --git a/python/dune/perftool/ufl/extract_accumulation_terms.py b/python/dune/perftool/ufl/extract_accumulation_terms.py
index 4ef70070..b5d473b9 100644
--- a/python/dune/perftool/ufl/extract_accumulation_terms.py
+++ b/python/dune/perftool/ufl/extract_accumulation_terms.py
@@ -131,19 +131,44 @@ def split_into_accumulation_terms(expr):
         replace_expr = identity_propagation(replace_expr)
 
         # 5) Further split according to trial function in jacobian terms
+        #
+        # Note: We need to split according to the FunctionView. For
+        # example in Stokes with test functions v and q we have to
+        # split between those.
+        #
+        # But: We don't want to split according to gradients since
+        # this would break the input buffers of stage 3 of
+        # sumfactorization
         if all_jacobian_args:
-            indexed_jac_args = extract_modified_arguments(replace_expr, argnumber=1, do_index=True)
-            for restriction in (Restriction.NONE, Restriction.POSITIVE, Restriction.NEGATIVE):
+            trial_args = extract_modified_arguments(replace_expr,
+                                                    argnumber=1,
+                                                    do_index=False,
+                                                    do_gradient=False)
+            for trial_arg in trial_args:
+                # 5.1) Restrict to this trial argument
                 replacement = {ma.expr: Zero(shape=ma.expr.ufl_shape,
                                              free_indices=ma.expr.ufl_free_indices,
                                              index_dimensions=ma.expr.ufl_index_dimensions)
-                               if ma.restriction != restriction else ma.expr
-                               for ma in indexed_jac_args}
-
+                               for ma in trial_args}
+                replacement[trial_arg.expr] = trial_arg.expr
                 jac_expr = replace_expression(replace_expr, replacemap=replacement)
 
-                if not isinstance(jac_expr, Zero):
-                    ret.append(AccumulationTerm(jac_expr, test_arg, indexmap, newi))
+                # 5.2) Propagate indexed zeros to simplify expression
+                jac_expr = zero_propagation(jac_expr)
+
+                # 5.3) Accumulate according to restriction
+                indexed_jac_args = extract_modified_arguments(jac_expr, argnumber=1, do_index=True)
+                for restriction in (Restriction.NONE, Restriction.POSITIVE, Restriction.NEGATIVE):
+                    replacement = {ma.expr: Zero(shape=ma.expr.ufl_shape,
+                                                 free_indices=ma.expr.ufl_free_indices,
+                                                 index_dimensions=ma.expr.ufl_index_dimensions)
+                                   if ma.restriction != restriction else ma.expr
+                                   for ma in indexed_jac_args}
+
+                    acc_expr = replace_expression(jac_expr, replacemap=replacement)
+
+                    if not isinstance(jac_expr, Zero):
+                        ret.append(AccumulationTerm(acc_expr, test_arg, indexmap, newi))
         else:
             if not isinstance(replace_expr, Zero):
                 ret.append(AccumulationTerm(replace_expr, test_arg, indexmap, newi))
diff --git a/python/dune/perftool/ufl/modified_terminals.py b/python/dune/perftool/ufl/modified_terminals.py
index a73d2610..b5d2f2c5 100644
--- a/python/dune/perftool/ufl/modified_terminals.py
+++ b/python/dune/perftool/ufl/modified_terminals.py
@@ -125,16 +125,20 @@ def analyse_modified_argument(expr, **kwargs):
 class _ModifiedArgumentExtractor(MultiFunction):
     """ A multifunction that extracts and returns the set of modified arguments """
 
-    def __call__(self, o, argnumber=None, coeffcount=None, do_index=False):
+    def __call__(self, o, argnumber=None, coeffcount=None, do_index=False, do_gradient=True):
         self.argnumber = argnumber
         self.coeffcount = coeffcount
         self.do_index = do_index
+        self.do_gradient = do_gradient
         self.modified_arguments = set()
         ret = self.call(o)
         if ret:
             # This indicates that this entire expression was a modified thing...
             self.modified_arguments.add(ret)
-        return tuple(analyse_modified_argument(ma, do_index=self.do_index) for ma in self.modified_arguments)
+        return tuple(analyse_modified_argument(ma,
+                                               do_index=self.do_index
+                                               )
+                     for ma in self.modified_arguments)
 
     def expr(self, o):
         for op in o.ufl_operands:
@@ -152,10 +156,21 @@ class _ModifiedArgumentExtractor(MultiFunction):
         else:
             self.expr(o)
 
+    def reference_grad(self, o):
+        if self.do_gradient:
+            return self.pass_on(o)
+        else:
+            self.expr(o)
+
+    def grad(self, o):
+        if self.do_gradient:
+            return self.pass_on(o)
+        else:
+            self.expr(o)
+
+
     positive_restricted = pass_on
     negative_restricted = pass_on
-    grad = pass_on
-    reference_grad = pass_on
     function_view = pass_on
     reference_value = pass_on
 
-- 
GitLab