From 10b3bbde507400ecd1a72833163059ad15969ba4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20He=C3=9F?= <rene.hess@iwr.uni-heidelberg.de> Date: Thu, 15 Nov 2018 21:25:03 +0100 Subject: [PATCH] [skip ci][wip] Move around things in realize_sumfact_kernel_function --- python/dune/codegen/sumfact/realization.py | 48 +++++++++++----------- python/dune/codegen/sumfact/symbolic.py | 3 +- 2 files changed, 25 insertions(+), 26 deletions(-) diff --git a/python/dune/codegen/sumfact/realization.py b/python/dune/codegen/sumfact/realization.py index b9d209c3..7207dd38 100644 --- a/python/dune/codegen/sumfact/realization.py +++ b/python/dune/codegen/sumfact/realization.py @@ -161,11 +161,11 @@ def realize_sumfact_kernel_function(sf): # Compute the correct shapes of in- and output matrices of this matrix-matrix multiplication # and get inames that realize the product. inp_shape = (matrix.cols,) \ - + tuple(mat.cols for mat in matrix_sequence[l + 1:]) \ - + tuple(mat.rows for mat in matrix_sequence[:l]) + + tuple(mat.cols for mat in matrix_sequence[l + 1:]) \ + + tuple(mat.rows for mat in matrix_sequence[:l]) out_shape = (matrix.rows,) \ - + tuple(mat.cols for mat in matrix_sequence[l + 1:]) \ - + tuple(mat.rows for mat in matrix_sequence[:l]) + + tuple(mat.cols for mat in matrix_sequence[l + 1:]) \ + + tuple(mat.rows for mat in matrix_sequence[:l]) out_inames = tuple(sumfact_iname(length, "out_inames_" + str(k)) for k, length in enumerate(out_shape)) vec_iname = () if matrix.vectorized: @@ -225,26 +225,6 @@ def realize_sumfact_kernel_function(sf): buffer.switch() - # Get a temporary that interprets the base storage of the output. - # - # Note: In this step the reordering of the fastest directions - # is happening. The new direction (out_inames[0]) and the - # corresponding shape (out_shape[0]) goes to the end (slowest - # direction) and everything stays column major (ftags->fortran - # style). - # - # If we are in the last step we reverse the permutation. - output_shape = tuple(out_shape[1:]) + (out_shape[0],) - if l == len(matrix_sequence) - 1: - output_shape = permute_backward(output_shape, sf.cost_permutation) - if sf.stage == 3: - output_shape = permute_backward(output_shape, sf.interface.quadrature_permutation) - - out = buffer.get_temporary("buff_step{}_out".format(l), - shape=output_shape + vec_shape, - dim_tags=ftags, - ) - # Write the matrix-matrix multiplication expression matprod = prim.Product((matrix.pymbolic((prim.Variable(out_inames[0]), k_expr) + vec_iname), input_summand)) @@ -271,6 +251,26 @@ def realize_sumfact_kernel_function(sf): insn_args["forced_iname_deps"] = frozenset({vec_iname[0].name}) insn_dep = sf.interface.realize_direct(matprod, output_inames, out_shape, **insn_args) else: + # Get a temporary that interprets the base storage of the output. + # + # Note: In this step the reordering of the fastest directions + # is happening. The new direction (out_inames[0]) and the + # corresponding shape (out_shape[0]) goes to the end (slowest + # direction) and everything stays column major (ftags->fortran + # style). + # + # If we are in the last step we reverse the permutation. + output_shape = tuple(out_shape[1:]) + (out_shape[0],) + if l == len(matrix_sequence) - 1: + output_shape = permute_backward(output_shape, sf.cost_permutation) + if sf.stage == 3: + output_shape = permute_backward(output_shape, sf.interface.quadrature_permutation) + + out = buffer.get_temporary("buff_step{}_out".format(l), + shape=output_shape + vec_shape, + dim_tags=ftags, + ) + # Issue the reduction instruction that implements the multiplication # at the same time store the instruction ID for the next instruction to depend on insn_dep = frozenset({instruction(assignee=prim.Subscript(prim.Variable(out), output_inames + vec_iname), diff --git a/python/dune/codegen/sumfact/symbolic.py b/python/dune/codegen/sumfact/symbolic.py index 0cafec86..0b303c64 100644 --- a/python/dune/codegen/sumfact/symbolic.py +++ b/python/dune/codegen/sumfact/symbolic.py @@ -117,6 +117,7 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase): def cost_permutation(self): # The cost_permutation of the underlying scalar SumfactKernel can be # different for each kernel. + # raise RuntimeError("cost_permutation should not be called on VectorSumfactKernelInput") # TODO! cost_permutation = self.interfaces[0].cost_permutation @@ -125,8 +126,6 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase): return cost_permutation - # raise RuntimeError("cost_permutation should not be called on VectorSumfactKernelInput") - @property def stage(self): return 1 -- GitLab