diff --git a/python/dune/perftool/generation/loopy.py b/python/dune/perftool/generation/loopy.py index 3faf5125572132d2b3d5f7436455a23f734b4f40..e079c137c5434bba5ec068ea590da481106c1f84 100644 --- a/python/dune/perftool/generation/loopy.py +++ b/python/dune/perftool/generation/loopy.py @@ -7,7 +7,6 @@ import loopy import numpy iname = generator_factory(item_tags=("loopy", "kernel", "iname")) -temporary_variable = generator_factory(item_tags=("loopy", "kernel", "temporary"), on_store=lambda n: loopy.TemporaryVariable(n, dtype=numpy.float64), no_deco=True) valuearg = generator_factory(item_tags=("loopy", "kernel", "argument", "valuearg"), on_store=lambda n: loopy.ValueArg(n), no_deco=True) pymbolic_expr = generator_factory(item_tags=("loopy", "kernel", "pymbolic")) constantarg = generator_factory(item_tags=("loopy", "kernel", "argument", "constantarg"), on_store=lambda n: loopy.ConstantArg(n)) @@ -28,6 +27,12 @@ def domain(iname, shape): return "{{ [{0}] : 0<={0}<{1} }}".format(iname, shape) +@generator_factory(item_tags=("loopy", "kernel", "temporary"), cache_key_generator=lambda n, **kw: n) +def temporary_variable(name, **kwargs): + if 'dtype' not in kwargs: + kwargs['dtype'] = numpy.float64 + return loopy.TemporaryVariable(name, **kwargs) + # Now define generators for instructions. To ease dependency handling of instructions # these generators are a bit more involved... We apply the following procedure: # There is one generator that returns the unique id and forwards to a generator that diff --git a/python/dune/perftool/pdelab/basis.py b/python/dune/perftool/pdelab/basis.py index f721b6d72c41c26a7b1875304d294f23627baec3..e2b46143f46ccca28c3eb6929c9958bbe95d539c 100644 --- a/python/dune/perftool/pdelab/basis.py +++ b/python/dune/perftool/pdelab/basis.py @@ -37,7 +37,8 @@ def name_localfunctionspace(expr): @cached def evaluate_basis(element, name): - temporary_variable(name) + # TODO this is of course not yet correct + temporary_variable(name, shape=('arg0_n',)) lfs = name_localfunctionspace(element) qp = name_quadraturepoint() instruction(inames=(quadrature_iname(), @@ -58,7 +59,8 @@ def name_basis(element): @cached def evaluate_reference_gradient(element, name): - temporary_variable(name) + # TODO this is of course not yet correct + temporary_variable(name, shape=('arg0_n', 2)) lfs = name_localfunctionspace(element) qp = name_quadraturepoint() instruction(inames=(quadrature_iname(), @@ -79,7 +81,8 @@ def name_reference_gradient(element): @cached def evaluate_basis_gradient(element, name): - temporary_variable(name) + # TODO this is of course not yet correct + temporary_variable(name, shape=('arg0_n', 2)) jac = name_jacobian_inverse_transposed() index = lfs_iname(element) reference_gradients = name_reference_gradient(element) @@ -103,7 +106,8 @@ def name_basis_gradient(element): @cached def evaluate_trialfunction(element, name): - temporary_variable(name) + # TODO this is of course not yet correct + temporary_variable(name, shape=()) lfs = name_localfunctionspace(element) index = lfs_iname(element) basis = name_basis() @@ -122,7 +126,8 @@ def evaluate_trialfunction(element, name): @cached def evaluate_trialfunction_gradient(element, name): - temporary_variable(name) + # TODO this is of course not yet correct + temporary_variable(name, shape=(2,)) lfs = name_localfunctionspace(element) index = lfs_iname(element) basis = name_basis_gradient(element) diff --git a/python/dune/perftool/pdelab/quadrature.py b/python/dune/perftool/pdelab/quadrature.py index c6190a2eee1fdcb9b7495c113705e0e7f5938117..b25cdc2e6add781dbb2b29933b4dc36886d92ae0 100644 --- a/python/dune/perftool/pdelab/quadrature.py +++ b/python/dune/perftool/pdelab/quadrature.py @@ -30,6 +30,6 @@ def name_quadraturepoint(): @symbol def name_factor(): - temporary_variable("fac") + temporary_variable("fac", shape=()) define_quadrature_factor("fac") return "fac"