Skip to content
Snippets Groups Projects
Commit 8ae78f05 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Reweite the temporary_variable generator

parent 887cb5de
No related branches found
No related tags found
No related merge requests found
...@@ -7,7 +7,6 @@ import loopy ...@@ -7,7 +7,6 @@ import loopy
import numpy import numpy
iname = generator_factory(item_tags=("loopy", "kernel", "iname")) iname = generator_factory(item_tags=("loopy", "kernel", "iname"))
temporary_variable = generator_factory(item_tags=("loopy", "kernel", "temporary"), on_store=lambda n: loopy.TemporaryVariable(n, dtype=numpy.float64), no_deco=True)
valuearg = generator_factory(item_tags=("loopy", "kernel", "argument", "valuearg"), on_store=lambda n: loopy.ValueArg(n), no_deco=True) valuearg = generator_factory(item_tags=("loopy", "kernel", "argument", "valuearg"), on_store=lambda n: loopy.ValueArg(n), no_deco=True)
pymbolic_expr = generator_factory(item_tags=("loopy", "kernel", "pymbolic")) pymbolic_expr = generator_factory(item_tags=("loopy", "kernel", "pymbolic"))
constantarg = generator_factory(item_tags=("loopy", "kernel", "argument", "constantarg"), on_store=lambda n: loopy.ConstantArg(n)) constantarg = generator_factory(item_tags=("loopy", "kernel", "argument", "constantarg"), on_store=lambda n: loopy.ConstantArg(n))
...@@ -28,6 +27,12 @@ def domain(iname, shape): ...@@ -28,6 +27,12 @@ def domain(iname, shape):
return "{{ [{0}] : 0<={0}<{1} }}".format(iname, shape) return "{{ [{0}] : 0<={0}<{1} }}".format(iname, shape)
@generator_factory(item_tags=("loopy", "kernel", "temporary"), cache_key_generator=lambda n, **kw: n)
def temporary_variable(name, **kwargs):
if 'dtype' not in kwargs:
kwargs['dtype'] = numpy.float64
return loopy.TemporaryVariable(name, **kwargs)
# Now define generators for instructions. To ease dependency handling of instructions # Now define generators for instructions. To ease dependency handling of instructions
# these generators are a bit more involved... We apply the following procedure: # these generators are a bit more involved... We apply the following procedure:
# There is one generator that returns the unique id and forwards to a generator that # There is one generator that returns the unique id and forwards to a generator that
......
...@@ -37,7 +37,8 @@ def name_localfunctionspace(expr): ...@@ -37,7 +37,8 @@ def name_localfunctionspace(expr):
@cached @cached
def evaluate_basis(element, name): def evaluate_basis(element, name):
temporary_variable(name) # TODO this is of course not yet correct
temporary_variable(name, shape=('arg0_n',))
lfs = name_localfunctionspace(element) lfs = name_localfunctionspace(element)
qp = name_quadraturepoint() qp = name_quadraturepoint()
instruction(inames=(quadrature_iname(), instruction(inames=(quadrature_iname(),
...@@ -58,7 +59,8 @@ def name_basis(element): ...@@ -58,7 +59,8 @@ def name_basis(element):
@cached @cached
def evaluate_reference_gradient(element, name): def evaluate_reference_gradient(element, name):
temporary_variable(name) # TODO this is of course not yet correct
temporary_variable(name, shape=('arg0_n', 2))
lfs = name_localfunctionspace(element) lfs = name_localfunctionspace(element)
qp = name_quadraturepoint() qp = name_quadraturepoint()
instruction(inames=(quadrature_iname(), instruction(inames=(quadrature_iname(),
...@@ -79,7 +81,8 @@ def name_reference_gradient(element): ...@@ -79,7 +81,8 @@ def name_reference_gradient(element):
@cached @cached
def evaluate_basis_gradient(element, name): def evaluate_basis_gradient(element, name):
temporary_variable(name) # TODO this is of course not yet correct
temporary_variable(name, shape=('arg0_n', 2))
jac = name_jacobian_inverse_transposed() jac = name_jacobian_inverse_transposed()
index = lfs_iname(element) index = lfs_iname(element)
reference_gradients = name_reference_gradient(element) reference_gradients = name_reference_gradient(element)
...@@ -103,7 +106,8 @@ def name_basis_gradient(element): ...@@ -103,7 +106,8 @@ def name_basis_gradient(element):
@cached @cached
def evaluate_trialfunction(element, name): def evaluate_trialfunction(element, name):
temporary_variable(name) # TODO this is of course not yet correct
temporary_variable(name, shape=())
lfs = name_localfunctionspace(element) lfs = name_localfunctionspace(element)
index = lfs_iname(element) index = lfs_iname(element)
basis = name_basis() basis = name_basis()
...@@ -122,7 +126,8 @@ def evaluate_trialfunction(element, name): ...@@ -122,7 +126,8 @@ def evaluate_trialfunction(element, name):
@cached @cached
def evaluate_trialfunction_gradient(element, name): def evaluate_trialfunction_gradient(element, name):
temporary_variable(name) # TODO this is of course not yet correct
temporary_variable(name, shape=(2,))
lfs = name_localfunctionspace(element) lfs = name_localfunctionspace(element)
index = lfs_iname(element) index = lfs_iname(element)
basis = name_basis_gradient(element) basis = name_basis_gradient(element)
......
...@@ -30,6 +30,6 @@ def name_quadraturepoint(): ...@@ -30,6 +30,6 @@ def name_quadraturepoint():
@symbol @symbol
def name_factor(): def name_factor():
temporary_variable("fac") temporary_variable("fac", shape=())
define_quadrature_factor("fac") define_quadrature_factor("fac")
return "fac" return "fac"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment