From 8ae78f0530467cbed21f32f3964adf811abfdaa0 Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Wed, 13 Apr 2016 10:06:36 +0200 Subject: [PATCH] Reweite the temporary_variable generator --- python/dune/perftool/generation/loopy.py | 7 ++++++- python/dune/perftool/pdelab/basis.py | 15 ++++++++++----- python/dune/perftool/pdelab/quadrature.py | 2 +- 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/python/dune/perftool/generation/loopy.py b/python/dune/perftool/generation/loopy.py index 3faf5125..e079c137 100644 --- a/python/dune/perftool/generation/loopy.py +++ b/python/dune/perftool/generation/loopy.py @@ -7,7 +7,6 @@ import loopy import numpy iname = generator_factory(item_tags=("loopy", "kernel", "iname")) -temporary_variable = generator_factory(item_tags=("loopy", "kernel", "temporary"), on_store=lambda n: loopy.TemporaryVariable(n, dtype=numpy.float64), no_deco=True) valuearg = generator_factory(item_tags=("loopy", "kernel", "argument", "valuearg"), on_store=lambda n: loopy.ValueArg(n), no_deco=True) pymbolic_expr = generator_factory(item_tags=("loopy", "kernel", "pymbolic")) constantarg = generator_factory(item_tags=("loopy", "kernel", "argument", "constantarg"), on_store=lambda n: loopy.ConstantArg(n)) @@ -28,6 +27,12 @@ def domain(iname, shape): return "{{ [{0}] : 0<={0}<{1} }}".format(iname, shape) +@generator_factory(item_tags=("loopy", "kernel", "temporary"), cache_key_generator=lambda n, **kw: n) +def temporary_variable(name, **kwargs): + if 'dtype' not in kwargs: + kwargs['dtype'] = numpy.float64 + return loopy.TemporaryVariable(name, **kwargs) + # Now define generators for instructions. To ease dependency handling of instructions # these generators are a bit more involved... We apply the following procedure: # There is one generator that returns the unique id and forwards to a generator that diff --git a/python/dune/perftool/pdelab/basis.py b/python/dune/perftool/pdelab/basis.py index f721b6d7..e2b46143 100644 --- a/python/dune/perftool/pdelab/basis.py +++ b/python/dune/perftool/pdelab/basis.py @@ -37,7 +37,8 @@ def name_localfunctionspace(expr): @cached def evaluate_basis(element, name): - temporary_variable(name) + # TODO this is of course not yet correct + temporary_variable(name, shape=('arg0_n',)) lfs = name_localfunctionspace(element) qp = name_quadraturepoint() instruction(inames=(quadrature_iname(), @@ -58,7 +59,8 @@ def name_basis(element): @cached def evaluate_reference_gradient(element, name): - temporary_variable(name) + # TODO this is of course not yet correct + temporary_variable(name, shape=('arg0_n', 2)) lfs = name_localfunctionspace(element) qp = name_quadraturepoint() instruction(inames=(quadrature_iname(), @@ -79,7 +81,8 @@ def name_reference_gradient(element): @cached def evaluate_basis_gradient(element, name): - temporary_variable(name) + # TODO this is of course not yet correct + temporary_variable(name, shape=('arg0_n', 2)) jac = name_jacobian_inverse_transposed() index = lfs_iname(element) reference_gradients = name_reference_gradient(element) @@ -103,7 +106,8 @@ def name_basis_gradient(element): @cached def evaluate_trialfunction(element, name): - temporary_variable(name) + # TODO this is of course not yet correct + temporary_variable(name, shape=()) lfs = name_localfunctionspace(element) index = lfs_iname(element) basis = name_basis() @@ -122,7 +126,8 @@ def evaluate_trialfunction(element, name): @cached def evaluate_trialfunction_gradient(element, name): - temporary_variable(name) + # TODO this is of course not yet correct + temporary_variable(name, shape=(2,)) lfs = name_localfunctionspace(element) index = lfs_iname(element) basis = name_basis_gradient(element) diff --git a/python/dune/perftool/pdelab/quadrature.py b/python/dune/perftool/pdelab/quadrature.py index c6190a2e..b25cdc2e 100644 --- a/python/dune/perftool/pdelab/quadrature.py +++ b/python/dune/perftool/pdelab/quadrature.py @@ -30,6 +30,6 @@ def name_quadraturepoint(): @symbol def name_factor(): - temporary_variable("fac") + temporary_variable("fac", shape=()) define_quadrature_factor("fac") return "fac" -- GitLab