diff --git a/python/dune/perftool/loopy/transformations/collect_rotate.py b/python/dune/perftool/loopy/transformations/collect_rotate.py index 21713c6768c0dbf3cde22b271b566ba147ce2203..ff5f18f1b13bdbaa3b5e93db6b14a074bd3545d4 100644 --- a/python/dune/perftool/loopy/transformations/collect_rotate.py +++ b/python/dune/perftool/loopy/transformations/collect_rotate.py @@ -289,16 +289,16 @@ def collect_vector_data_rotate(knl): ) ) - # Rotate back! - if rotating: - new_insns.append(lp.CallInstruction((), # assignees - prim.Call(prim.Variable("transpose_reg"), - tuple(prim.Subscript(prim.Variable(lhsname), (prim.Variable("vec_index") + i, prim.Variable(new_iname))) for i in range(4))), - depends_on=frozenset({Tagged("vec_write")}), - within_inames=common_inames.union(inames).union(frozenset({new_iname})), - within_inames_is_final=True, - id="{}_rotateback".format(lhsname), - )) + # Rotate back! + if rotating and "{}_rotateback".format(lhsname) not in [i.id for i in new_insns]: + new_insns.append(lp.CallInstruction((), # assignees + prim.Call(prim.Variable("transpose_reg"), + tuple(prim.Subscript(prim.Variable(lhsname), (prim.Variable("vec_index") + i, prim.Variable(new_iname))) for i in range(4))), + depends_on=frozenset({Tagged("vec_write")}), + within_inames=common_inames.union(inames).union(frozenset({new_iname})), + within_inames_is_final=True, + id="{}_rotateback".format(lhsname), + )) from loopy.kernel.creation import resolve_dependencies return resolve_dependencies(knl.copy(instructions=new_insns + other_insns)) diff --git a/python/dune/perftool/sumfact/__init__.py b/python/dune/perftool/sumfact/__init__.py index d82ea08af1a7ad439562be9732114b098848d6c1..cba13702efbf8f8a7333dc919e48c8de87048232 100644 --- a/python/dune/perftool/sumfact/__init__.py +++ b/python/dune/perftool/sumfact/__init__.py @@ -24,10 +24,14 @@ class SumFactInterface(PDELabInterface): return pymbolic_reference_gradient(element, restriction, number) def pymbolic_trialfunction_gradient(self, element, restriction, component, visitor=None): - return pymbolic_trialfunction_gradient(element, restriction, component, visitor) + ret, indices = pymbolic_trialfunction_gradient(element, restriction, component, visitor) + visitor.indices = indices + return ret def pymbolic_trialfunction(self, element, restriction, component, visitor=None): - return pymbolic_trialfunction(element, restriction, component, visitor) + ret, indices = pymbolic_trialfunction(element, restriction, component, visitor) + visitor.indices = indices + return ret def quadrature_inames(self): return quadrature_inames() diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py index f5f5ebdf613678254fa6d2ea9770f15976e94a1c..a7b3c4c98ac3c5b262ff1f7dee9b5acecba33aa9 100644 --- a/python/dune/perftool/sumfact/basis.py +++ b/python/dune/perftool/sumfact/basis.py @@ -75,7 +75,7 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor): # Get the vectorization info. If this happens during the dry run, we get dummies from dune.perftool.sumfact.vectorization import get_vectorization_info - a_matrices, buffer, input, index = get_vectorization_info(a_matrices, 0) + a_matrices, buffer, input, index = get_vectorization_info(a_matrices, restriction) # Initialize the buffer for the sum fact kernel shape = (product(mat.cols for mat in a_matrices),) @@ -101,6 +101,7 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor): 1, preferred_position=i, insn_dep=insn_dep, + restriction=restriction, ) buffers.append(var) @@ -109,9 +110,7 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor): # with the position in the vector register. if index: assert len(visitor.indices) == 1 - indices = visitor.indices - visitor.indices = None - return maybe_wrap_subscript(var, tuple(prim.Variable(i) for i in quadrature_inames()) + indices) + return maybe_wrap_subscript(var, tuple(prim.Variable(i) for i in quadrature_inames()) + visitor.indices), None # TODO this should be quite conditional!!! for i, buf in enumerate(buffers): @@ -126,7 +125,7 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor): forced_iname_deps_is_final=True, ) - return prim.Variable(name) + return prim.Variable(name), visitor.indices @kernel_cached @@ -140,7 +139,7 @@ def pymbolic_trialfunction(element, restriction, component, visitor): # Get the vectorization info. If this happens during the dry run, we get dummies from dune.perftool.sumfact.vectorization import get_vectorization_info - a_matrices, buffer, input, index = get_vectorization_info(a_matrices, 0) + a_matrices, buffer, input, index = get_vectorization_info(a_matrices, restriction) # Flip flop buffers for sumfactorization shape = (product(mat.cols for mat in a_matrices),) @@ -164,6 +163,7 @@ def pymbolic_trialfunction(element, restriction, component, visitor): preferred_position=None, insn_dep=frozenset({Writes(input)}), outshape=tuple(mat.rows for mat in a_matrices if mat.rows != 1), + restriction=restriction, ) if index: @@ -173,7 +173,7 @@ def pymbolic_trialfunction(element, restriction, component, visitor): return prim.Subscript(var, tuple(prim.Variable(i) for i in quadrature_inames()) + index - ) + ), visitor.indices @iname diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py index b038e763fd7a7168d6aec101ecc397bc50639203..2b8dba96e176f107acdb8dcfc8a41c1d18544b67 100644 --- a/python/dune/perftool/sumfact/sumfact.py +++ b/python/dune/perftool/sumfact/sumfact.py @@ -128,7 +128,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): # Get the vectorization info. If this happens during the dry run, we get dummies from dune.perftool.sumfact.vectorization import get_vectorization_info - a_matrices, buffer, input, index = get_vectorization_info(a_matrices, restriction) + a_matrices, buffer, input, index = get_vectorization_info(a_matrices, accterm.argument.restriction) # Initialize a base storage for this buffer and get a temporay pointing to it shape = tuple(mat.cols for mat in a_matrices if mat.cols != 1) @@ -185,6 +185,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): insn_dep=insn_dep, additional_inames=frozenset(visitor.inames), preferred_position=pref_pos, + restriction=accterm.argument.restriction, ) inames = tuple(sumfact_iname(mat.rows, 'accum') for mat in a_matrices) @@ -240,7 +241,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): insn_dep = emit_sumfact_kernel(None, restriction, insn_dep) -@generator_factory(item_tags=("sumfactkernel",), context_tags=("kernel",), cache_key_generator=lambda a, b, s, **kw: (a, b, s)) +@generator_factory(item_tags=("sumfactkernel",), context_tags=("kernel",), cache_key_generator=lambda a, b, s, **kw: (a, b, s, kw.get("restriction", 0))) def sum_factorization_kernel(a_matrices, buf, stage, insn_dep=frozenset({}), additional_inames=frozenset({}), diff --git a/python/dune/perftool/sumfact/vectorization.py b/python/dune/perftool/sumfact/vectorization.py index d4942550ed39e7007c66bbec3066f3c8bd7ace10..37bc7f9fc30356740413dddcd56b8a8d673f240c 100644 --- a/python/dune/perftool/sumfact/vectorization.py +++ b/python/dune/perftool/sumfact/vectorization.py @@ -30,17 +30,16 @@ def get_vectorization_info(a_matrices, restriction): def no_vectorization(sumfacts): for sumf in sumfacts: - for res in (Restriction.NONE, Restriction.POSITIVE, Restriction.NEGATIVE): - vectorization_info(sumf.a_matrices, - res, - sumf.a_matrices, - get_counted_variable("buffer"), - get_counted_variable(restricted_name("input", sumf.restriction)), - None) + vectorization_info(sumf.a_matrices, + sumf.restriction, + sumf.a_matrices, + get_counted_variable("buffer"), + get_counted_variable(restricted_name("input", sumf.restriction)), + None) -def decide_stage_vectorization_strategy(sumfacts, stage): - stage_sumfacts = frozenset([sf for sf in sumfacts if sf.stage == stage]) +def decide_stage_vectorization_strategy(sumfacts, stage, restriction): + stage_sumfacts = frozenset([sf for sf in sumfacts if sf.stage == stage and sf.restriction == restriction]) if len(stage_sumfacts) in (3, 4): # Map the sum factorization to their position in the joint kernel position_mapping = {} @@ -107,8 +106,9 @@ def decide_vectorization_strategy(): if not get_option("vectorize_grads"): no_vectorization(sumfacts) else: - decide_stage_vectorization_strategy(sumfacts, 1) - decide_stage_vectorization_strategy(sumfacts, 3) + for stage in (1, 3): + for restriction in (Restriction.NONE, Restriction.POSITIVE, Restriction.NEGATIVE): + decide_stage_vectorization_strategy(sumfacts, stage, restriction) class HasSumfactMapper(lp.symbolic.CombineMapper):