Skip to content
Snippets Groups Projects
Commit c722dffb authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Introduce a stage mechanism and use it in sum factorization

Needs a small fix in loopy (MR opened).
parent 191729fd
No related branches found
No related tags found
No related merge requests found
......@@ -32,6 +32,7 @@ from dune.perftool.generation.loopy import (constantarg,
globalarg,
iname,
instruction,
noop_instruction,
pymbolic_expr,
silenced_warning,
temporary_variable,
......
......@@ -196,3 +196,8 @@ def instruction(code=None, expression=None, **kwargs):
# return the ID, as it is the only useful information to the user
return id
@generator_factory(item_tags=("instruction",), cache_key_generator=lambda **kw: kw['id'])
def noop_instruction(**kwargs):
return loopy.NoOpInstruction(**kwargs)
\ No newline at end of file
""" loopy instructions to mark stages of computations """
from dune.perftool.generation import noop_instruction
def stage_insn(n, **kwargs):
assert 'id' not in kwargs
# Get an ID for this instruction
id = 'stage_insn_{}'.format(n)
# Chain dependencies of stage instructions
if n > 0:
kwargs['depends_on'] = kwargs.get('depends_on', frozenset([])).union(frozenset([stage_insn(n-1, **kwargs)]))
# Actually issue the instruction
noop_instruction(id=id, **kwargs)
return id
......@@ -59,7 +59,7 @@ def start_sumfactorization():
return sum_factorization_kernel(a_matrices, inp, "buffer")
def sum_factorization_kernel(a_matrices, inp, buffer):
def sum_factorization_kernel(a_matrices, inp, buffer, stage=0):
"""
Calculate a sum factorization matrix product.
......@@ -76,6 +76,10 @@ def sum_factorization_kernel(a_matrices, inp, buffer):
buffer: A string identifying the flip flop buffer in use
for intermediate results.
"""
# Get the stage instruction
from dune.perftool.loopy.stages import stage_insn
insn_dep = stage_insn(stage)
for l, a_matrix in enumerate(a_matrices):
# Get a temporary that interprets the base storage of the input
# as a column-major matrix. In later iteration of the amatrix loop
......@@ -108,10 +112,12 @@ def sum_factorization_kernel(a_matrices, inp, buffer):
))
# Issue the reduction instruction that implements the multiplication
instruction(assignee=Subscript(Variable(out), (Variable(i), Variable(j))),
expression=Reduction("sum", k, prod),
forced_iname_deps=frozenset({i, j}),
forced_iname_deps_is_final=True,
)
# at the same time store the instruction ID for the next instruction to depend on
insn_dep = instruction(assignee=Subscript(Variable(out), (Variable(i), Variable(j))),
expression=Reduction("sum", k, prod),
forced_iname_deps=frozenset({i, j}),
forced_iname_deps_is_final=True,
depends_on=frozenset({insn_dep}),
)
return out
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment