Skip to content
Snippets Groups Projects
Commit f5b7508f authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Simplify AccumulationArgument trick

parent 270208ee
No related branches found
No related tags found
No related merge requests found
......@@ -15,20 +15,13 @@ function_mangler = generator_factory(item_tags=("mangler",))
silenced_warning = generator_factory(item_tags=("silenced_warning",), no_deco=True)
class AccumulationGlobalArg(loopy.GlobalArg):
allowed_extra_kwargs = loopy.GlobalArg.allowed_extra_kwargs + ['transform']
@generator_factory(item_tags=("argument", "globalarg"),
cache_key_generator=lambda n, **kw: n)
def globalarg(name, shape=loopy.auto, **kw):
def globalarg(name, shape=loopy.auto, argtype=loopy.GlobalArg, **kw):
if isinstance(shape, str):
shape = (shape,)
dtype = kw.pop("dtype", numpy.float64)
if 'transform' in kw:
return AccumulationGlobalArg(name, dtype=dtype, shape=shape, **kw)
else:
return loopy.GlobalArg(name, dtype=dtype, shape=shape, **kw)
return argtype(name, dtype=dtype, shape=shape, **kw)
@generator_factory(item_tags=("argument", "constantarg"),
......
from dune.perftool.loopy.temporary import DuneTemporaryVariable
from dune.perftool.generation.loopy import AccumulationGlobalArg
from dune.perftool.sumfact.sumfact import AccumulationArg
from dune.perftool.pdelab.spaces import LFSLocalIndex
from loopy.target import (TargetBase,
ASTBuilderBase,
......@@ -32,11 +33,10 @@ class MyMapper(ExpressionToCExpressionMapper):
for i in expr.index:
ret = Subscript(ret, i)
return ret
elif isinstance(arr, AccumulationGlobalArg):
assert isinstance(arr.transform, FunctionIdentifier)
elif isinstance(arr, AccumulationArg):
pseudo_subscript = Subscript(expr.aggregate, expr.index)
flattened = ExpressionToCExpressionMapper.map_subscript(self, pseudo_subscript, type_context)
transformed = ExpressionToCExpressionMapper.map_call(self, Call(arr.transform, (flattened.index,)), type_context)
transformed = ExpressionToCExpressionMapper.map_call(self, Call(LFSLocalIndex(arr.lfs), (flattened.index,)), type_context)
return Subscript(Variable(flattened.aggregate.name + '.base()'), transformed)
else:
return ExpressionToCExpressionMapper.map_subscript(self, expr, type_context)
......
......@@ -30,12 +30,16 @@ from pymbolic.primitives import (Call,
Variable,
)
from dune.perftool.sumfact.quadrature import quadrature_inames
from loopy import Reduction
from loopy import Reduction, GlobalArg
from loopy.symbolic import FunctionIdentifier
from pytools import product
class AccumulationArg(GlobalArg):
allowed_extra_kwargs = GlobalArg.allowed_extra_kwargs + ["lfs"]
@iname
def _sumfact_iname(bound, _type, count):
name = "sf_{}_{}".format(_type, str(count))
......@@ -122,8 +126,6 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
result = sum_factorization_kernel(a_matrices, "reffub", 2)
from dune.perftool.pdelab.spaces import LFSLocalIndex
# Now write all this into the correct residual
lfs = name_lfs(accterm.argument.argexpr.ufl_element(),
accterm.argument.restriction,
......@@ -132,7 +134,8 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
inames = tuple(sumfact_iname(mat.rows, 'accum') for mat in a_matrices)
globalarg("r",
shape=(basis_functions_per_direction(),) * dim,
transform=LFSLocalIndex(lfs),
argtype=AccumulationArg,
lfs=lfs,
)
instruction(expression=Subscript(Variable(result), tuple(Variable(i) for i in inames)),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment