From 86d24865fac79ac46331af440ac0a1a4f4108310 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Wed, 26 Oct 2016 16:40:15 +0200
Subject: [PATCH] Split quadrature loop for sum factorization and implement
 weight multiplication

---
 python/dune/perftool/pdelab/__init__.py      | 11 +++++---
 python/dune/perftool/pdelab/argument.py      | 13 ++++++++-
 python/dune/perftool/pdelab/basis.py         |  4 +--
 python/dune/perftool/pdelab/localoperator.py | 14 ++++++----
 python/dune/perftool/pdelab/quadrature.py    | 12 ++++-----
 python/dune/perftool/pdelab/spaces.py        | 18 +++++++++++++
 python/dune/perftool/sumfact/__init__.py     | 22 +++++++++++----
 python/dune/perftool/sumfact/amatrix.py      |  1 +
 python/dune/perftool/sumfact/sumfact.py      | 28 +++++++++++---------
 python/dune/perftool/ufl/visitor.py          |  2 +-
 10 files changed, 88 insertions(+), 37 deletions(-)

diff --git a/python/dune/perftool/pdelab/__init__.py b/python/dune/perftool/pdelab/__init__.py
index 76aaff91..7055fe70 100644
--- a/python/dune/perftool/pdelab/__init__.py
+++ b/python/dune/perftool/pdelab/__init__.py
@@ -22,7 +22,8 @@ from dune.perftool.pdelab.index import (name_index,
 from dune.perftool.pdelab.parameter import (cell_parameter_function,
                                             intersection_parameter_function,
                                             )
-from dune.perftool.pdelab.quadrature import (name_quadrature_weight,
+from dune.perftool.pdelab.quadrature import (pymbolic_quadrature_weight,
+                                             quadrature_inames,
                                              )
 from dune.perftool.pdelab.spaces import (lfs_iname,
                                          )
@@ -102,5 +103,9 @@ class PDELabInterface(object):
     # Quadrature related generator functions
     #
 
-    def name_quadrature_weight(self):
-        return name_quadrature_weight()
+    def pymbolic_quadrature_weight(self):
+        return pymbolic_quadrature_weight()
+
+    # TODO Should this be part of interface or not?
+    def quadrature_inames(self):
+        return quadrature_inames()
\ No newline at end of file
diff --git a/python/dune/perftool/pdelab/argument.py b/python/dune/perftool/pdelab/argument.py
index afcb0e95..5b72c9b9 100644
--- a/python/dune/perftool/pdelab/argument.py
+++ b/python/dune/perftool/pdelab/argument.py
@@ -169,11 +169,22 @@ def type_residual():
     return "R"
 
 
-def name_accumulation_variable(restrictions=(Restriction.NONE,)):
+def name_accumulation_variable(restrictions=None):
     ft = get_global_context_value("form_type")
+    measure = get_global_context_value("integral_type")
     if ft == 'residual' or ft == 'jacobian_apply':
+        if restrictions is None:
+            if measure == "cell":
+                restrictions = (Restriction.NONE,)
+            else:
+                restrictions = (Restriction.OUTSIDE,)
         return name_residual(*restrictions)
     if ft == 'jacobian':
+        if restrictions is None:
+            if measure == "cell":
+                restrictions = (Restriction.NONE, Restriction.NONE)
+            else:
+                restrictions = (Restriction.OUTSIDE, Restriction.OUTSIDE)
         return name_jacobian(*restrictions)
     assert False
 
diff --git a/python/dune/perftool/pdelab/basis.py b/python/dune/perftool/pdelab/basis.py
index ad587091..09ecddb0 100644
--- a/python/dune/perftool/pdelab/basis.py
+++ b/python/dune/perftool/pdelab/basis.py
@@ -1,7 +1,6 @@
 """ Generators for basis evaluations """
 
-from dune.perftool.generation import (backend,
-                                      cached,
+from dune.perftool.generation import (cached,
                                       class_member,
                                       generator_factory,
                                       include_file,
@@ -124,7 +123,6 @@ def shape_as_pymbolic(shape):
     return tuple(_shape_as_pymbolic(s) for s in shape)
 
 
-@backend(interface="eval_coefficient")
 @cached
 def evaluate_coefficient(element, name, container, restriction, component):
     from ufl.functionview import select_subelement
diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py
index 5f9ac0d9..33af5b7a 100644
--- a/python/dune/perftool/pdelab/localoperator.py
+++ b/python/dune/perftool/pdelab/localoperator.py
@@ -403,13 +403,13 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
                 (ansatz_lfs.get_args() + test_lfs.get_args() + (pymbolic_expr,))
                 )
 
-    from dune.perftool.generation import get_backend, instruction
+    from dune.perftool.generation import instruction
     from dune.perftool.options import option_switch
-    quad_inames = get_backend(interface="quadinames", selector=option_switch("sumfac"))()
+    quad_inames = visitor.interface.quadrature_inames()
 
     instruction(assignees=(),
                 expression=expr,
-                forced_iname_deps=additional_inames.union(frozenset(visitor.inames).union(quad_inames)),
+                forced_iname_deps=additional_inames.union(frozenset(visitor.inames).union(frozenset(quad_inames))),
                 forced_iname_deps_is_final=True,
                 predicates=predicates
                 )
@@ -459,8 +459,12 @@ def generate_kernel(integrals):
                     indexmap[j] = indexmap[i]
 
             # Get a transformer instance for this kernel
-            from dune.perftool.pdelab import PDELabInterface
-            interface = PDELabInterface()
+            if get_option('sumfact'):
+                from dune.perftool.sumfact import SumFactInterface
+                interface = SumFactInterface()
+            else:
+                from dune.perftool.pdelab import PDELabInterface
+                interface = PDELabInterface()
             from dune.perftool.ufl.visitor import UFL2LoopyVisitor
             visitor = UFL2LoopyVisitor(interface, measure, indexmap)
             generate_accumulation_instruction(visitor, term, measure, subdomain_id)
diff --git a/python/dune/perftool/pdelab/quadrature.py b/python/dune/perftool/pdelab/quadrature.py
index b50c8248..ea55c546 100644
--- a/python/dune/perftool/pdelab/quadrature.py
+++ b/python/dune/perftool/pdelab/quadrature.py
@@ -1,5 +1,4 @@
-from dune.perftool.generation import (backend,
-                                      cached,
+from dune.perftool.generation import (cached,
                                       domain,
                                       get_global_context_value,
                                       iname,
@@ -10,6 +9,8 @@ from dune.perftool.generation import (backend,
 from dune.perftool.options import get_option
 from dune.perftool.ufl.modified_terminals import Restriction
 
+from pymbolic.primitives import Variable
+
 
 @iname
 def quadrature_iname():
@@ -17,9 +18,8 @@ def quadrature_iname():
     return "q"
 
 
-@backend(interface="quadinames")
 def quadrature_inames():
-    return frozenset({quadrature_iname()})
+    return (quadrature_iname(),)
 
 
 def quadrature_preamble(code, **kw):
@@ -67,11 +67,11 @@ def define_quadrature_weight(name):
                                )
 
 
-def name_quadrature_weight():
+def pymbolic_quadrature_weight():
     name = 'weight'
     temporary_variable(name, shape=())
     define_quadrature_weight(name)
-    return name
+    return Variable(name)
 
 
 def estimate_quadrature_order():
diff --git a/python/dune/perftool/pdelab/spaces.py b/python/dune/perftool/pdelab/spaces.py
index 07403a86..0e7b0ce2 100644
--- a/python/dune/perftool/pdelab/spaces.py
+++ b/python/dune/perftool/pdelab/spaces.py
@@ -212,6 +212,24 @@ def lfs_iname(element, restriction, count=None, context=''):
     return _lfs_iname(element, restriction, context)
 
 
+class LFSLocalIndex(FunctionIdentifier):
+    def __init__(self, lfs):
+        self.lfs = lfs
+
+    def __getinitargs__(self):
+        return (self.lfs,)
+
+    @property
+    def name(self):
+        return '{}.local_index'.format(self.lfs)
+
+
+@function_mangler
+def lfs_localindex_mangler(target, func, dtypes):
+    if isinstance(func, LFSLocalIndex):
+        return CallMangleInfo(func.name, (NumpyType(numpy.int32),), (NumpyType(numpy.int32),))
+
+
 def name_testfunctionspace(restriction):
     return restricted_name("lfsv", restriction)
 
diff --git a/python/dune/perftool/sumfact/__init__.py b/python/dune/perftool/sumfact/__init__.py
index 3f081643..a4abda6f 100644
--- a/python/dune/perftool/sumfact/__init__.py
+++ b/python/dune/perftool/sumfact/__init__.py
@@ -1,6 +1,18 @@
-# Trigger some imports that are needed to have all backend implementations visible
-# to the selection mechanisms
-import dune.perftool.sumfact.amatrix
-import dune.perftool.sumfact.sumfact
+from dune.perftool.sumfact.quadrature import (quadrature_inames,
+                                              quadrature_weight,
+                                              )
 
-from dune.perftool.sumfact.sumfact import start_sumfactorization
+from dune.perftool.sumfact.sumfact import pymbolic_trialfunction
+
+from dune.perftool.pdelab import PDELabInterface
+
+
+class SumFactInterface(PDELabInterface):
+    def pymbolic_trialfunction(self, element, restriction, component):
+        return pymbolic_trialfunction(element, restriction, component)
+
+    def quadrature_inames(self):
+        return quadrature_inames()
+
+    def pymbolic_quadrature_weight(self):
+        return quadrature_weight()
\ No newline at end of file
diff --git a/python/dune/perftool/sumfact/amatrix.py b/python/dune/perftool/sumfact/amatrix.py
index 968f4f16..9ff78641 100644
--- a/python/dune/perftool/sumfact/amatrix.py
+++ b/python/dune/perftool/sumfact/amatrix.py
@@ -125,6 +125,7 @@ def define_oned_quadrature_weights(name):
 
 def name_oned_quadrature_weights():
     name = "qw"
+    globalarg(name, shape=(quadrature_points_per_direction(),), dtype=NumpyType(numpy.float64))
     define_oned_quadrature_weights(name)
     return name
 
diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py
index 5a3e364d..37be6d7d 100644
--- a/python/dune/perftool/sumfact/sumfact.py
+++ b/python/dune/perftool/sumfact/sumfact.py
@@ -1,4 +1,6 @@
-from dune.perftool.pdelab.argument import pymbolic_coefficient
+from dune.perftool.pdelab.argument import (name_coefficientcontainer,
+                                           pymbolic_coefficient,
+                                           )
 from dune.perftool.generation import (backend,
                                       domain,
                                       get_counter,
@@ -23,7 +25,7 @@ from pymbolic.primitives import (Call,
                                  Subscript,
                                  Variable,
                                  )
-
+from dune.perftool.sumfact.quadrature import quadrature_inames
 from loopy import Reduction
 
 from pytools import product
@@ -42,7 +44,7 @@ def sumfact_iname(bound, _type):
     return name
 
 
-def setup_theta(element, container, restriction, component, a_matrices):
+def setup_theta(element, restriction, component, a_matrices):
     number_basis = product(mat.cols for mat in a_matrices)
     shape = (number_basis,)
     inp = get_buffer_temporary("buffer",
@@ -52,6 +54,7 @@ def setup_theta(element, container, restriction, component, a_matrices):
     # Write initial coefficients into buffer
     basisiname = sumfact_iname(number_basis, "basis")
     lfs = name_lfs(element, restriction, component)
+    container = name_coefficientcontainer(restriction)
     coeff = pymbolic_coefficient(container, lfs, basisiname)
     assignee = Subscript(Variable(inp), (Variable(basisiname),))
     return instruction(assignee=assignee,
@@ -60,9 +63,7 @@ def setup_theta(element, container, restriction, component, a_matrices):
                        )
 
 
-# TODO this code is WIP and mainly used for experiments.
-@backend(interface="eval_coefficient", name="sumfact")
-def start_sumfactorization(element, name, container, restriction, component):
+def pymbolic_trialfunction(element, restriction, component):
     theta = name_theta()
     rows = quadrature_points_per_direction()
     cols = basis_functions_per_direction()
@@ -75,16 +76,17 @@ def start_sumfactorization(element, name, container, restriction, component):
                       num=2
                       )
 
-    insn_dep = setup_theta(element, container, restriction, component, a_matrices)
+    insn_dep = setup_theta(element, restriction, component, a_matrices)
+    var = sum_factorization_kernel(a_matrices, "buffer", 0, frozenset({insn_dep}))
 
-    sum_factorization_kernel(a_matrices, "buffer", 0, frozenset({insn_dep}))
+    return Subscript(Variable(var), tuple(Variable(i) for i in quadrature_inames()))
 
-    # Do stage 3 (for f=u => mass matrix)
-    theta_transposed = name_theta_transposed()
-    a_matrix_transposed = AMatrix(theta_transposed, cols, rows)
-    a_matrices_transposed = (a_matrix_transposed, a_matrix_transposed)
 
-    return sum_factorization_kernel(a_matrices_transposed, "buffer", 2)
+#     # Do stage 3 (for f=u => mass matrix)
+#     theta_transposed = name_theta_transposed()
+#     a_matrix_transposed = AMatrix(theta_transposed, cols, rows)
+#     a_matrices_transposed = (a_matrix_transposed, a_matrix_transposed)
+#     var = sum_factorization_kernel(a_matrices_transposed, "buffer", 2)
 
 
 def sum_factorization_kernel(a_matrices, buffer, stage, insn_dep=frozenset({})):
diff --git a/python/dune/perftool/ufl/visitor.py b/python/dune/perftool/ufl/visitor.py
index d87c5dcf..7a7ff15a 100644
--- a/python/dune/perftool/ufl/visitor.py
+++ b/python/dune/perftool/ufl/visitor.py
@@ -256,7 +256,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
             return Variable(self.interface.name_unit_inner_normal())
 
     def quadrature_weight(self, o):
-        return Variable(self.interface.name_quadrature_weight())
+        return self.interface.pymbolic_quadrature_weight()
 
     def jacobian_determinant(self, o):
         return Variable(self.interface.name_jacobian_determinant())
-- 
GitLab