Skip to content
Snippets Groups Projects
Commit 1b3f71af authored by Dominic Kempf's avatar Dominic Kempf
Browse files

A 'transform' generator for delayed kernel transformation application

parent 3c269b93
No related branches found
No related tags found
No related merge requests found
......@@ -39,6 +39,7 @@ from dune.perftool.generation.loopy import (constantarg,
noop_instruction,
silenced_warning,
temporary_variable,
transform,
valuearg,
)
......
......@@ -124,3 +124,10 @@ def instruction(code=None, expression=None, **kwargs):
@generator_factory(item_tags=("instruction",), cache_key_generator=lambda **kw: kw['id'])
def noop_instruction(**kwargs):
return loopy.NoOpInstruction(**kwargs)
@generator_factory(item_tags=("transformation",),
cache_key_generator=no_caching,
)
def transform(transform, *args):
return (transform, args)
......@@ -482,6 +482,7 @@ def generate_kernel(integrals):
arguments = [i for i in retrieve_cache_items("argument")]
manglers = retrieve_cache_functions("mangler")
silenced = [l for l in retrieve_cache_items("silenced_warning")]
transformations = [t for t in retrieve_cache_items("transformation")]
# Construct an options object
from loopy import Options
......@@ -502,6 +503,10 @@ def generate_kernel(integrals):
from loopy import make_reduction_inames_unique
kernel = make_reduction_inames_unique(kernel)
# Apply the transformations that were gathered during tree traversals
for trafo in transformations:
kernel = trafo[0](kernel, *trafo[1])
kernel = preprocess_kernel(kernel)
from dune.perftool.loopy.duplicate import heuristic_duplication
......
......@@ -27,6 +27,18 @@ from pymbolic.primitives import (Call,
import numpy
def nest_quadrature_loops(kernel, inames):
from loopy import find_instructions
insns = []
for insn in find_instructions(kernel, "tag:quad"):
insns.append(insn.copy(within_inames=insn.within_inames.union(frozenset(inames)),
tags=insn.tags - frozenset({"quad"})
)
)
return kernel.copy(instructions=insns)
class BaseWeight(FunctionIdentifier):
def __init__(self, accumobj):
self.accumobj = accumobj
......@@ -70,6 +82,7 @@ def define_recursive_quadrature_weight(name, dir):
assignee=Variable(name),
forced_iname_deps=frozenset(quadrature_inames()[dir:]),
forced_iname_deps_is_final=True,
tags=frozenset({"quad"}),
)
......@@ -96,6 +109,7 @@ def define_quadrature_position(name):
assignee=Subscript(Variable(name), (i,)),
forced_iname_deps=frozenset(quadrature_inames()),
forced_iname_deps_is_final=True,
tags=frozenset({"quad"}),
)
......
......@@ -6,16 +6,19 @@ from dune.perftool.generation import (backend,
domain,
function_mangler,
get_counter,
get_global_context_value,
globalarg,
iname,
instruction,
silenced_warning,
temporary_variable,
transform,
)
from dune.perftool.loopy.buffer import (get_buffer_temporary,
initialize_buffer,
switch_base_storage,
)
from dune.perftool.sumfact.quadrature import nest_quadrature_loops
from dune.perftool.pdelab.spaces import name_lfs
from dune.perftool.sumfact.amatrix import (AMatrix,
quadrature_points_per_direction,
......@@ -127,6 +130,9 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
depends_on=frozenset({stage_insn(3)}),
)
# Mark the transformation that moves the quadrature loop inside the trialfunction loops for application
transform(nest_quadrature_loops, visitor.inames)
def sum_factorization_kernel(a_matrices, buffer, stage, insn_dep=frozenset({}), additional_inames=frozenset({})):
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment