From 8ae78f0530467cbed21f32f3964adf811abfdaa0 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Wed, 13 Apr 2016 10:06:36 +0200
Subject: [PATCH] Reweite the temporary_variable generator

---
 python/dune/perftool/generation/loopy.py  |  7 ++++++-
 python/dune/perftool/pdelab/basis.py      | 15 ++++++++++-----
 python/dune/perftool/pdelab/quadrature.py |  2 +-
 3 files changed, 17 insertions(+), 7 deletions(-)

diff --git a/python/dune/perftool/generation/loopy.py b/python/dune/perftool/generation/loopy.py
index 3faf5125..e079c137 100644
--- a/python/dune/perftool/generation/loopy.py
+++ b/python/dune/perftool/generation/loopy.py
@@ -7,7 +7,6 @@ import loopy
 import numpy
 
 iname = generator_factory(item_tags=("loopy", "kernel", "iname"))
-temporary_variable = generator_factory(item_tags=("loopy", "kernel", "temporary"), on_store=lambda n: loopy.TemporaryVariable(n, dtype=numpy.float64), no_deco=True)
 valuearg = generator_factory(item_tags=("loopy", "kernel", "argument", "valuearg"), on_store=lambda n: loopy.ValueArg(n), no_deco=True)
 pymbolic_expr = generator_factory(item_tags=("loopy", "kernel", "pymbolic"))
 constantarg = generator_factory(item_tags=("loopy", "kernel", "argument", "constantarg"), on_store=lambda n: loopy.ConstantArg(n))
@@ -28,6 +27,12 @@ def domain(iname, shape):
     return "{{ [{0}] : 0<={0}<{1} }}".format(iname, shape)
 
 
+@generator_factory(item_tags=("loopy", "kernel", "temporary"), cache_key_generator=lambda n, **kw: n)
+def temporary_variable(name, **kwargs):
+    if 'dtype' not in kwargs:
+        kwargs['dtype'] = numpy.float64
+    return loopy.TemporaryVariable(name, **kwargs)
+
 # Now define generators for instructions. To ease dependency handling of instructions
 # these generators are a bit more involved... We apply the following procedure:
 # There is one generator that returns the unique id and forwards to a generator that
diff --git a/python/dune/perftool/pdelab/basis.py b/python/dune/perftool/pdelab/basis.py
index f721b6d7..e2b46143 100644
--- a/python/dune/perftool/pdelab/basis.py
+++ b/python/dune/perftool/pdelab/basis.py
@@ -37,7 +37,8 @@ def name_localfunctionspace(expr):
 
 @cached
 def evaluate_basis(element, name):
-    temporary_variable(name)
+    # TODO this is of course not yet correct
+    temporary_variable(name, shape=('arg0_n',))
     lfs = name_localfunctionspace(element)
     qp = name_quadraturepoint()
     instruction(inames=(quadrature_iname(),
@@ -58,7 +59,8 @@ def name_basis(element):
 
 @cached
 def evaluate_reference_gradient(element, name):
-    temporary_variable(name)
+    # TODO this is of course not yet correct
+    temporary_variable(name, shape=('arg0_n', 2))
     lfs = name_localfunctionspace(element)
     qp = name_quadraturepoint()
     instruction(inames=(quadrature_iname(),
@@ -79,7 +81,8 @@ def name_reference_gradient(element):
 
 @cached
 def evaluate_basis_gradient(element, name):
-    temporary_variable(name)
+    # TODO this is of course not yet correct
+    temporary_variable(name, shape=('arg0_n', 2))
     jac = name_jacobian_inverse_transposed()
     index = lfs_iname(element)
     reference_gradients = name_reference_gradient(element)
@@ -103,7 +106,8 @@ def name_basis_gradient(element):
 
 @cached
 def evaluate_trialfunction(element, name):
-    temporary_variable(name)
+    # TODO this is of course not yet correct
+    temporary_variable(name, shape=())
     lfs = name_localfunctionspace(element)
     index = lfs_iname(element)
     basis = name_basis()
@@ -122,7 +126,8 @@ def evaluate_trialfunction(element, name):
 
 @cached
 def evaluate_trialfunction_gradient(element, name):
-    temporary_variable(name)
+    # TODO this is of course not yet correct
+    temporary_variable(name, shape=(2,))
     lfs = name_localfunctionspace(element)
     index = lfs_iname(element)
     basis = name_basis_gradient(element)
diff --git a/python/dune/perftool/pdelab/quadrature.py b/python/dune/perftool/pdelab/quadrature.py
index c6190a2e..b25cdc2e 100644
--- a/python/dune/perftool/pdelab/quadrature.py
+++ b/python/dune/perftool/pdelab/quadrature.py
@@ -30,6 +30,6 @@ def name_quadraturepoint():
 
 @symbol
 def name_factor():
-    temporary_variable("fac")
+    temporary_variable("fac", shape=())
     define_quadrature_factor("fac")
     return "fac"
-- 
GitLab