Skip to content
Snippets Groups Projects
Commit 989e0163 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Make vectorization opportunities deterministic again

By implementing repr such that instance addresses are
not part of sorting anymore
parent bc78c2ba
No related branches found
No related tags found
No related merge requests found
......@@ -104,6 +104,9 @@ class AccumulationOutput(SumfactKernelOutputBase, ImmutableRecord):
trial_element_index=trial_element_index,
)
def __repr__(self):
return ImmutableRecord.__repr__(self)
@property
def within_inames(self):
if self.trial_element is None:
......
......@@ -33,12 +33,18 @@ class SumfactKernelInputBase(object):
def realize_direct(self, inames):
raise NotImplementedError
def __repr__(self):
return "SumfactKernelInputBase()"
class VectorSumfactKernelInput(SumfactKernelInputBase):
def __init__(self, inputs):
assert(isinstance(inputs, tuple))
self.inputs = inputs
def __repr__(self):
return "_".join(repr(i) for i in self.inputs)
@property
def direct_input_is_possible(self):
return all(i.direct_input_is_possible for i in self.inputs)
......@@ -87,11 +93,17 @@ class SumfactKernelOutputBase(object):
def realize_direct(self, result, inames, shape, args):
raise NotImplementedError
def __repr__(self):
return "SumfactKernelOutputBase()"
class VectorSumfactKernelOutput(SumfactKernelOutputBase):
def __init__(self, outputs):
self.outputs = outputs
def __repr__(self):
return "_".join(repr(o) for o in self.outputs)
def _add_hadd(self, o, result):
hadd_function = "horizontal_add"
if len(set(self.outputs)) > 1:
......@@ -293,7 +305,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
work on the same input coefficient (stage 1) or accumulate
into the same thing (stage 3)
"""
return (self.input, self.output)
return (repr(self.input), repr(self.output))
@property
def group_name(self):
......
......@@ -285,7 +285,8 @@ def _level2_optimal_vectorization_strategy_generator(sumfacts, width, qp, alread
# Find the number of input coefficients we can work on
keys = frozenset(sf.inout_key for sf in sumfacts)
inoutkey_sumfacts = [tuple(sorted(filter(lambda sf: sf.inout_key == key, sumfacts))) for key in keys]
inoutkey_sumfacts = [tuple(sorted(filter(lambda sf: sf.inout_key == key, sumfacts))) for key in sorted(keys)]
for parallel in (1, 2):
if parallel > len(keys):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment