From ca5474795e7464021a65c89ad4949d999bc187c2 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Mon, 12 Dec 2016 16:33:08 +0100
Subject: [PATCH] Fix sumfactorized jacobians on facets

---
 python/dune/perftool/compile.py               |  3 +-
 python/dune/perftool/loopy/symbolic.py        |  6 ++-
 python/dune/perftool/pdelab/localoperator.py  |  5 +--
 python/dune/perftool/sumfact/basis.py         |  4 +-
 python/dune/perftool/sumfact/sumfact.py       | 42 ++++++++++++-------
 python/dune/perftool/sumfact/vectorization.py | 21 +++++++---
 .../extract_accumulation_terms.py             |  8 ++--
 test/poisson/CMakeLists.txt                   |  8 +++-
 test/poisson/poisson_dg_quadrilateral.ufl     |  4 +-
 test/sumfact/poisson/poisson_dg.mini          |  2 +-
 test/sumfact/poisson/poisson_dg.ufl           |  4 +-
 11 files changed, 67 insertions(+), 40 deletions(-)

diff --git a/python/dune/perftool/compile.py b/python/dune/perftool/compile.py
index d158a022..7e09f2fe 100644
--- a/python/dune/perftool/compile.py
+++ b/python/dune/perftool/compile.py
@@ -4,7 +4,6 @@ The methods to run the parts of the form compiler
 Should also contain the entrypoint methods.
 """
 from __future__ import absolute_import
-from os.path import basename, splitext
 
 import loopy
 
@@ -21,6 +20,8 @@ from dune.perftool.pdelab.localoperator import (generate_localoperator_basefile,
                                                 name_localoperator_file)
 from dune.perftool.ufl.preprocess import preprocess_form
 
+import os.path
+
 
 # Disable loopy caching before we do anything else!
 loopy.CACHING_ENABLED = False
diff --git a/python/dune/perftool/loopy/symbolic.py b/python/dune/perftool/loopy/symbolic.py
index 1ba315a3..42f0420d 100644
--- a/python/dune/perftool/loopy/symbolic.py
+++ b/python/dune/perftool/loopy/symbolic.py
@@ -20,21 +20,23 @@ class SumfactKernel(prim.Variable):
                  buffer,
                  stage,
                  preferred_position,
+                 restriction,
                  ):
         self.a_matrices = a_matrices
         self.buffer = buffer
         self.stage = stage
         self.preferred_position = preferred_position
+        self.restriction = restriction
 
         prim.Variable.__init__(self, "SUMFACT")
 
     def __getinitargs__(self):
-        return (self.a_matrices, self.buffer, self.stage, self.preferred_position)
+        return (self.a_matrices, self.buffer, self.stage, self.preferred_position, self.restriction)
 
     def stringifier(self):
         return lp.symbolic.StringifyMapper
 
-    init_arg_names = ("a_matrices", "buffer", "stage", "preferred_position")
+    init_arg_names = ("a_matrices", "buffer", "stage", "preferred_position", "restriction")
 
     mapper_method = "map_sumfact_kernel"
 
diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py
index ffa16a10..e5342213 100644
--- a/python/dune/perftool/pdelab/localoperator.py
+++ b/python/dune/perftool/pdelab/localoperator.py
@@ -215,10 +215,7 @@ def determine_accumulation_space(expr, number, measure):
     # If this is a residual term we return a dummy object
     if len(args) == 0:
         return AccumulationSpace()
-
-    # There should be but one modified argument, as the splitting eliminated all others.
-    assert(len(args) == 1)
-    ma, = args
+    ma = next(iter(args))
 
     # Extract information on the finite element
     from ufl.functionview import select_subelement
diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py
index b346a2f9..fe9ef21b 100644
--- a/python/dune/perftool/sumfact/basis.py
+++ b/python/dune/perftool/sumfact/basis.py
@@ -75,7 +75,7 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor):
 
         # Get the vectorization info. If this happens during the dry run, we get dummies
         from dune.perftool.sumfact.vectorization import get_vectorization_info
-        a_matrices, buffer, input, index = get_vectorization_info(a_matrices)
+        a_matrices, buffer, input, index = get_vectorization_info(a_matrices, 0)
 
         # Initialize the buffer for the sum fact kernel
         shape = (product(mat.cols for mat in a_matrices),)
@@ -140,7 +140,7 @@ def pymbolic_trialfunction(element, restriction, component, visitor):
 
     # Get the vectorization info. If this happens during the dry run, we get dummies
     from dune.perftool.sumfact.vectorization import get_vectorization_info
-    a_matrices, buffer, input, index = get_vectorization_info(a_matrices)
+    a_matrices, buffer, input, index = get_vectorization_info(a_matrices, 0)
 
     # Flip flop buffers for sumfactorization
     shape = (product(mat.cols for mat in a_matrices),)
diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py
index 9c48e3d4..afae030a 100644
--- a/python/dune/perftool/sumfact/sumfact.py
+++ b/python/dune/perftool/sumfact/sumfact.py
@@ -47,6 +47,7 @@ from dune.perftool.sumfact.switch import (get_facedir,
                                           get_facemod,
                                           )
 from dune.perftool.loopy.symbolic import SumfactKernel
+from dune.perftool.ufl.modified_terminals import extract_modified_arguments
 from dune.perftool.tools import get_pymbolic_basename
 from dune.perftool.error import PerftoolError
 from pymbolic.primitives import (Call,
@@ -109,9 +110,6 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
 
     dim = world_dimension()
 
-    # Collect buffers we need
-    buffers = []
-
     # If this is a gradient, we find the gradient iname
     additional_inames = frozenset({})
     if accterm.argument.index:
@@ -119,13 +117,8 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
             if i not in visitor.dimension_indices:
                 from dune.perftool.pdelab.localoperator import grad_iname
                 additional_inames = additional_inames.union(frozenset({grad_iname(i, dim)}))
-                for i in range(dim):
-                    buffers.append(name_test_function_contribution(accterm.argument))
-    else:
-        buffers.append(name_test_function_contribution(accterm.argument))
 
-    insn_dep = None
-    for i, buf in enumerate(buffers):
+    def emit_sumfact_kernel(i, restriction, insn_dep):
         # Construct the matrix sequence for this sum factorization
         a_matrices = construct_amatrix_sequence(transpose=True,
                                                 derivative=i if accterm.argument.index else None,
@@ -135,10 +128,13 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
 
         # Get the vectorization info. If this happens during the dry run, we get dummies
         from dune.perftool.sumfact.vectorization import get_vectorization_info
-        a_matrices, buffer, input, index = get_vectorization_info(a_matrices)
+        a_matrices, buffer, input, index = get_vectorization_info(a_matrices, restriction)
 
         # Initialize a base storage for this buffer and get a temporay pointing to it
-        shape = tuple(mat.cols for mat in a_matrices if mat.cols != 1)
+        try:
+            shape = tuple(mat.cols for mat in a_matrices if mat.cols != 1)
+        except:
+            from pudb import set_trace; set_trace()
         dim_tags = ",".join(['f'] * local_dimension())
         if index is not None:
             shape = shape + (4,)
@@ -207,7 +203,8 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
         ansatz_lfs = determine_accumulation_space(accterm.term, 1, measure)
         rank = 2 if visitor.inames else 1
         if rank == 2:
-            ansatz_lfs.index = flatten_index(tuple(Variable(i) for i in visitor.inames),
+            # TODO the next line should get its inames from elsewhere. This is *NOT* robust (but works right now)
+            ansatz_lfs.index = flatten_index(tuple(Variable(visitor.inames[i]) for i in range(world_dimension())),
                                              (basis_functions_per_direction(),) * dim,
                                              order="f"
                                              )
@@ -230,13 +227,30 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
         # Mark the transformation that moves the quadrature loop inside the trialfunction loops for application
         transform(nest_quadrature_loops, visitor.inames)
 
+        return insn_dep
+
+    # Extract the restrictions on argument-1:
+    jac_restrictions = frozenset(tuple(ma.restriction for ma in extract_modified_arguments(accterm.term, argnumber=1)))
+    if not jac_restrictions:
+        jac_restrictions = frozenset({0})
+
+    insn_dep = None
+    for restriction in jac_restrictions:
+        if accterm.argument.index:
+            for i in range(world_dimension()):
+                insn_dep = emit_sumfact_kernel(i, restriction, insn_dep)
+        else:
+            insn_dep = emit_sumfact_kernel(None, restriction, insn_dep)
+
 
 @generator_factory(item_tags=("sumfactkernel",), context_tags=("kernel",), cache_key_generator=lambda a, b, s, **kw: (a, b, s))
 def sum_factorization_kernel(a_matrices, buf, stage,
                              insn_dep=frozenset({}),
                              additional_inames=frozenset({}),
                              preferred_position=None,
-                             outshape=None):
+                             outshape=None,
+                             restriction=0,
+                             ):
     """
     Calculate a sum factorization matrix product.
 
@@ -257,7 +271,7 @@ def sum_factorization_kernel(a_matrices, buf, stage,
         other.
     """
     if get_global_context_value("dry_run", False):
-        return SumfactKernel(a_matrices, buf, stage, preferred_position), frozenset()
+        return SumfactKernel(a_matrices, buf, stage, preferred_position, restriction), frozenset()
 
     ftags = "f,f"
     ctags = "c,c"
diff --git a/python/dune/perftool/sumfact/vectorization.py b/python/dune/perftool/sumfact/vectorization.py
index 3747a930..a7e8eae0 100644
--- a/python/dune/perftool/sumfact/vectorization.py
+++ b/python/dune/perftool/sumfact/vectorization.py
@@ -3,31 +3,40 @@
 from dune.perftool.generation import (generator_factory,
                                       get_counted_variable,
                                       )
+from dune.perftool.pdelab.restriction import (Restriction,
+                                              restricted_name,
+                                              )
 from dune.perftool.error import PerftoolError
 from dune.perftool.options import get_option
 
 import loopy as lp
 
 
-@generator_factory(item_tags=("vecinfo", "dryrundata"), cache_key_generator=lambda a, *args: a)
-def vectorization_info(a_matrices, new_a_matrices, buffer, input, index):
+@generator_factory(item_tags=("vecinfo", "dryrundata"), cache_key_generator=lambda a, r, *args: (a, r))
+def vectorization_info(a_matrices, restriction, new_a_matrices, buffer, input, index):
     return (new_a_matrices, buffer, input, index)
 
 
-def get_vectorization_info(a_matrices):
+def get_vectorization_info(a_matrices, restriction):
     from dune.perftool.generation import get_global_context_value
     if get_global_context_value("dry_run"):
         # Return dummy data
         return (a_matrices, get_counted_variable("buffer"), get_counted_variable("input"), None)
     try:
-        return vectorization_info(a_matrices, None, None, None, None)
+        return vectorization_info(a_matrices, restriction, None, None, None, None)
     except TypeError:
         raise PerftoolError("Sumfact Vectorization data should have been collected in dry run, but wasnt")
 
 
 def no_vectorization(sumfacts):
     for sumf in sumfacts:
-        vectorization_info(sumf.a_matrices, sumf.a_matrices, get_counted_variable("buffer"), get_counted_variable("input"), None)
+        for res in (Restriction.NONE, Restriction.POSITIVE, Restriction.NEGATIVE):
+            vectorization_info(sumf.a_matrices,
+                               res,
+                               sumf.a_matrices,
+                               get_counted_variable("buffer"),
+                               get_counted_variable(restricted_name("input", sumf.restriction)),
+                               None)
 
 
 def decide_stage_vectorization_strategy(sumfacts, stage):
@@ -74,7 +83,7 @@ def decide_stage_vectorization_strategy(sumfacts, stage):
             large_a_matrices.append(large)
 
         for sumf in stage_sumfacts:
-            vectorization_info(sumf.a_matrices, tuple(large_a_matrices), buffer, input, position_mapping[sumf])
+            vectorization_info(sumf.a_matrices, sumf.restriction, tuple(large_a_matrices), buffer, input, position_mapping[sumf])
     else:
         # Disable vectorization strategy
         no_vectorization(stage_sumfacts)
diff --git a/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py b/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py
index 0429c1c9..522ba129 100644
--- a/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py
+++ b/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py
@@ -8,6 +8,7 @@ from dune.perftool.ufl.transformations.replace import replace_expression
 from dune.perftool.ufl.transformations.identitypropagation import identity_propagation
 from dune.perftool.ufl.transformations.reindexing import reindexing
 from dune.perftool.ufl.modified_terminals import analyse_modified_argument, ModifiedArgument
+from dune.perftool.pdelab.restriction import Restriction
 
 from ufl.classes import Zero, Identity, Indexed, IntValue, MultiIndex
 from ufl.core.multiindex import indices
@@ -73,15 +74,14 @@ def split_into_accumulation_terms(expr, indexmap={}):
 
         # 4) Further split according to trial function in jacobian terms
         if all_jacobian_args:
-            # Update the list!
             jac_args = extract_modified_arguments(replace_expr, argnumber=1)
-            for jac_arg in jac_args:
-                # TODO Some jacobian terms can be joined
+
+            for restriction in (Restriction.NONE, Restriction.POSITIVE, Restriction.NEGATIVE):
                 replacement = {ma.expr: Zero(shape=ma.expr.ufl_shape,
                                              free_indices=ma.expr.ufl_free_indices,
                                              index_dimensions=ma.expr.ufl_index_dimensions)
+                               if ma.restriction != restriction else ma.expr
                                for ma in jac_args}
-                replacement[jac_arg.expr] = jac_arg.expr
                 jac_expr = replace_expression(replace_expr, replacemap=replacement)
 
                 if not isinstance(jac_expr, Zero):
diff --git a/test/poisson/CMakeLists.txt b/test/poisson/CMakeLists.txt
index c1312699..932fdafb 100644
--- a/test/poisson/CMakeLists.txt
+++ b/test/poisson/CMakeLists.txt
@@ -49,12 +49,16 @@ dune_add_formcompiler_system_test(UFLFILE poisson_cellwise_constant.ufl
                                   )
 
 # 8. Poisson with operator counting
-dune_add_formcompiler_system_test(UFLFILE poisson_dg_quadrilateral.ufl
+dune_add_formcompiler_system_test(UFLFILE opcount_poisson_dg.ufl
                                   BASENAME opcount_poisson_dg_symdiff
                                   INIFILE opcount_poisson_dg_symdiff.mini
                                   )
 
-
+# 3. Poisson Test Case: DG, f + pure dirichlet
+dune_add_formcompiler_system_test(UFLFILE poisson_dg_quadrilateral.ufl
+                                  BASENAME poisson_dg_quadrilateral
+                                  INIFILE poisson_dg_quadrilateral.mini
+                                  )
 
 # the reference vtk file
 add_executable(poisson_dg_ref reference_main.cc)
diff --git a/test/poisson/poisson_dg_quadrilateral.ufl b/test/poisson/poisson_dg_quadrilateral.ufl
index 1ca65e1e..e5372870 100644
--- a/test/poisson/poisson_dg_quadrilateral.ufl
+++ b/test/poisson/poisson_dg_quadrilateral.ufl
@@ -1,7 +1,7 @@
 cell = "quadrilateral"
 
-f = Expression("return -2.0*x.size();", cell=cell)
-g = Expression("return x.two_norm2();", on_intersection=True, cell=cell)
+f = Expression("Dune::FieldVector<double,2> c(0.5); c-= x; return 4.*(1.-c.two_norm2())*std::exp(-1.*c.two_norm2());", cell=cell)
+g = Expression("Dune::FieldVector<double,2> c(0.5); c-= x; return std::exp(-1.*c.two_norm2());", on_intersection=True, cell=cell)
 
 V = FiniteElement("DG", cell, 1)
 
diff --git a/test/sumfact/poisson/poisson_dg.mini b/test/sumfact/poisson/poisson_dg.mini
index eeca1791..aa5e41bc 100644
--- a/test/sumfact/poisson/poisson_dg.mini
+++ b/test/sumfact/poisson/poisson_dg.mini
@@ -1,7 +1,7 @@
 __name = sumfact_poisson_dg_{__exec_suffix}
 __exec_suffix = numdiff, symdiff | expand num
 
-cells = 1 1
+cells = 16 16
 extension = 1. 1.
 
 [wrapper.vtkcompare]
diff --git a/test/sumfact/poisson/poisson_dg.ufl b/test/sumfact/poisson/poisson_dg.ufl
index a0b5049d..90194675 100644
--- a/test/sumfact/poisson/poisson_dg.ufl
+++ b/test/sumfact/poisson/poisson_dg.ufl
@@ -1,7 +1,7 @@
 cell = "quadrilateral"
 
-f = Expression("return -2.0*x.size();", cell=cell)
-g = Expression("return x.two_norm2();", on_intersection=True, cell=cell)
+f = Expression("Dune::FieldVector<double,2> c(0.5); c-= x; return 4.*(1.-c.two_norm2())*std::exp(-1.*c.two_norm2());", cell=cell)
+g = Expression("Dune::FieldVector<double,2> c(0.5); c-= x; return std::exp(-1.*c.two_norm2());", on_intersection=True, cell=cell)
 
 V = FiniteElement("DG", cell, 1)
 
-- 
GitLab