Skip to content
Snippets Groups Projects
Commit 7c6bbade authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Introduce an output object into the symbolic sum fact kernel representation

It does not yet do all the jobs it can do, but it exists and we can
run tests with it.
parent 79450475
No related branches found
No related tags found
No related merge requests found
...@@ -32,7 +32,7 @@ from dune.perftool.sumfact.tabulation import (basis_functions_per_direction, ...@@ -32,7 +32,7 @@ from dune.perftool.sumfact.tabulation import (basis_functions_per_direction,
from dune.perftool.sumfact.switch import (get_facedir, from dune.perftool.sumfact.switch import (get_facedir,
get_facemod, get_facemod,
) )
from dune.perftool.sumfact.symbolic import SumfactKernel, SumfactKernelInputBase from dune.perftool.sumfact.symbolic import SumfactKernel, SumfactKernelOutputBase
from dune.perftool.ufl.modified_terminals import extract_modified_arguments from dune.perftool.ufl.modified_terminals import extract_modified_arguments
from dune.perftool.tools import get_pymbolic_basename from dune.perftool.tools import get_pymbolic_basename
from dune.perftool.error import PerftoolError from dune.perftool.error import PerftoolError
...@@ -81,21 +81,35 @@ def accum_iname(element, bound, i): ...@@ -81,21 +81,35 @@ def accum_iname(element, bound, i):
return sumfact_iname(bound, "accum{}".format(suffix)) return sumfact_iname(bound, "accum{}".format(suffix))
class AlreadyAssembledInput(SumfactKernelInputBase): class AccumulationOutput(SumfactKernelOutputBase, ImmutableRecord):
def __init__(self, index): def __init__(self,
self.index = index accumvar=None,
restriction=None,
def __eq__(self, other): test_element=None,
return type(self) == type(other) and self.index == other.index test_element_index=None,
trial_element=None,
def __repr__(self): trial_element_index=None,
return "AlreadyAssembledInput({})".format(self.index) ):
#TODO: Isnt accumvar superfluous in the presence of all the other infos?
def __hash__(self): ImmutableRecord.__init__(self,
return hash(self.index) accumvar=accumvar,
restriction=None,
test_element=test_element,
test_element_index=test_element_index,
trial_element=trial_element,
trial_element_index=trial_element_index,
)
def __str__(self): @property
return "Input{}".format(self.index[0]) def within_inames(self):
if self.trial_element is None:
return ()
else:
from dune.perftool.sumfact.basis import lfs_inames
element = self.trial_element
if isinstance(element, MixedElement):
element = element.extract_component(self.trial_element_index)[1]
return lfs_inames(element, self.restriction)
class SumfactAccumulationInfo(ImmutableRecord): class SumfactAccumulationInfo(ImmutableRecord):
...@@ -266,16 +280,18 @@ def generate_accumulation_instruction(expr, visitor): ...@@ -266,16 +280,18 @@ def generate_accumulation_instruction(expr, visitor):
if priority is None: if priority is None:
priority = 3 priority = 3
output = AccumulationOutput(accumvar=accumvar,
restriction=(test_info.restriction, trial_info.restriction),
test_element=test_info.element,
test_element_index=test_info.element_index,
trial_element=trial_info.element,
trial_element_index=trial_info.element_index,
)
sf = SumfactKernel(matrix_sequence=matrix_sequence, sf = SumfactKernel(matrix_sequence=matrix_sequence,
restriction=(test_info.restriction, trial_info.restriction),
stage=3, stage=3,
position_priority=priority, position_priority=priority,
accumvar=accumvar, output=output,
test_element=test_info.element,
test_element_index=test_info.element_index,
trial_element=trial_info.element,
trial_element_index=trial_info.element_index,
input=AlreadyAssembledInput(index=(test_info.element_index,)),
predicates=predicates, predicates=predicates,
) )
......
...@@ -24,12 +24,28 @@ class SumfactKernelInputBase(object): ...@@ -24,12 +24,28 @@ class SumfactKernelInputBase(object):
return False return False
def realize(self, sf, dep, index=0): def realize(self, sf, dep, index=0):
return frozenset() return dep
def realize_direct(self, inames): def realize_direct(self, inames):
raise NotImplementedError raise NotImplementedError
class SumfactKernelOutputBase(object):
@property
def direct_output_is_possible(self):
return False
@property
def within_inames(self):
return ()
def realize(self, sf, dep):
return dep
def realize_direct(self):
raise NotImplementedError
class VectorSumfactKernelInput(SumfactKernelInputBase): class VectorSumfactKernelInput(SumfactKernelInputBase):
def __init__(self, inputs): def __init__(self, inputs):
assert(isinstance(inputs, tuple)) assert(isinstance(inputs, tuple))
...@@ -82,14 +98,9 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): ...@@ -82,14 +98,9 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
buffer=None, buffer=None,
stage=1, stage=1,
position_priority=None, position_priority=None,
restriction=None,
insn_dep=frozenset(), insn_dep=frozenset(),
input=None, input=SumfactKernelInputBase(),
accumvar=None, output=SumfactKernelOutputBase(),
test_element=None,
test_element_index=None,
trial_element=None,
trial_element_index=None,
predicates=frozenset(), predicates=frozenset(),
): ):
"""Create a sum factorization kernel """Create a sum factorization kernel
...@@ -165,11 +176,8 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): ...@@ -165,11 +176,8 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
assert stage in (1, 3) assert stage in (1, 3)
if stage == 1: assert isinstance(input, SumfactKernelInputBase)
assert isinstance(input, SumfactKernelInputBase) assert isinstance(output, SumfactKernelOutputBase)
if stage == 3:
assert isinstance(restriction, tuple)
assert isinstance(insn_dep, frozenset) assert isinstance(insn_dep, frozenset)
...@@ -215,21 +223,28 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): ...@@ -215,21 +223,28 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
def cache_key(self): def cache_key(self):
""" The cache key that can be used in generation magic """ The cache key that can be used in generation magic
Any two sum factorization kernels having the same cache_key Any two sum factorization kernels having the same cache_key
are realized simulatenously! are realized simultaneously!
""" """
return (self.matrix_sequence, self.restriction, self.stage, self.buffer, self.test_element_index) if self.buffer is None:
# During dry run, we return something unique to this kernel
return repr(self)
else:
# Later we identify parallely implemented kernels by the assigned buffer
return self.buffer
# return (self.matrix_sequence, self.restriction, self.stage, self.buffer, self.test_element_index)
@property @property
def input_key(self): def inout_key(self):
""" A cache key for the input coefficients """ A cache key for the input coefficients
Any two sum factorization kernels having the same input_key Any two sum factorization kernels having the same input_key
work on the same input coefficient work on the same input coefficient (stage 1) or accumulate
into the same thing (stage 3)
""" """
return (self.input, self.restriction, self.accumvar, self.trial_element_index) return (self.input, self.output)
@property @property
def group_name(self): def group_name(self):
return "sfgroup_{}_{}_{}_{}".format(self.input, self.restriction, self.accumvar, self.trial_element_index) return "sfgroup_{}_{}".format(self.input, self.output)
# #
# Some convenience methods to extract information about the sum factorization kernel # Some convenience methods to extract information about the sum factorization kernel
...@@ -238,7 +253,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): ...@@ -238,7 +253,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
def __lt__(self, other): def __lt__(self, other):
if self.parallel_key != other.parallel_key: if self.parallel_key != other.parallel_key:
return self.parallel_key < other.parallel_key return self.parallel_key < other.parallel_key
if self.input_key != other.input_key: if self.inout_key != other.inout_key:
return self.input_key < other.input_key return self.input_key < other.input_key
if self.position_priority == other.position_priority: if self.position_priority == other.position_priority:
return repr(self) < repr(other) return repr(self) < repr(other)
...@@ -263,14 +278,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): ...@@ -263,14 +278,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
@property @property
def within_inames(self): def within_inames(self):
if self.trial_element is None: return self.output.within_inames
return ()
else:
from dune.perftool.sumfact.basis import lfs_inames
element = self.trial_element
if isinstance(element, MixedElement):
element = element.extract_component(self.trial_element_index)[1]
return lfs_inames(element, self.restriction)
def vec_index(self, sf): def vec_index(self, sf):
""" Map an unvectorized sumfact kernel object to its position """ Map an unvectorized sumfact kernel object to its position
...@@ -553,8 +561,8 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) ...@@ -553,8 +561,8 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
return (tuple(k.cache_key for k in self.kernels), self.buffer) return (tuple(k.cache_key for k in self.kernels), self.buffer)
@property @property
def input_key(self): def inout_key(self):
return tuple(k.input_key for k in self.kernels) return tuple(k.inout_key for k in self.kernels)
@property @property
def group_name(self): def group_name(self):
...@@ -570,7 +578,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) ...@@ -570,7 +578,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
def horizontal_index(self, sf): def horizontal_index(self, sf):
for i, k in enumerate(self.kernels): for i, k in enumerate(self.kernels):
if sf.input_key == k.input_key: if sf.inout_key == k.inout_key:
if tuple(mat.derivative for mat in sf.matrix_sequence) == tuple(mat.derivative for mat in k.matrix_sequence): if tuple(mat.derivative for mat in sf.matrix_sequence) == tuple(mat.derivative for mat in k.matrix_sequence):
return i return i
......
...@@ -284,8 +284,8 @@ def _level2_optimal_vectorization_strategy_generator(sumfacts, width, qp, alread ...@@ -284,8 +284,8 @@ def _level2_optimal_vectorization_strategy_generator(sumfacts, width, qp, alread
yielded = False yielded = False
# Find the number of input coefficients we can work on # Find the number of input coefficients we can work on
keys = frozenset(sf.input_key for sf in sumfacts) keys = frozenset(sf.inout_key for sf in sumfacts)
inputkey_sumfacts = [tuple(sorted(filter(lambda sf: sf.input_key == key, sumfacts))) for key in keys] inoutkey_sumfacts = [tuple(sorted(filter(lambda sf: sf.inout_key == key, sumfacts))) for key in keys]
for parallel in (1, 2): for parallel in (1, 2):
if parallel == 2 and next(iter(sumfacts)).stage == 3: if parallel == 2 and next(iter(sumfacts)).stage == 3:
...@@ -294,7 +294,7 @@ def _level2_optimal_vectorization_strategy_generator(sumfacts, width, qp, alread ...@@ -294,7 +294,7 @@ def _level2_optimal_vectorization_strategy_generator(sumfacts, width, qp, alread
it.permutations(range(len(keys)), parallel)): it.permutations(range(len(keys)), parallel)):
horizontal = 1 horizontal = 1
while horizontal <= width // parallel: while horizontal <= width // parallel:
combo = sum((inputkey_sumfacts[part][:horizontal] for part in which), ()) combo = sum((inoutkey_sumfacts[part][:horizontal] for part in which), ())
vecdict = get_vectorization_dict(combo, width // (horizontal * parallel), horizontal * parallel, qp) vecdict = get_vectorization_dict(combo, width // (horizontal * parallel), horizontal * parallel, qp)
horizontal *= 2 horizontal *= 2
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment