From 70cc6a680c00d844b8a3fb1ef0d369f87ed0b16e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ren=C3=A9=20He=C3=9F?= <rene.hess@iwr.uni-heidelberg.de>
Date: Mon, 24 Oct 2016 11:13:33 +0200
Subject: [PATCH] Copy coefficients to input buffer

---
 python/dune/perftool/pdelab/argument.py |  2 +-
 python/dune/perftool/sumfact/sumfact.py | 20 ++++++++++++++++----
 2 files changed, 17 insertions(+), 5 deletions(-)

diff --git a/python/dune/perftool/pdelab/argument.py b/python/dune/perftool/pdelab/argument.py
index d061f079..06d4a75d 100644
--- a/python/dune/perftool/pdelab/argument.py
+++ b/python/dune/perftool/pdelab/argument.py
@@ -103,7 +103,7 @@ def name_trialfunction_gradient(element, restriction, component):
     # anything.
     if get_option("sumfact") and restriction == Restriction.NONE:
         from dune.perftool.sumfact import start_sumfactorization
-        start_sumfactorization()
+        start_sumfactorization(element, container, restriction, component)
 
     evaluate_coefficient_gradient(element, name, container, restriction, component)
     return name
diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py
index 0d839d61..78e04406 100644
--- a/python/dune/perftool/sumfact/sumfact.py
+++ b/python/dune/perftool/sumfact/sumfact.py
@@ -1,3 +1,4 @@
+from dune.perftool.pdelab.argument import pymbolic_coefficient
 from dune.perftool.generation import (domain,
                                       get_counter,
                                       iname,
@@ -8,6 +9,7 @@ from dune.perftool.loopy.buffer import (get_buffer_temporary,
                                         initialize_buffer,
                                         switch_base_storage,
                                         )
+from dune.perftool.pdelab.spaces import name_lfs
 from pymbolic.primitives import (Product,
                                  Subscript,
                                  Variable,
@@ -32,7 +34,7 @@ def sumfact_iname(bound, _type):
 
 
 # TODO this code is WIP and mainly used for experiments.
-def start_sumfactorization():
+def start_sumfactorization(element, container, restriction, component):
     from dune.perftool.sumfact.amatrix import (AMatrix,
                                                quadrature_points_per_direction,
                                                basis_functions_per_direction,
@@ -50,12 +52,22 @@ def start_sumfactorization():
                       num=2
                       )
 
-    shape = (a_matrices[0].n, product(mat.n for mat in a_matrices[1:]))
+
+    number_basis = product(mat.n for mat in a_matrices)
+    shape = (n,)
     inp = get_buffer_temporary("buffer",
-                               shape=shape,
-                               dim_tags="f,f")
+                               shape=shape)
     silenced_warning('read_no_write({})'.format(inp))
 
+    # Write initial coefficients into buffer
+    basisiname = sumfact_iname(number_basis, "basis")
+    lfs = name_lfs(element, restriction, component)
+    coeff = pymbolic_coefficient(container, lfs, basisiname)
+    assignee = Subscript(Variable(inp), (Variable(basisiname),))
+    instruction(assignee = assignee,
+                expression = coeff,
+                )
+
     return sum_factorization_kernel(a_matrices, inp, "buffer")
 
 
-- 
GitLab