diff --git a/python/dune/codegen/sumfact/realization.py b/python/dune/codegen/sumfact/realization.py index b9d209c31ff244f3aaa51e74b2dcb47e6d168d82..7207dd3829a4fc4e192278b2b43b5c9cbeaae8f0 100644 --- a/python/dune/codegen/sumfact/realization.py +++ b/python/dune/codegen/sumfact/realization.py @@ -161,11 +161,11 @@ def realize_sumfact_kernel_function(sf): # Compute the correct shapes of in- and output matrices of this matrix-matrix multiplication # and get inames that realize the product. inp_shape = (matrix.cols,) \ - + tuple(mat.cols for mat in matrix_sequence[l + 1:]) \ - + tuple(mat.rows for mat in matrix_sequence[:l]) + + tuple(mat.cols for mat in matrix_sequence[l + 1:]) \ + + tuple(mat.rows for mat in matrix_sequence[:l]) out_shape = (matrix.rows,) \ - + tuple(mat.cols for mat in matrix_sequence[l + 1:]) \ - + tuple(mat.rows for mat in matrix_sequence[:l]) + + tuple(mat.cols for mat in matrix_sequence[l + 1:]) \ + + tuple(mat.rows for mat in matrix_sequence[:l]) out_inames = tuple(sumfact_iname(length, "out_inames_" + str(k)) for k, length in enumerate(out_shape)) vec_iname = () if matrix.vectorized: @@ -225,26 +225,6 @@ def realize_sumfact_kernel_function(sf): buffer.switch() - # Get a temporary that interprets the base storage of the output. - # - # Note: In this step the reordering of the fastest directions - # is happening. The new direction (out_inames[0]) and the - # corresponding shape (out_shape[0]) goes to the end (slowest - # direction) and everything stays column major (ftags->fortran - # style). - # - # If we are in the last step we reverse the permutation. - output_shape = tuple(out_shape[1:]) + (out_shape[0],) - if l == len(matrix_sequence) - 1: - output_shape = permute_backward(output_shape, sf.cost_permutation) - if sf.stage == 3: - output_shape = permute_backward(output_shape, sf.interface.quadrature_permutation) - - out = buffer.get_temporary("buff_step{}_out".format(l), - shape=output_shape + vec_shape, - dim_tags=ftags, - ) - # Write the matrix-matrix multiplication expression matprod = prim.Product((matrix.pymbolic((prim.Variable(out_inames[0]), k_expr) + vec_iname), input_summand)) @@ -271,6 +251,26 @@ def realize_sumfact_kernel_function(sf): insn_args["forced_iname_deps"] = frozenset({vec_iname[0].name}) insn_dep = sf.interface.realize_direct(matprod, output_inames, out_shape, **insn_args) else: + # Get a temporary that interprets the base storage of the output. + # + # Note: In this step the reordering of the fastest directions + # is happening. The new direction (out_inames[0]) and the + # corresponding shape (out_shape[0]) goes to the end (slowest + # direction) and everything stays column major (ftags->fortran + # style). + # + # If we are in the last step we reverse the permutation. + output_shape = tuple(out_shape[1:]) + (out_shape[0],) + if l == len(matrix_sequence) - 1: + output_shape = permute_backward(output_shape, sf.cost_permutation) + if sf.stage == 3: + output_shape = permute_backward(output_shape, sf.interface.quadrature_permutation) + + out = buffer.get_temporary("buff_step{}_out".format(l), + shape=output_shape + vec_shape, + dim_tags=ftags, + ) + # Issue the reduction instruction that implements the multiplication # at the same time store the instruction ID for the next instruction to depend on insn_dep = frozenset({instruction(assignee=prim.Subscript(prim.Variable(out), output_inames + vec_iname), diff --git a/python/dune/codegen/sumfact/symbolic.py b/python/dune/codegen/sumfact/symbolic.py index 0cafec8644a8f0a45e6ce08926feb85a73ce2ac6..0b303c640952a98d40278e3b76fe6958e44a5d91 100644 --- a/python/dune/codegen/sumfact/symbolic.py +++ b/python/dune/codegen/sumfact/symbolic.py @@ -117,6 +117,7 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase): def cost_permutation(self): # The cost_permutation of the underlying scalar SumfactKernel can be # different for each kernel. + # raise RuntimeError("cost_permutation should not be called on VectorSumfactKernelInput") # TODO! cost_permutation = self.interfaces[0].cost_permutation @@ -125,8 +126,6 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase): return cost_permutation - # raise RuntimeError("cost_permutation should not be called on VectorSumfactKernelInput") - @property def stage(self): return 1