diff --git a/python/dune/perftool/generation/__init__.py b/python/dune/perftool/generation/__init__.py index ea94e8258aed4595a900ebe64bebf27bdd721780..05bba36afa71ed428da7371fed5ee624e41a2638 100644 --- a/python/dune/perftool/generation/__init__.py +++ b/python/dune/perftool/generation/__init__.py @@ -39,6 +39,7 @@ from dune.perftool.generation.loopy import (constantarg, noop_instruction, silenced_warning, temporary_variable, + transform, valuearg, ) diff --git a/python/dune/perftool/generation/loopy.py b/python/dune/perftool/generation/loopy.py index b3fda210635a08f0c5ca56f115d08d8b1e141c8b..c5932bfa5b110457135977f87e014b4240820f9d 100644 --- a/python/dune/perftool/generation/loopy.py +++ b/python/dune/perftool/generation/loopy.py @@ -124,3 +124,10 @@ def instruction(code=None, expression=None, **kwargs): @generator_factory(item_tags=("instruction",), cache_key_generator=lambda **kw: kw['id']) def noop_instruction(**kwargs): return loopy.NoOpInstruction(**kwargs) + + +@generator_factory(item_tags=("transformation",), + cache_key_generator=no_caching, + ) +def transform(transform, *args): + return (transform, args) diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py index fca6efacaeaff1d3cabc9aedfdf922db199915ec..dcf26e15a29404dde2c997b9ba2a9cc90c7b08ae 100644 --- a/python/dune/perftool/pdelab/localoperator.py +++ b/python/dune/perftool/pdelab/localoperator.py @@ -482,6 +482,7 @@ def generate_kernel(integrals): arguments = [i for i in retrieve_cache_items("argument")] manglers = retrieve_cache_functions("mangler") silenced = [l for l in retrieve_cache_items("silenced_warning")] + transformations = [t for t in retrieve_cache_items("transformation")] # Construct an options object from loopy import Options @@ -502,6 +503,10 @@ def generate_kernel(integrals): from loopy import make_reduction_inames_unique kernel = make_reduction_inames_unique(kernel) + # Apply the transformations that were gathered during tree traversals + for trafo in transformations: + kernel = trafo[0](kernel, *trafo[1]) + kernel = preprocess_kernel(kernel) from dune.perftool.loopy.duplicate import heuristic_duplication diff --git a/python/dune/perftool/sumfact/quadrature.py b/python/dune/perftool/sumfact/quadrature.py index 408c9c0b03245fd73184d9c71268e80fb2cddfc2..4949fac418acbc31bda0759baa159f4364a00d84 100644 --- a/python/dune/perftool/sumfact/quadrature.py +++ b/python/dune/perftool/sumfact/quadrature.py @@ -27,6 +27,18 @@ from pymbolic.primitives import (Call, import numpy +def nest_quadrature_loops(kernel, inames): + from loopy import find_instructions + insns = [] + for insn in find_instructions(kernel, "tag:quad"): + insns.append(insn.copy(within_inames=insn.within_inames.union(frozenset(inames)), + tags=insn.tags - frozenset({"quad"}) + ) + ) + + return kernel.copy(instructions=insns) + + class BaseWeight(FunctionIdentifier): def __init__(self, accumobj): self.accumobj = accumobj @@ -70,6 +82,7 @@ def define_recursive_quadrature_weight(name, dir): assignee=Variable(name), forced_iname_deps=frozenset(quadrature_inames()[dir:]), forced_iname_deps_is_final=True, + tags=frozenset({"quad"}), ) @@ -96,6 +109,7 @@ def define_quadrature_position(name): assignee=Subscript(Variable(name), (i,)), forced_iname_deps=frozenset(quadrature_inames()), forced_iname_deps_is_final=True, + tags=frozenset({"quad"}), ) diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py index 3d52a87915999f32fb8033c1f975cc7097c34dc2..62326d19d7fa7302120f7bcd0efb2d096a38107d 100644 --- a/python/dune/perftool/sumfact/sumfact.py +++ b/python/dune/perftool/sumfact/sumfact.py @@ -6,16 +6,19 @@ from dune.perftool.generation import (backend, domain, function_mangler, get_counter, + get_global_context_value, globalarg, iname, instruction, silenced_warning, temporary_variable, + transform, ) from dune.perftool.loopy.buffer import (get_buffer_temporary, initialize_buffer, switch_base_storage, ) +from dune.perftool.sumfact.quadrature import nest_quadrature_loops from dune.perftool.pdelab.spaces import name_lfs from dune.perftool.sumfact.amatrix import (AMatrix, quadrature_points_per_direction, @@ -127,6 +130,9 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): depends_on=frozenset({stage_insn(3)}), ) + # Mark the transformation that moves the quadrature loop inside the trialfunction loops for application + transform(nest_quadrature_loops, visitor.inames) + def sum_factorization_kernel(a_matrices, buffer, stage, insn_dep=frozenset({}), additional_inames=frozenset({})): """