diff --git a/python/dune/codegen/sumfact/realization.py b/python/dune/codegen/sumfact/realization.py index 112ee20e8637c4f7c4ebbdcd6b6ef8bebb8e6a70..306f6c756a77febb290b53fdfe220b4441dee7e1 100644 --- a/python/dune/codegen/sumfact/realization.py +++ b/python/dune/codegen/sumfact/realization.py @@ -229,14 +229,8 @@ def realize_sumfact_kernel_function(sf): if matrix.cols != 1: matprod = lp.Reduction("sum", k, matprod) - # Here we also move the new direction (out_inames[0]) to the - # end and reverse permutation + # Here we also move the new direction (out_inames[0]) to the end output_inames = tuple(prim.Variable(i) for i in out_inames[1:]) + (prim.Variable(out_inames[0]),) - if l == len(matrix_sequence) - 1: - # TODO: Move permutations to interface! - output_inames = permute_backward(output_inames, sf.cost_permutation) - if sf.stage == 3: - output_inames = permute_backward(output_inames, sf.interface.quadrature_permutation) # Collect the key word arguments for the loopy instruction insn_args = {"depends_on": insn_dep} @@ -244,15 +238,22 @@ def realize_sumfact_kernel_function(sf): # In case of direct output we directly accumulate the result # of the Sumfactorization into some global data structure. if l == len(matrix_sequence) - 1 and get_form_option('fastdg') and sf.stage == 3: + # TODO: Move permutations to interface! + output_inames = permute_backward(output_inames, sf.cost_permutation) + output_inames = permute_backward(output_inames, sf.interface.quadrature_permutation) + if sf.vectorized: insn_args["forced_iname_deps"] = frozenset({vec_iname[0].name}) insn_dep = sf.interface.realize_direct(matprod, output_inames, out_shape, **insn_args) elif l == len(matrix_sequence) - 1: + # TODO: Move permutations to interface! + output_inames = permute_backward(output_inames, sf.cost_permutation) output_shape = tuple(out_shape[1:]) + (out_shape[0],) # TODO: Move permutations to interface output_shape = permute_backward(output_shape, sf.cost_permutation) if sf.stage == 3: + output_inames = permute_backward(output_inames, sf.interface.quadrature_permutation) output_shape = permute_backward(output_shape, sf.interface.quadrature_permutation) out = buffer.get_temporary("buff_step{}_out".format(l),