From d388acca82b8ddf69a3687c7d9e6bd18dab30417 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Thu, 3 Nov 2016 10:42:06 +0100
Subject: [PATCH] Introduce a module for basis related sumfac stuff

---
 python/dune/perftool/generation/backend.py |  2 +-
 python/dune/perftool/sumfact/__init__.py   |  2 +-
 python/dune/perftool/sumfact/amatrix.py    |  6 ++--
 python/dune/perftool/sumfact/basis.py      | 38 ++++++++++++++++++++++
 python/dune/perftool/sumfact/sumfact.py    | 19 -----------
 5 files changed, 43 insertions(+), 24 deletions(-)
 create mode 100644 python/dune/perftool/sumfact/basis.py

diff --git a/python/dune/perftool/generation/backend.py b/python/dune/perftool/generation/backend.py
index 0e7b7ec2..6bd6544c 100644
--- a/python/dune/perftool/generation/backend.py
+++ b/python/dune/perftool/generation/backend.py
@@ -39,6 +39,6 @@ def get_backend(interface=None, selector=option_switch("sumfact"), **kwargs):
     assert interface and selector
 
     select = selector(**kwargs)
-    assert select in _backend_mapping[interface]
+    assert select in _backend_mapping[interface], "Implementation '{}' for interface '{}' missing!".format(select, interface)
 
     return _backend_mapping[interface][select]
diff --git a/python/dune/perftool/sumfact/__init__.py b/python/dune/perftool/sumfact/__init__.py
index e263c742..82516df9 100644
--- a/python/dune/perftool/sumfact/__init__.py
+++ b/python/dune/perftool/sumfact/__init__.py
@@ -2,7 +2,7 @@ from dune.perftool.sumfact.quadrature import (quadrature_inames,
                                               quadrature_weight,
                                               )
 
-from dune.perftool.sumfact.sumfact import pymbolic_trialfunction
+from dune.perftool.sumfact.basis import pymbolic_trialfunction
 
 from dune.perftool.pdelab import PDELabInterface
 
diff --git a/python/dune/perftool/sumfact/amatrix.py b/python/dune/perftool/sumfact/amatrix.py
index 8dfc180c..b6ab9d3d 100644
--- a/python/dune/perftool/sumfact/amatrix.py
+++ b/python/dune/perftool/sumfact/amatrix.py
@@ -13,20 +13,20 @@ from dune.perftool.generation import (class_member,
                                       iname,
                                       include_file,
                                       initializer_list,
+                                      silenced_warning,
                                       temporary_variable,
                                       valuearg
                                       )
-
+from dune.perftool.loopy.buffer import get_buffer_temporary
 from dune.perftool.pdelab.localoperator import (name_domain_field,
                                                 lop_template_range_field,
                                                 )
 from dune.perftool.pdelab.quadrature import quadrature_order
-
 from loopy import CallMangleInfo
 from loopy.symbolic import FunctionIdentifier
 from loopy.types import NumpyType
 
-from pytools import Record
+from pytools import Record, product
 
 import numpy
 
diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py
new file mode 100644
index 00000000..008e8853
--- /dev/null
+++ b/python/dune/perftool/sumfact/basis.py
@@ -0,0 +1,38 @@
+""" Generator functions to evaluate trial and test functions
+
+NB: Basis evaluation is only needed for the trial function argument in jacobians, as the
+multiplication withthe test function is part of the sum factorization kernel.
+"""
+from dune.perftool.sumfact.amatrix import (AMatrix,
+                                           basis_functions_per_direction,
+                                           name_theta,
+                                           quadrature_points_per_direction,
+                                           )
+from dune.perftool.sumfact.sumfact import (setup_theta,
+                                           sum_factorization_kernel,
+                                           )
+from dune.perftool.sumfact.quadrature import quadrature_inames
+from dune.perftool.loopy.buffer import initialize_buffer
+
+from pytools import product
+
+from pymbolic.primitives import Subscript, Variable
+
+
+def pymbolic_trialfunction(element, restriction, component):
+    theta = name_theta()
+    rows = quadrature_points_per_direction()
+    cols = basis_functions_per_direction()
+    a_matrix = AMatrix(theta, rows, cols)
+    a_matrices = (a_matrix, a_matrix)
+
+    # Do stage 1
+    initialize_buffer("buffer",
+                      base_storage_size=product(max(mat.rows, mat.cols) for mat in a_matrices),
+                      num=2
+                      )
+
+    insn_dep = setup_theta(element, restriction, component, a_matrices)
+    var = sum_factorization_kernel(a_matrices, "buffer", 0, frozenset({insn_dep}))
+
+    return Subscript(Variable(var), tuple(Variable(i) for i in quadrature_inames()))
diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py
index e1dff60d..1186d415 100644
--- a/python/dune/perftool/sumfact/sumfact.py
+++ b/python/dune/perftool/sumfact/sumfact.py
@@ -72,25 +72,6 @@ def setup_theta(element, restriction, component, a_matrices):
                        )
 
 
-def pymbolic_trialfunction(element, restriction, component):
-    theta = name_theta()
-    rows = quadrature_points_per_direction()
-    cols = basis_functions_per_direction()
-    a_matrix = AMatrix(theta, rows, cols)
-    a_matrices = (a_matrix, a_matrix)
-
-    # Do stage 1
-    initialize_buffer("buffer",
-                      base_storage_size=product(max(mat.rows, mat.cols) for mat in a_matrices),
-                      num=2
-                      )
-
-    insn_dep = setup_theta(element, restriction, component, a_matrices)
-    var = sum_factorization_kernel(a_matrices, "buffer", 0, frozenset({insn_dep}))
-
-    return Subscript(Variable(var), tuple(Variable(i) for i in quadrature_inames()))
-
-
 @backend(interface="accum_insn", name="sumfact")
 def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
     pymbolic_expr = visitor(accterm.term)
-- 
GitLab