diff --git a/python/dune/perftool/generation/loopy.py b/python/dune/perftool/generation/loopy.py index bdca6a9f2b99af29aee25fb181a0e55f45adb292..bec300327de47280c26e16914d775363aab4603b 100644 --- a/python/dune/perftool/generation/loopy.py +++ b/python/dune/perftool/generation/loopy.py @@ -17,13 +17,17 @@ function_mangler = generator_factory(item_tags=("mangler",)) silenced_warning = generator_factory(item_tags=("silenced_warning",), no_deco=True) +class DuneGlobalArg(loopy.GlobalArg): + allowed_extra_kwargs = loopy.GlobalArg.allowed_extra_kwargs + ["managed"] + + @generator_factory(item_tags=("argument", "globalarg"), cache_key_generator=lambda n, **kw: n) -def globalarg(name, shape=loopy.auto, **kw): +def globalarg(name, shape=loopy.auto, managed=True, **kw): if isinstance(shape, str): shape = (shape,) dtype = kw.pop("dtype", numpy.float64) - return loopy.GlobalArg(name, dtype=dtype, shape=shape, **kw) + return DuneGlobalArg(name, dtype=dtype, shape=shape, managed=managed, **kw) @generator_factory(item_tags=("argument", "constantarg"), diff --git a/python/dune/perftool/loopy/target.py b/python/dune/perftool/loopy/target.py index a28a10adc9a107ea753b20b130dada055bb54bc8..a1b5a07b400b4ac4d61e4e4d010d0f5b19dae785 100644 --- a/python/dune/perftool/loopy/target.py +++ b/python/dune/perftool/loopy/target.py @@ -1,5 +1,6 @@ from dune.perftool.generation import post_include +from dune.perftool.generation.loopy import DuneGlobalArg from dune.perftool.loopy.temporary import DuneTemporaryVariable from dune.perftool.loopy.vcl import VCLTypeRegistry from dune.perftool.generation import (include_file, @@ -29,7 +30,7 @@ _registry = {'float32': 'float', class DuneExpressionToCExpressionMapper(ExpressionToCExpressionMapper): def map_subscript(self, expr, type_context): arr = self.find_array(expr) - if isinstance(arr, DuneTemporaryVariable) and not arr.managed: + if isinstance(arr, (DuneTemporaryVariable, DuneGlobalArg)) and not arr.managed: # If there is but one index, we do not need to handle this if isinstance(expr.index, (prim.Variable, int)): return expr diff --git a/python/dune/perftool/pdelab/quadrature.py b/python/dune/perftool/pdelab/quadrature.py index c0ddc6a99dac2c0f2f179e366e7d76846f28c7c4..72a42aa73ca20e0543f390ccbd193b5fc2895943 100644 --- a/python/dune/perftool/pdelab/quadrature.py +++ b/python/dune/perftool/pdelab/quadrature.py @@ -118,6 +118,8 @@ def name_quadrature_points(): """Name of vector storing quadrature points as class member""" dim = _local_dim() name = "qp_order" + str(dim) + shape = (name_quadrature_bound(), dim) + globalarg(name, shape=shape, dtype=numpy.float64, managed=False) define_quadrature_points(name) fill_quadrature_points_cache(name) return name