diff --git a/python/dune/perftool/cgen/clazz.py b/python/dune/perftool/cgen/clazz.py index f265638ec68ba122cd209909b7c2704cd6992aaf..dca212eca2aecee421bee757c900f2f1778949d8 100644 --- a/python/dune/perftool/cgen/clazz.py +++ b/python/dune/perftool/cgen/clazz.py @@ -72,7 +72,7 @@ class Class(Generable): for bc in base_classes: assert isinstance(bc, BaseClass) for mem in members: - assert isinstance(mem, ClassMember) + assert isinstance(mem, Generable) def generate(self): # define the class header diff --git a/python/dune/perftool/generation/__init__.py b/python/dune/perftool/generation/__init__.py index 72fad9024cc46d3084b8318a2ff7db4c0c91cc52..89a98fd7d9ead131f628d92edc0e47d97d521660 100644 --- a/python/dune/perftool/generation/__init__.py +++ b/python/dune/perftool/generation/__init__.py @@ -36,6 +36,7 @@ from dune.perftool.generation.loopy import (barrier, globalarg, iname, instruction, + loopy_class_member, kernel_cached, noop_instruction, silenced_warning, diff --git a/python/dune/perftool/generation/loopy.py b/python/dune/perftool/generation/loopy.py index c61fec6a21dbe47e9a4147008ae616471af1a5bc..16c9d4f54a4900164d4e3389cb78eeeadfe0a70a 100644 --- a/python/dune/perftool/generation/loopy.py +++ b/python/dune/perftool/generation/loopy.py @@ -164,3 +164,17 @@ def barrier(**kwargs): name = 'barrier_{}'.format(get_counter('barrier')) _barrier(id=name, **kwargs) return name + + +def loopy_class_member(name, classtag=None, **kwargs): + """ A class member is based on loopy! It is an + * temporary variable of the constructor kernel + * A globalarg of the requesting kernel (to make things pass) + """ + assert classtag + temporary_variable(name, kernel=classtag, **kwargs) + silenced_warning("read_no_write({})".format(name), kernel=classtag) + + kwargs.pop("decl_method", None) + # TODO I guess some filtering has to be applied here. + globalarg(name, **kwargs) \ No newline at end of file diff --git a/python/dune/perftool/loopy/target.py b/python/dune/perftool/loopy/target.py index 57ff1d237c48c0c9ed27ea36698410381d1aa718..f3b7d3e39d2721de7c78f13581f2a2999d646a4b 100644 --- a/python/dune/perftool/loopy/target.py +++ b/python/dune/perftool/loopy/target.py @@ -115,11 +115,18 @@ class DuneASTBuilder(CASTBuilder): post_include("#define BARRIER asm volatile(\"\": : :\"memory\")", filetag="operatorfile") return cgen.Line("BARRIER;") + def get_temporary_decls(self, codegen_state, schedule_index): + if self.target.declare_temporaries: + return CASTBuilder.get_temporary_decls(self, codegen_state, schedule_index) + else: + return [] + class DuneTarget(TargetBase): - def __init__(self): + def __init__(self, declare_temporaries=True): # Set fortran_abi to allow reusing CASTBuilder for the moment self.fortran_abi = False + self.declare_temporaries = declare_temporaries def split_kernel_at_global_barriers(self): return False diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py index 0234b4221c30190e983a2859ff857e4642273fdf..30454153b1e659531ee388cc9df49767229311e9 100644 --- a/python/dune/perftool/pdelab/localoperator.py +++ b/python/dune/perftool/pdelab/localoperator.py @@ -644,11 +644,29 @@ def cgen_class_from_cache(tag, members=[]): # Construct the constructor constructor_knl = extract_kernel_from_cache(tag) + from dune.perftool.loopy.target import DuneTarget + constructor_knl = constructor_knl.copy(target=DuneTarget(declare_temporaries=False)) signature = "{}({})".format(basename, ", ".join(next(iter(p.generate(with_semicolon=False))) for p in constructor_params)) constructor = LoopyKernelMethod([signature], constructor_knl, add_timings=False, initializer_list=il) + # Take any temporary declarations from the kernel and make them class members + target = DuneTarget() + from loopy.codegen import CodeGenerationState + codegen_state = CodeGenerationState(kernel=constructor_knl, + implemented_data_info=None, + implemented_domain=None, + implemented_predicates=frozenset(), + seen_dtypes=frozenset(), + seen_functions=frozenset(), + seen_atomic_dtypes=frozenset(), + var_subst_map={}, + allow_complex=False, + is_generating_device_code=True, + ) + decls = target.get_device_ast_builder().get_temporary_decls(codegen_state, 0) + from dune.perftool.cgen import Class - return Class(basename, base_classes=base_classes, members=[constructor] + members + pm, tparam_decls=tparams) + return Class(basename, base_classes=base_classes, members=[constructor] + members + pm + decls, tparam_decls=tparams) def generate_localoperator_kernels(formdata, data): diff --git a/python/dune/perftool/sumfact/amatrix.py b/python/dune/perftool/sumfact/amatrix.py index 387fb400fd186e593fb36954c1582e0ae781d65d..e919951b775ad640281a682216e709ef0de4b621 100644 --- a/python/dune/perftool/sumfact/amatrix.py +++ b/python/dune/perftool/sumfact/amatrix.py @@ -13,6 +13,7 @@ from dune.perftool.generation import (class_member, include_file, initializer_list, instruction, + loopy_class_member, preamble, silenced_warning, temporary_variable, @@ -93,30 +94,30 @@ def basis_functions_per_direction(): return polynomial_degree() + 1 -@class_member(classtag="operator") def define_oned_quadrature_weights(name): - range_field = lop_template_range_field() - number_qp = quadrature_points_per_direction() - return "{} {}[{}];".format(range_field, name, number_qp) + loopy_class_member(name, + dtype=numpy.float64, + classtag="operator", + shape=(quadrature_points_per_direction(),), + ) def name_oned_quadrature_weights(): name = "qw" - globalarg(name, shape=(quadrature_points_per_direction(),), dtype=NumpyType(numpy.float64)) define_oned_quadrature_weights(name) return name -@class_member(classtag="operator") def define_oned_quadrature_points(name): - range_field = lop_template_range_field() - number_qp = quadrature_points_per_direction() - return "{} {}[{}];".format(range_field, name, number_qp) + loopy_class_member(name, + dtype=numpy.float64, + classtag="operator", + shape=(quadrature_points_per_direction(),), + ) def name_oned_quadrature_points(): name = "qp" - globalarg(name, shape=(quadrature_points_per_direction(),), dtype=NumpyType(numpy.float64)) define_oned_quadrature_points(name) return name