diff --git a/python/dune/perftool/generation/cpp.py b/python/dune/perftool/generation/cpp.py index ea8727be7c700898724cc0e7ad903c0a1a1fdd63..797c29a5db3f84df9db8e95a8dfc3964fe20fe39 100644 --- a/python/dune/perftool/generation/cpp.py +++ b/python/dune/perftool/generation/cpp.py @@ -4,86 +4,45 @@ are commonly needed for code generation """ from dune.perftool.generation import generator_factory -from dune.perftool.cgen.clazz import AccessModifier +from dune.perftool.cgen.clazz import AccessModifier, BaseClass, ClassMember -preamble = generator_factory(item_tags=("preamble",), counted=True) +import cgen -def pre_include(pre, filetag=None, pre_include=True): - assert filetag - gen = generator_factory(item_tags=("file", filetag, "pre_include"), no_deco=True) - return gen(pre) +preamble = generator_factory(item_tags=("preamble",), counted=True, context_tags="kernel") +pre_include = generator_factory(item_tags=("file", "pre_include"), context_tags=("filetag",), no_deco=True) +post_include = generator_factory(item_tags=("file", "post_include"), context_tags=("filetag",), no_deco=True) +class_member = generator_factory(item_tags=("clazz", "member"), context_tags=("classtag",), on_store=lambda m: ClassMember(m), counted=True) +template_parameter = generator_factory(item_tags=("clazz", "template_param"), context_tags=("classtag",), counted=True) +class_basename = generator_factory(item_tags=("clazz", "basename"), context_tags=("classtag",)) +constructor_block = generator_factory(item_tags=("clazz", "constructor_block"), context_tags=("classtag",), counted=True) -def post_include(post, filetag=None, pre_include=True): - assert filetag - gen = generator_factory(item_tags=("file", filetag, "post_include"), no_deco=True) - return gen(post) - - -def include_file(include, filetag=None, system=False): - assert filetag - from cgen import Include - gen = generator_factory(on_store=lambda i: Include(i, system=system), item_tags=("file", filetag, "include"), no_deco=True) - return gen(include) +@generator_factory(item_tags=("file", "include"), context_tags=("filetag",)) +def include_file(include, system=False): + return cgen.Include(include, system=system) +@generator_factory(item_tags=("clazz", "initializer"), counted=True, context_tags=("classtag",), cache_key_generator=lambda o, p: o) def initializer_list(obj, params, classtag=None): - assert classtag - gen = generator_factory(item_tags=("clazz", classtag, "initializer"), counted=True, no_deco=True, cache_key_generator=lambda *a: a[0]) - return gen("{}({})".format(obj, ", ".join(params))) - + return "{}({})".format(obj, ", ".join(params)) -def base_class(baseclass, classtag=None, access=AccessModifier.PUBLIC, construction=[]): - assert classtag - - from dune.perftool.cgen.clazz import BaseClass - gen = generator_factory(item_tags=("clazz", "baseclass", classtag), on_store=lambda n: BaseClass(n, inheritance=access), counted=True, no_deco=True) +@generator_factory(item_tags=("clazz", "baseclass"), context_tags=("classtag",), counted=True) +def base_class(baseclass, access=AccessModifier.PUBLIC, construction=[], **kwargs): if construction: - initializer_list(baseclass, construction, classtag=classtag) - - return gen(baseclass) - - -def class_member(classtag=None, access=AccessModifier.PRIVATE): - assert classtag - from cgen import Value - from dune.perftool.cgen.clazz import ClassMember - - return generator_factory(item_tags=("clazz", classtag, "member"), on_store=lambda m: ClassMember(m, access=access), counted=True) - + initializer_list(baseclass, construction, **kwargs) -def constructor_parameter(_type, name, classtag=None, constructortag="default"): - assert classtag - assert constructortag - from cgen import Value + return BaseClass(baseclass, inheritance=access) - gen = generator_factory(item_tags=("clazz", classtag, constructortag, "constructor_param"), counted=True, no_deco=True) - return gen(Value(_type, name)) - -def template_parameter(classtag=None): - assert classtag - - return generator_factory(item_tags=("clazz", classtag, "template_param"), counted=True) - - -def class_basename(classtag=None): - assert classtag - - return generator_factory(item_tags=("clazz", classtag, "basename")) - - -def constructor_block(classtag=None): - assert classtag - from dune.perftool.generation import generator_factory - return generator_factory(item_tags=("clazz", classtag, "constructor_block"), counted=True) +@generator_factory(item_tags=("clazz", "constructor_param"), context_tags=("classtag",), counted=True) +def constructor_parameter(_type, name): + return cgen.Value(_type, name) +@generator_factory(item_tags=("dump_timers",)) def dump_accumulate_timer(name): - gen = generator_factory(item_tags=("dump_timers"), no_deco=True) - from dune.perftool.pdelab.localoperator import (name_time_dumper_os, name_time_dumper_reset, name_time_dumper_t, @@ -95,4 +54,4 @@ def dump_accumulate_timer(name): counter = name_time_dumper_counter() code = "DUMP_AND_ACCUMULATE_TIMER({},{},{},{},{});".format(name, os, reset, t, counter) - return gen(code) + return code diff --git a/python/dune/perftool/generation/loopy.py b/python/dune/perftool/generation/loopy.py index ef22840214c757dba272b271d867830dda56cabe..b0e7264b0c65f8b8351f087723a8445502e78c10 100644 --- a/python/dune/perftool/generation/loopy.py +++ b/python/dune/perftool/generation/loopy.py @@ -152,7 +152,9 @@ def transform(trafo, *args): return (trafo, args) -@generator_factory(item_tags=("instruction", "barrier"), cache_key_generator=lambda **kw: kw['id']) +@generator_factory(item_tags=("instruction", "barrier"), + context_tags="kernel", + cache_key_generator=lambda **kw: kw['id']) def _barrier(**kwargs): return lp.BarrierInstruction(**kwargs) diff --git a/python/dune/perftool/pdelab/basis.py b/python/dune/perftool/pdelab/basis.py index b0d04c92d24f61ccdf7e542afb2bc1077fe93c29..2364b2bcfb1290c0a842342550b5c4402b9987a2 100644 --- a/python/dune/perftool/pdelab/basis.py +++ b/python/dune/perftool/pdelab/basis.py @@ -39,7 +39,7 @@ def type_localbasis_cache(element): return "LocalBasisCacheWithoutReferences<{}>".format(type_gfs(element)) -@class_member("operator") +@class_member(classtag="operator") def define_localbasis_cache(element, name): include_file("dune/perftool/localbasiscache.hh", filetag="operatorfile") t = type_localbasis_cache(element) diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py index e084853782f0d8c44ab85d984977c18f35a0e753..1c261d94adcb8e374ff3b68cd805aff294ce5d27 100644 --- a/python/dune/perftool/pdelab/localoperator.py +++ b/python/dune/perftool/pdelab/localoperator.py @@ -54,22 +54,22 @@ def name_localoperator_file(formdata, data): return filename -@template_parameter("operator") +@template_parameter(classtag="operator") def lop_template_ansatz_gfs(): return "GFSU" -@template_parameter("operator") +@template_parameter(classtag="operator") def lop_template_test_gfs(): return "GFSV" -@template_parameter("operator") +@template_parameter(classtag="operator") def lop_template_range_field(): return "RF" -@class_member("operator") +@class_member(classtag="operator") def lop_domain_field(name): # TODO: Rethink for not Galerkin Method gfs = lop_template_ansatz_gfs() @@ -100,7 +100,7 @@ def name_initree_constructor_param(): return "iniParams" -@class_member("operator") +@class_member(classtag="operator") def define_initree(name): param_name = name_initree_constructor_param() include_file('dune/common/parametertree.hh', filetag="operatorfile") @@ -117,7 +117,7 @@ def ufl_measure_to_pdelab_measure(which): }.get(which) -@class_member(classtag="operator", access=AccessModifier.PUBLIC) +@class_member(classtag="operator") def _enum_pattern(which): return "enum {{ doPattern{} = true }};".format(which) @@ -138,7 +138,7 @@ def pattern_baseclass(): return _pattern_baseclass(ufl_measure_to_pdelab_measure(integral_type)) -@class_member(classtag="operator", access=AccessModifier.PUBLIC) +@class_member(classtag="operator") def _enum_alpha(which): return "enum {{ doAlpha{} = true }};".format(which) @@ -154,7 +154,7 @@ def name_initree_member(): return "_iniParams" -@class_basename("operator") +@class_basename(classtag="operator") def localoperator_basename(formdata, data): form_name = name_form(formdata, data) return "LocalOperator" + form_name.capitalize() diff --git a/python/dune/perftool/pdelab/parameter.py b/python/dune/perftool/pdelab/parameter.py index 46946434ba75ca652c59fc424dbf522ca2fd1abc..0d3aa6a29d11f8c520687ded1fb839c95797dcd7 100644 --- a/python/dune/perftool/pdelab/parameter.py +++ b/python/dune/perftool/pdelab/parameter.py @@ -25,13 +25,13 @@ from dune.perftool.pdelab.localoperator import (class_type_from_cache, ) -@class_basename("parameterclass") +@class_basename(classtag="parameterclass") def parameterclass_basename(formdata, data): lopbase = localoperator_basename(formdata, data) return "{}Params".format(lopbase) -@class_member("operator") +@class_member(classtag="operator") def define_parameterclass(name): _, t = class_type_from_cache("parameterclass") constructor_parameter("const {}&".format(t), name + "_", classtag="operator") @@ -44,7 +44,7 @@ def name_paramclass(): return "param" -@class_member(classtag="parameterclass", access=AccessModifier.PRIVATE) +@class_member(classtag="parameterclass") def define_time(name): initializer_list(name, ["0.0"], classtag="parameterclass") return "double {};".format(name) @@ -55,7 +55,7 @@ def name_time(): return "t" -@class_member("parameterclass", access=AccessModifier.PUBLIC) +@class_member(classtag="parameterclass") def define_set_time_method(): time_name = name_time() # TODO double? @@ -81,7 +81,7 @@ def component_to_tree_path(element, component): return _flatten(subel) -@class_member("parameterclass", access=AccessModifier.PUBLIC) +@class_member(classtag="parameterclass") def define_parameter_function_class_member(name, expr, baset, shape, cell): t = construct_nested_fieldvector(baset, shape) diff --git a/python/dune/perftool/pdelab/quadrature.py b/python/dune/perftool/pdelab/quadrature.py index 8ee1402c1ef5775982e0085a72b568b70fad9f38..f6be15fd3e5f3e0a64f79e2573d28f2943459a40 100644 --- a/python/dune/perftool/pdelab/quadrature.py +++ b/python/dune/perftool/pdelab/quadrature.py @@ -94,7 +94,7 @@ def fill_quadrature_points_cache(name): return "fillQuadraturePointsCache({}, {}, {});".format(geo, quad_order, name) -@class_member("operator") +@class_member(classtag="operator") def typedef_quadrature_points(name): range_field = lop_template_range_field() dim = _local_dim() @@ -107,7 +107,7 @@ def type_quadrature_points(name): return name -@class_member("operator") +@class_member(classtag="operator") def define_quadrature_points(name): qp_type = type_quadrature_points(name) return "mutable std::vector<{}> {};".format(qp_type, name) @@ -147,7 +147,7 @@ def fill_quadrature_weights_cache(name): return "fillQuadratureWeightsCache({}, {}, {});".format(geo, quad_order, name) -@class_member("operator") +@class_member(classtag="operator") def typedef_quadrature_weights(name): range_field = lop_template_range_field() dim = _local_dim() @@ -166,7 +166,7 @@ def type_quadrature_weights(name): return name -@class_member("operator") +@class_member(classtag="operator") def define_quadrature_weights(name): qw_type = type_quadrature_weights(name) return "mutable std::vector<{}> {};".format(qw_type, name) diff --git a/python/dune/perftool/sumfact/amatrix.py b/python/dune/perftool/sumfact/amatrix.py index 1d3fa814f18dba7b819783cdfdfd27bd5e898433..ec0dace8da0ac18a99e541a0c4b76bf5a2b5fec2 100644 --- a/python/dune/perftool/sumfact/amatrix.py +++ b/python/dune/perftool/sumfact/amatrix.py @@ -58,7 +58,7 @@ def colmajoraccess_mangler(target, func, dtypes): return CallMangleInfo(func.name, (NumpyType(numpy.float64),), (NumpyType(numpy.int32), NumpyType(numpy.int32))) -@class_member("operator") +@class_member(classtag="operator") def define_alignment(name): alignment = get_option("sumfact_alignment") return "enum {{ {} = {} }};".format(name, str(alignment)) @@ -83,7 +83,7 @@ def quadrature_points_per_direction(): return nb_qp -@class_member("operator") +@class_member(classtag="operator") def define_number_of_quadrature_points_per_direction(name): number_qp = quadrature_points_per_direction() return "enum {{ {} = {} }};".format(name, str(number_qp)) @@ -104,7 +104,7 @@ def basis_functions_per_direction(): return polynomial_degree() + 1 -@class_member("operator") +@class_member(classtag="operator") def define_number_of_basis_functions_per_direction(name): number_basis = basis_functions_per_direction() return "enum {{ {} = {} }};".format(name, str(number_basis)) @@ -116,7 +116,7 @@ def name_number_of_basis_functions_per_direction(): return name -@class_member("operator") +@class_member(classtag="operator") def define_oned_quadrature_weights(name): range_field = lop_template_range_field() number_qp = name_number_of_quadrature_points_per_direction() @@ -130,7 +130,7 @@ def name_oned_quadrature_weights(): return name -@class_member("operator") +@class_member(classtag="operator") def define_oned_quadrature_points(name): range_field = lop_template_range_field() number_qp = name_number_of_quadrature_points_per_direction() @@ -144,7 +144,7 @@ def name_oned_quadrature_points(): return name -@class_member("operator") +@class_member(classtag="operator") def typedef_polynomials(name): range_field = lop_template_range_field() domain_field = name_domain_field() @@ -169,7 +169,7 @@ def type_polynomials(): return name -@class_member("operator") +@class_member(classtag="operator") def define_polynomials(name): polynomials_type = type_polynomials() return "{} {};".format(polynomials_type, name) @@ -181,7 +181,7 @@ def name_polynomials(): return name -@constructor_block("operator") +@constructor_block(classtag="operator") def sort_quadrature_points_weights(): range_field = lop_template_range_field() domain_field = name_domain_field() @@ -192,7 +192,7 @@ def sort_quadrature_points_weights(): return "onedQuadraturePointsWeights<{}, {}, {}>({}, {});".format(range_field, domain_field, number_qp, qp, qw) -@constructor_block("operator") +@constructor_block(classtag="operator") def construct_theta(name, transpose, derivative): # Make sure that the quadrature points are sorted sort_quadrature_points_weights() @@ -215,7 +215,7 @@ def construct_theta(name, transpose, derivative): "}"] -@class_member("operator") +@class_member(classtag="operator") def typedef_theta(name): include_file("dune/perftool/sumfact/alignedmatvec.hh", filetag="operatorfile") alignment = name_alignment() @@ -229,7 +229,7 @@ def type_theta(): return name -@class_member("operator") +@class_member(classtag="operator") def define_theta(name, shape, transpose, derivative): theta_type = type_theta() initializer_list(name, [str(axis) for axis in shape], classtag="operator")