Skip to content
Snippets Groups Projects
Commit 3cbaeb7f authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Reimplement instrumentation as a loopy transformation

This is currently only used for level 4 instrumentation,
as these start and stop markers were messing heavily with
dependencies in the sum factorization code. This way, the necessary but
ugly work of adding instrumentation is separated much more
nicely.
parent 9f3e977d
No related branches found
No related tags found
No related merge requests found
""" Add instrumentation instructions to a kernel """
from dune.perftool.generation import (dump_accumulate_timer,
post_include,
)
from dune.perftool.options import get_option
import loopy as lp
def _intersect(a):
""" Return intersection of a given tuple of frozensets. Also works for empty tuple """
if len(a) == 0:
return frozenset()
return frozenset.intersection(*a)
def _union(a):
""" Return union of a given tuple of frozensets. Also works for empty tuple """
if len(a) == 0:
return frozenset()
return frozenset.union(*a)
def add_instrumentation(knl, match, identifier, level, filetag='operatorfile', operator=False):
""" Transform loopy kernel to contain instrumentation code
Arguments:
knl : The loopy kernel, follows the loopy transformation convention
match : A loopy match object or a string (interpreted as instruction ID or tag) to describe
which instructions should be wrapped in an instrumentation block.
identifier : The name of the counter to start and stop
level : The instrumentation level this measurement is defined at
filetag : The tag of the file that should contain the counter definitions
"""
# If the instrumentation level is not high enough, this is a no-op
if level > get_option("instrumentation_level"):
return knl
# If a string was given for match, heuristically make it a match object
if isinstance(match, str):
match = lp.match.Or((lp.match.Id(match), lp.match.Tagged(match)))
# Find the instructions to wrap in instrumentation
insns = lp.find_instructions(knl, match)
rewritten_insns = []
# If the match is empty, this is also no op
if not insns:
return knl
# Determine the iname nesting of the timing block
insn_inames = _intersect(tuple(i.within_inames for i in insns))
other_inames = _union(tuple(i.within_inames for i in lp.find_instructions(knl, lp.match.Not(match))))
within = _intersect((insn_inames, other_inames))
# Get a unique identifer - note that the same timer could be started and stopped several times
# within one kernel...
ident = identifier
if lp.find_instructions(knl, lp.match.Id("{}_start".format(identifier))):
ident = "{}_".format(ident)
# Define the start instruction and correct dependencies for it
start_id = "{}_start".format(ident)
start_depends = _union(tuple(i.depends_on for i in insns)).difference(frozenset(i.id for i in insns))
start_insn = lp.CInstruction([],
"HP_TIMER_START({});".format(identifier),
id=start_id,
within_inames=within,
depends_on=start_depends,
boostable_into=frozenset(),
)
# Add dependencies on the timing instructions
rewritten_insns.extend([i.copy(depends_on=i.depends_on.union(frozenset({start_id}))) for i in insns])
# Define the stop instruction and correct dependencies for it
stop_id = "{}_stop".format(ident)
stop_insn = lp.CInstruction([],
"HP_TIMER_STOP({});".format(identifier),
id=stop_id,
within_inames=within,
depends_on=frozenset(i.id for i in insns),
boostable_into=frozenset(),
)
# Find all the instructions that should depend on stop
dep_insns = filter(lambda i: _intersect((i.depends_on, frozenset(i.id for i in insns))),
lp.find_instructions(knl, lp.match.Not(match))
)
rewritten_insns.extend([i.copy(depends_on=i.depends_on.union(frozenset({stop_id}))) for i in dep_insns])
# Trigger code generation on the file/operator level
post_include('HP_DECLARE_TIMER({});'.format(identifier), filetag=filetag)
dump_accumulate_timer(identifier)
# Filter all the instructions which were untouched
other_insns = list(filter(lambda i: i.id not in [j.id for j in rewritten_insns], knl.instructions))
# Add all the modified instructions into the kernel object
return knl.copy(instructions=rewritten_insns + other_insns + [start_insn, stop_insn])
......@@ -206,6 +206,7 @@ def _vectorize_quadrature_loop(knl, inames, suffix):
within_inames=common_inames.union(frozenset({outer_iname, vec_iname})),
within_inames_is_final=True,
id="{}_rotate{}".format(quantity, suffix),
tags=frozenset({"sumfact_stage2"}),
))
# Add substitution rules
......@@ -267,7 +268,7 @@ def _vectorize_quadrature_loop(knl, inames, suffix):
within_inames=common_inames.union(frozenset({outer_iname, vec_iname})),
within_inames_is_final=True,
id=insn.id,
tags=frozenset({"vec_write{}".format(suffix)})
tags=frozenset({"vec_write{}".format(suffix), "sumfact_stage2"})
)
)
......@@ -282,6 +283,7 @@ def _vectorize_quadrature_loop(knl, inames, suffix):
within_inames=common_inames.union(frozenset({outer_iname, vec_iname})),
within_inames_is_final=True,
id="{}_rotateback{}".format(lhsname, suffix),
tags=frozenset({"sumfact_stage2"}),
))
# Add the necessary vector indices
......@@ -296,6 +298,7 @@ def _vectorize_quadrature_loop(knl, inames, suffix):
within_inames=common_inames,
within_inames_is_final=True,
id="assign_{}{}".format(name, suffix),
tags=frozenset({"sumfact_stage2"}),
))
new_insns.append(lp.Assignment(prim.Variable(name), # assignee
prim.Sum((prim.Variable(name), increment)), # expression
......@@ -304,6 +307,7 @@ def _vectorize_quadrature_loop(knl, inames, suffix):
depends_on=frozenset({Tagged("vec_write{}".format(suffix)), "assign_{}{}".format(name, suffix)}),
depends_on_is_final=True,
id="update_{}{}".format(name, suffix),
tags=frozenset({"sumfact_stage2"}),
))
from loopy.kernel.creation import resolve_dependencies
......
......@@ -562,6 +562,14 @@ def extract_kernel_from_cache(tag, name, signature, wrap_in_cgen=True, add_timin
from dune.perftool.loopy.transformations.matchfma import match_fused_multiply_add
kernel = match_fused_multiply_add(kernel)
# Add instrumentation to the kernel
from dune.perftool.loopy.transformations.instrumentation import add_instrumentation
if add_timings and get_form_option("sumfact"):
from dune.perftool.pdelab.signatures import assembler_routine_name
kernel = add_instrumentation(kernel, lp.match.Tagged("sumfact_stage1"), "{}_kernel_stage1".format(assembler_routine_name()), 4)
kernel = add_instrumentation(kernel, lp.match.Tagged("sumfact_stage2"), "{}_kernel_quadratureloop".format(assembler_routine_name()), 4)
kernel = add_instrumentation(kernel, lp.match.Tagged("sumfact_stage3"), "{}_kernel_stage3".format(assembler_routine_name()), 4)
if wrap_in_cgen:
# Wrap the kernel in something which can generate code
if signature is None:
......@@ -668,7 +676,7 @@ def cgen_class_from_cache(tag, members=[]):
tparams = [i for i in retrieve_cache_items('{} and template_param'.format(tag))]
# Construct the constructor
constructor_knl = extract_kernel_from_cache(tag, "constructor_kernel", None, wrap_in_cgen=False)
constructor_knl = extract_kernel_from_cache(tag, "constructor_kernel", None, wrap_in_cgen=False, add_timings=False)
from dune.perftool.loopy.target import DuneTarget
constructor_knl = constructor_knl.copy(target=DuneTarget(declare_temporaries=False))
signature = "{}({})".format(basename, ", ".join(next(iter(p.generate(with_semicolon=False))) for p in constructor_params))
......
......@@ -475,23 +475,6 @@ def generate_accumulation_instruction(expr, visitor):
tags=frozenset(["quadvec", "gradvec"]),
)
# Write timing stuff for jacobian (for alpha methods it is done at the end of stage 1)
timer_dep = frozenset()
if get_option("instrumentation_level") >= 4:
timer_name = "{}_kernel_stage1".format(assembler_routine_name())
timer_dep = frozenset({instruction(code="HP_TIMER_STOP({});".format(timer_name),
depends_on=frozenset({lp.match.Tagged("sumfact_stage1"), 'hptimerstart_{}'.format(timer_name)}),
id="hptimerstop_{}".format(timer_name)
)}
)
timer_name = '{}_kernel_quadratureloop'.format(assembler_routine_name())
post_include('HP_DECLARE_TIMER({});'.format(timer_name), filetag='operatorfile')
dump_accumulate_timer(timer_name)
timer_dep = frozenset({instruction(code="HP_TIMER_START({});".format(timer_name),
within_inames=frozenset(jacobian_inames),
id="hptimerstart_{}".format(timer_name),
depends_on=timer_dep)})
# Determine dependencies
from loopy.match import Or, Writes
from loopy.symbolic import DependencyMapper
......@@ -506,19 +489,13 @@ def generate_accumulation_instruction(expr, visitor):
expression=expr,
forced_iname_deps=frozenset(quadrature_inames(trial_leaf_element) + jacobian_inames),
forced_iname_deps_is_final=True,
tags=frozenset({"quadvec"}).union(vectag),
depends_on=frozenset({deps}).union(timer_dep).union(frozenset({lp.match.Tagged("sumfact_stage1")})),
tags=frozenset({"quadvec", "sumfact_stage2"}).union(vectag),
depends_on=frozenset({deps}).union(frozenset({lp.match.Tagged("sumfact_stage1")})),
)
if insn_dep is None:
insn_dep = frozenset({contrib_dep})
if get_option("instrumentation_level") >= 4:
insn_dep = insn_dep.union(frozenset({instruction(code="HP_TIMER_STOP({});".format(timer_name),
depends_on=insn_dep,
within_inames=frozenset(jacobian_inames),
id="hptimerstop_{}".format(timer_name))}))
# Add a sum factorization kernel that implements the multiplication
# with the test function (stage 3)
from dune.perftool.sumfact.realization import realize_sum_factorization_kernel
......@@ -526,11 +503,3 @@ def generate_accumulation_instruction(expr, visitor):
if not get_form_option("fastdg"):
insn_dep = vsf.interface.realize(vsf, result, insn_dep)
if get_option("instrumentation_level") >= 4:
assert vsf.stage == 3
timer_name = '{}_kernel_stage{}'.format(assembler_routine_name(), vsf.stage)
insn_dep = frozenset({instruction(code="HP_TIMER_STOP({});".format(timer_name),
depends_on=frozenset({lp.match.Tagged("sumfact_stage3")}),
within_inames=frozenset(jacobian_inames),
id="hptimerstop_{}".format(timer_name))})
......@@ -67,32 +67,6 @@ def name_buffer_storage(buff, which):
def _realize_sum_factorization_kernel(sf):
insn_dep = sf.insn_dep
# Measure times and count operations in c++ code
if get_option("instrumentation_level") >= 4:
setuptimer = '{}_kernel_setup'.format(assembler_routine_name())
timer_dep = frozenset({instruction(code='HP_TIMER_STOP({});'.format(setuptimer),
id="hptimerstop_{}".format(setuptimer))})
timer_name = '{}_kernel_stage1'.format(assembler_routine_name())
post_include('HP_DECLARE_TIMER({});'.format(timer_name), filetag='operatorfile')
dump_accumulate_timer(timer_name)
timer_dep = timer_dep.union(frozenset({instruction(code="HP_TIMER_START({});".format(timer_name),
id="hptimerstart_{}".format(timer_name),
depends_on=timer_dep,
),
}))
timer_name = '{}_kernel_stage{}'.format(assembler_routine_name(), sf.stage)
post_include('HP_DECLARE_TIMER({});'.format(timer_name), filetag='operatorfile')
dump_accumulate_timer(timer_name)
timer_dep = timer_dep.union(frozenset({instruction(code="HP_TIMER_START({});".format(timer_name),
id="hptimerstart_{}".format(timer_name),
within_inames=frozenset(sf.within_inames),
depends_on=timer_dep.union(insn_dep),
),
}))
insn_dep = insn_dep.union(timer_dep)
# Get all the necessary pieces for a function call
buffers = tuple(name_buffer_storage(sf.buffer, i) for i in range(2))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment