Skip to content
Snippets Groups Projects
Commit 923f2bbb authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Have unmanaged GlobalArgs - needed for quadrature points

Conflicts:
	python/dune/perftool/pdelab/quadrature.py
parent 15caafb4
No related branches found
No related tags found
No related merge requests found
......@@ -17,13 +17,17 @@ function_mangler = generator_factory(item_tags=("mangler",))
silenced_warning = generator_factory(item_tags=("silenced_warning",), no_deco=True)
class DuneGlobalArg(loopy.GlobalArg):
allowed_extra_kwargs = loopy.GlobalArg.allowed_extra_kwargs + ["managed"]
@generator_factory(item_tags=("argument", "globalarg"),
cache_key_generator=lambda n, **kw: n)
def globalarg(name, shape=loopy.auto, **kw):
def globalarg(name, shape=loopy.auto, managed=True, **kw):
if isinstance(shape, str):
shape = (shape,)
dtype = kw.pop("dtype", numpy.float64)
return loopy.GlobalArg(name, dtype=dtype, shape=shape, **kw)
return DuneGlobalArg(name, dtype=dtype, shape=shape, managed=managed, **kw)
@generator_factory(item_tags=("argument", "constantarg"),
......
from dune.perftool.generation import post_include
from dune.perftool.generation.loopy import DuneGlobalArg
from dune.perftool.loopy.temporary import DuneTemporaryVariable
from dune.perftool.loopy.vcl import VCLTypeRegistry
from dune.perftool.generation import (include_file,
......@@ -29,7 +30,7 @@ _registry = {'float32': 'float',
class DuneExpressionToCExpressionMapper(ExpressionToCExpressionMapper):
def map_subscript(self, expr, type_context):
arr = self.find_array(expr)
if isinstance(arr, DuneTemporaryVariable) and not arr.managed:
if isinstance(arr, (DuneTemporaryVariable, DuneGlobalArg)) and not arr.managed:
# If there is but one index, we do not need to handle this
if isinstance(expr.index, (prim.Variable, int)):
return expr
......
......@@ -118,6 +118,8 @@ def name_quadrature_points():
"""Name of vector storing quadrature points as class member"""
dim = _local_dim()
name = "qp_order" + str(dim)
shape = (name_quadrature_bound(), dim)
globalarg(name, shape=shape, dtype=numpy.float64, managed=False)
define_quadrature_points(name)
fill_quadrature_points_cache(name)
return name
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment