diff --git a/python/dune/perftool/sumfact/accumulation.py b/python/dune/perftool/sumfact/accumulation.py index 9566efb71e358dcaf346bd3a24b94858e3a56f26..a01fb8eac5db2b8045be31e472f4a8b909a18ffc 100644 --- a/python/dune/perftool/sumfact/accumulation.py +++ b/python/dune/perftool/sumfact/accumulation.py @@ -32,7 +32,7 @@ from dune.perftool.sumfact.tabulation import (basis_functions_per_direction, from dune.perftool.sumfact.switch import (get_facedir, get_facemod, ) -from dune.perftool.sumfact.symbolic import SumfactKernel, SumfactKernelInputBase +from dune.perftool.sumfact.symbolic import SumfactKernel, SumfactKernelOutputBase from dune.perftool.ufl.modified_terminals import extract_modified_arguments from dune.perftool.tools import get_pymbolic_basename from dune.perftool.error import PerftoolError @@ -81,21 +81,35 @@ def accum_iname(element, bound, i): return sumfact_iname(bound, "accum{}".format(suffix)) -class AlreadyAssembledInput(SumfactKernelInputBase): - def __init__(self, index): - self.index = index - - def __eq__(self, other): - return type(self) == type(other) and self.index == other.index - - def __repr__(self): - return "AlreadyAssembledInput({})".format(self.index) - - def __hash__(self): - return hash(self.index) +class AccumulationOutput(SumfactKernelOutputBase, ImmutableRecord): + def __init__(self, + accumvar=None, + restriction=None, + test_element=None, + test_element_index=None, + trial_element=None, + trial_element_index=None, + ): + #TODO: Isnt accumvar superfluous in the presence of all the other infos? + ImmutableRecord.__init__(self, + accumvar=accumvar, + restriction=None, + test_element=test_element, + test_element_index=test_element_index, + trial_element=trial_element, + trial_element_index=trial_element_index, + ) - def __str__(self): - return "Input{}".format(self.index[0]) + @property + def within_inames(self): + if self.trial_element is None: + return () + else: + from dune.perftool.sumfact.basis import lfs_inames + element = self.trial_element + if isinstance(element, MixedElement): + element = element.extract_component(self.trial_element_index)[1] + return lfs_inames(element, self.restriction) class SumfactAccumulationInfo(ImmutableRecord): @@ -266,16 +280,18 @@ def generate_accumulation_instruction(expr, visitor): if priority is None: priority = 3 + output = AccumulationOutput(accumvar=accumvar, + restriction=(test_info.restriction, trial_info.restriction), + test_element=test_info.element, + test_element_index=test_info.element_index, + trial_element=trial_info.element, + trial_element_index=trial_info.element_index, + ) + sf = SumfactKernel(matrix_sequence=matrix_sequence, - restriction=(test_info.restriction, trial_info.restriction), stage=3, position_priority=priority, - accumvar=accumvar, - test_element=test_info.element, - test_element_index=test_info.element_index, - trial_element=trial_info.element, - trial_element_index=trial_info.element_index, - input=AlreadyAssembledInput(index=(test_info.element_index,)), + output=output, predicates=predicates, ) diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py index 09d597bba2ad3fd08621ed49a56be57864d0daa1..eb9df93108748e7bd68d583cad0aa8b6dbf1c3ad 100644 --- a/python/dune/perftool/sumfact/symbolic.py +++ b/python/dune/perftool/sumfact/symbolic.py @@ -24,12 +24,28 @@ class SumfactKernelInputBase(object): return False def realize(self, sf, dep, index=0): - return frozenset() + return dep def realize_direct(self, inames): raise NotImplementedError +class SumfactKernelOutputBase(object): + @property + def direct_output_is_possible(self): + return False + + @property + def within_inames(self): + return () + + def realize(self, sf, dep): + return dep + + def realize_direct(self): + raise NotImplementedError + + class VectorSumfactKernelInput(SumfactKernelInputBase): def __init__(self, inputs): assert(isinstance(inputs, tuple)) @@ -82,14 +98,9 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): buffer=None, stage=1, position_priority=None, - restriction=None, insn_dep=frozenset(), - input=None, - accumvar=None, - test_element=None, - test_element_index=None, - trial_element=None, - trial_element_index=None, + input=SumfactKernelInputBase(), + output=SumfactKernelOutputBase(), predicates=frozenset(), ): """Create a sum factorization kernel @@ -165,11 +176,8 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): assert stage in (1, 3) - if stage == 1: - assert isinstance(input, SumfactKernelInputBase) - - if stage == 3: - assert isinstance(restriction, tuple) + assert isinstance(input, SumfactKernelInputBase) + assert isinstance(output, SumfactKernelOutputBase) assert isinstance(insn_dep, frozenset) @@ -215,21 +223,28 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): def cache_key(self): """ The cache key that can be used in generation magic Any two sum factorization kernels having the same cache_key - are realized simulatenously! + are realized simultaneously! """ - return (self.matrix_sequence, self.restriction, self.stage, self.buffer, self.test_element_index) + if self.buffer is None: + # During dry run, we return something unique to this kernel + return repr(self) + else: + # Later we identify parallely implemented kernels by the assigned buffer + return self.buffer +# return (self.matrix_sequence, self.restriction, self.stage, self.buffer, self.test_element_index) @property - def input_key(self): + def inout_key(self): """ A cache key for the input coefficients Any two sum factorization kernels having the same input_key - work on the same input coefficient + work on the same input coefficient (stage 1) or accumulate + into the same thing (stage 3) """ - return (self.input, self.restriction, self.accumvar, self.trial_element_index) + return (self.input, self.output) @property def group_name(self): - return "sfgroup_{}_{}_{}_{}".format(self.input, self.restriction, self.accumvar, self.trial_element_index) + return "sfgroup_{}_{}".format(self.input, self.output) # # Some convenience methods to extract information about the sum factorization kernel @@ -238,7 +253,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): def __lt__(self, other): if self.parallel_key != other.parallel_key: return self.parallel_key < other.parallel_key - if self.input_key != other.input_key: + if self.inout_key != other.inout_key: return self.input_key < other.input_key if self.position_priority == other.position_priority: return repr(self) < repr(other) @@ -263,14 +278,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): @property def within_inames(self): - if self.trial_element is None: - return () - else: - from dune.perftool.sumfact.basis import lfs_inames - element = self.trial_element - if isinstance(element, MixedElement): - element = element.extract_component(self.trial_element_index)[1] - return lfs_inames(element, self.restriction) + return self.output.within_inames def vec_index(self, sf): """ Map an unvectorized sumfact kernel object to its position @@ -553,8 +561,8 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) return (tuple(k.cache_key for k in self.kernels), self.buffer) @property - def input_key(self): - return tuple(k.input_key for k in self.kernels) + def inout_key(self): + return tuple(k.inout_key for k in self.kernels) @property def group_name(self): @@ -570,7 +578,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) def horizontal_index(self, sf): for i, k in enumerate(self.kernels): - if sf.input_key == k.input_key: + if sf.inout_key == k.inout_key: if tuple(mat.derivative for mat in sf.matrix_sequence) == tuple(mat.derivative for mat in k.matrix_sequence): return i diff --git a/python/dune/perftool/sumfact/vectorization.py b/python/dune/perftool/sumfact/vectorization.py index e4a044edf0e62b8b2176efe4755a21617f0fcd8c..8d0081dbf4739005c7fb9c5ffc9fa7aa48806c93 100644 --- a/python/dune/perftool/sumfact/vectorization.py +++ b/python/dune/perftool/sumfact/vectorization.py @@ -284,8 +284,8 @@ def _level2_optimal_vectorization_strategy_generator(sumfacts, width, qp, alread yielded = False # Find the number of input coefficients we can work on - keys = frozenset(sf.input_key for sf in sumfacts) - inputkey_sumfacts = [tuple(sorted(filter(lambda sf: sf.input_key == key, sumfacts))) for key in keys] + keys = frozenset(sf.inout_key for sf in sumfacts) + inoutkey_sumfacts = [tuple(sorted(filter(lambda sf: sf.inout_key == key, sumfacts))) for key in keys] for parallel in (1, 2): if parallel == 2 and next(iter(sumfacts)).stage == 3: @@ -294,7 +294,7 @@ def _level2_optimal_vectorization_strategy_generator(sumfacts, width, qp, alread it.permutations(range(len(keys)), parallel)): horizontal = 1 while horizontal <= width // parallel: - combo = sum((inputkey_sumfacts[part][:horizontal] for part in which), ()) + combo = sum((inoutkey_sumfacts[part][:horizontal] for part in which), ()) vecdict = get_vectorization_dict(combo, width // (horizontal * parallel), horizontal * parallel, qp) horizontal *= 2