From 923f2bbb3033904ff350c67d58bcab8342ecec92 Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Tue, 8 Nov 2016 20:42:56 +0100 Subject: [PATCH] Have unmanaged GlobalArgs - needed for quadrature points Conflicts: python/dune/perftool/pdelab/quadrature.py --- python/dune/perftool/generation/loopy.py | 8 ++++++-- python/dune/perftool/loopy/target.py | 3 ++- python/dune/perftool/pdelab/quadrature.py | 2 ++ 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/python/dune/perftool/generation/loopy.py b/python/dune/perftool/generation/loopy.py index bdca6a9f..bec30032 100644 --- a/python/dune/perftool/generation/loopy.py +++ b/python/dune/perftool/generation/loopy.py @@ -17,13 +17,17 @@ function_mangler = generator_factory(item_tags=("mangler",)) silenced_warning = generator_factory(item_tags=("silenced_warning",), no_deco=True) +class DuneGlobalArg(loopy.GlobalArg): + allowed_extra_kwargs = loopy.GlobalArg.allowed_extra_kwargs + ["managed"] + + @generator_factory(item_tags=("argument", "globalarg"), cache_key_generator=lambda n, **kw: n) -def globalarg(name, shape=loopy.auto, **kw): +def globalarg(name, shape=loopy.auto, managed=True, **kw): if isinstance(shape, str): shape = (shape,) dtype = kw.pop("dtype", numpy.float64) - return loopy.GlobalArg(name, dtype=dtype, shape=shape, **kw) + return DuneGlobalArg(name, dtype=dtype, shape=shape, managed=managed, **kw) @generator_factory(item_tags=("argument", "constantarg"), diff --git a/python/dune/perftool/loopy/target.py b/python/dune/perftool/loopy/target.py index a28a10ad..a1b5a07b 100644 --- a/python/dune/perftool/loopy/target.py +++ b/python/dune/perftool/loopy/target.py @@ -1,5 +1,6 @@ from dune.perftool.generation import post_include +from dune.perftool.generation.loopy import DuneGlobalArg from dune.perftool.loopy.temporary import DuneTemporaryVariable from dune.perftool.loopy.vcl import VCLTypeRegistry from dune.perftool.generation import (include_file, @@ -29,7 +30,7 @@ _registry = {'float32': 'float', class DuneExpressionToCExpressionMapper(ExpressionToCExpressionMapper): def map_subscript(self, expr, type_context): arr = self.find_array(expr) - if isinstance(arr, DuneTemporaryVariable) and not arr.managed: + if isinstance(arr, (DuneTemporaryVariable, DuneGlobalArg)) and not arr.managed: # If there is but one index, we do not need to handle this if isinstance(expr.index, (prim.Variable, int)): return expr diff --git a/python/dune/perftool/pdelab/quadrature.py b/python/dune/perftool/pdelab/quadrature.py index c0ddc6a9..72a42aa7 100644 --- a/python/dune/perftool/pdelab/quadrature.py +++ b/python/dune/perftool/pdelab/quadrature.py @@ -118,6 +118,8 @@ def name_quadrature_points(): """Name of vector storing quadrature points as class member""" dim = _local_dim() name = "qp_order" + str(dim) + shape = (name_quadrature_bound(), dim) + globalarg(name, shape=shape, dtype=numpy.float64, managed=False) define_quadrature_points(name) fill_quadrature_points_cache(name) return name -- GitLab