Skip to content
Snippets Groups Projects
Commit a82b2b6b authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Represent the residual/the jacobian as a virtual loopy global arg

parent a193946f
No related branches found
No related tags found
No related merge requests found
......@@ -13,11 +13,12 @@ pymbolic_expr = generator_factory(item_tags=("loopy", "kernel", "pymbolic"))
constantarg = generator_factory(item_tags=("loopy", "kernel", "argument", "constantarg"), on_store=lambda n: loopy.ConstantArg(n))
@generator_factory(item_tags=("loopy", "kernel", "argument", "globalarg"))
def globalarg(name, shape=loopy.auto):
@generator_factory(item_tags=("loopy", "kernel", "argument", "globalarg"),
cache_key_generator=lambda n, **kw: n)
def globalarg(name, shape=loopy.auto, **kw):
if isinstance(shape, str):
shape = (shape,)
return loopy.GlobalArg(name, numpy.float64, shape)
return loopy.GlobalArg(name, dtype=numpy.float64, shape=shape, **kw)
@generator_factory(item_tags=("loopy", "kernel", "domain"))
......
......@@ -13,6 +13,7 @@ from dune.perftool.generation import (domain,
globalarg,
iname,
instruction,
symbol,
temporary_variable,
valuearg,
)
......@@ -22,12 +23,17 @@ from ufl.algorithms import MultiFunction
import loopy
@symbol
def argument_bound(number):
return "arg{}_n".format(number)
@iname
def argument_iname(arg):
# TODO extract the {iname}_n thing by a preamble
from dune.perftool.ufl.modified_terminals import modified_argument_number
ainame = "arg{}".format(chr(ord("i") + arg.argexpr.number()))
domain(ainame, ainame + "_n")
domain(ainame, argument_bound(arg.argexpr.number()))
return ainame
......@@ -134,6 +140,10 @@ def transform_accumulation_term(term):
from dune.perftool.pdelab.argument import name_residual
residual = name_residual()
# The residual/the jacobian should be represented through a loopy global argument
from dune.perftool.ufl.rank import ufl_rank
globalarg(residual, shape=tuple(argument_bound(i) for i in range(ufl_rank(term))))
from dune.perftool.generation import retrieve_cache_items
inames = retrieve_cache_items("iname")
......@@ -143,5 +153,6 @@ def transform_accumulation_term(term):
", ".join(accumargs),
expr_tv_name,
name_factor()
)
),
assignees=residual,
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment