Skip to content
Snippets Groups Projects
Commit f786d623 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Also have instruciton dependencies on the sumfact node

parent 81aa4458
No related branches found
No related tags found
No related merge requests found
......@@ -27,6 +27,7 @@ class SumfactKernel(ImmutableRecord, prim.Variable):
input=None,
padding=frozenset(),
index=None,
insn_dep=frozenset(),
):
# Check the input and apply defaults where necessary
assert isinstance(a_matrices, tuple)
......@@ -43,6 +44,8 @@ class SumfactKernel(ImmutableRecord, prim.Variable):
assert isinstance(within_inames, tuple)
assert isinstance(insn_dep, frozenset)
ImmutableRecord.__init__(self,
a_matrices=a_matrices,
buffer=buffer,
......@@ -53,6 +56,7 @@ class SumfactKernel(ImmutableRecord, prim.Variable):
input=input,
padding=padding,
index=index,
insn_dep=insn_dep,
)
prim.Variable.__init__(self, "SUMFACT")
......@@ -61,12 +65,12 @@ class SumfactKernel(ImmutableRecord, prim.Variable):
# The methods/fields needed to get a well-formed pymbolic node
#
def __getinitargs__(self):
return (self.a_matrices, self.buffer, self.stage, self.preferred_position, self.restriction, self.within_inames, self.input, self.padding, self.index)
return (self.a_matrices, self.buffer, self.stage, self.preferred_position, self.restriction, self.within_inames, self.input, self.padding, self.index, self.insn_dep)
def stringifier(self):
return lp.symbolic.StringifyMapper
init_arg_names = ("a_matrices", "buffer", "stage", "preferred_position", "restriction", "within_inames", "input", "padding", "index")
init_arg_names = ("a_matrices", "buffer", "stage", "preferred_position", "restriction", "within_inames", "input", "padding", "index", "insn_dep")
mapper_method = "map_sumfact_kernel"
......
......@@ -50,6 +50,10 @@ def realize_sum_factorization_kernel(sf, insn_dep=frozenset(), outshape=None, di
# # attached in dune.perftool.sumfact.vectorization
# sf = attach_vectorization_info(sf)
# Get the instruction dependencies of the sumfact kernel. This variable will be
# updated throughout this function.
insn_dep = insn_dep.union(sf.insn_dep)
# Prepare some dim_tags/shapes for later use
ftags = ",".join(["f"] * sf.length)
novec_ftags = ftags
......
......@@ -84,7 +84,9 @@ def decide_stage_vectorization_strategy(sumfacts, stage, restriction):
buffer=buf,
input=inp,
index=position_mapping[sumf],
padding=frozenset(available))
padding=frozenset(available),
insn_dep=frozenset().union(sf.insn_dep for sf in stage_sumfacts),
)
)
else:
# Disable vectorization strategy
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment