From e16b2619341e1505e0746f9553be73e2ad0dc01b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20He=C3=9F?= <rene.hess@iwr.uni-heidelberg.de> Date: Thu, 15 Nov 2018 21:07:51 +0100 Subject: [PATCH] [skip ci][wip] Move input permutation to interface classes So far only for fastdg. This should also happen in the non-fastdg case. --- python/dune/codegen/sumfact/basis.py | 14 ++++++++- python/dune/codegen/sumfact/realization.py | 24 +++++++--------- python/dune/codegen/sumfact/symbolic.py | 33 +++++++++++++++++++--- 3 files changed, 52 insertions(+), 19 deletions(-) diff --git a/python/dune/codegen/sumfact/basis.py b/python/dune/codegen/sumfact/basis.py index 3ce2fb24..ccc34406 100644 --- a/python/dune/codegen/sumfact/basis.py +++ b/python/dune/codegen/sumfact/basis.py @@ -24,7 +24,8 @@ from dune.codegen.sumfact.tabulation import (basis_functions_per_direction, name_polynomials, polynomial_degree, ) -from dune.codegen.sumfact.permutation import (permute_forward, +from dune.codegen.sumfact.permutation import (permute_backward, + permute_forward, sumfact_cost_permutation_strategy, sumfact_quadrature_permutation_strategy, ) @@ -144,6 +145,17 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord): return prim.Subscript(prim.Variable(arg), inames) + def realize_input(self, shape, inames, which=0): + if self.direct_is_possible: + shape = permute_backward(shape, self.cost_permutation) + shape = permute_backward(shape, self.quadrature_permutation) + inames = permute_backward(inames, self.cost_permutation) + inames = permute_backward(inames, self.quadrature_permutation) + + return self.realize_direct(shape, inames) + else: + raise NotImplementedError("TODO") + @property def function_name_suffix(self): if get_form_option("fastdg"): diff --git a/python/dune/codegen/sumfact/realization.py b/python/dune/codegen/sumfact/realization.py index 1d951eca..b9d209c3 100644 --- a/python/dune/codegen/sumfact/realization.py +++ b/python/dune/codegen/sumfact/realization.py @@ -160,8 +160,12 @@ def realize_sumfact_kernel_function(sf): for l, matrix in enumerate(matrix_sequence): # Compute the correct shapes of in- and output matrices of this matrix-matrix multiplication # and get inames that realize the product. - inp_shape = (matrix.cols,) + tuple(mat.cols for mat in matrix_sequence[l + 1:]) + tuple(mat.rows for mat in matrix_sequence[:l]) - out_shape = (matrix.rows,) + tuple(mat.cols for mat in matrix_sequence[l + 1:]) + tuple(mat.rows for mat in matrix_sequence[:l]) + inp_shape = (matrix.cols,) \ + + tuple(mat.cols for mat in matrix_sequence[l + 1:]) \ + + tuple(mat.rows for mat in matrix_sequence[:l]) + out_shape = (matrix.rows,) \ + + tuple(mat.cols for mat in matrix_sequence[l + 1:]) \ + + tuple(mat.rows for mat in matrix_sequence[:l]) out_inames = tuple(sumfact_iname(length, "out_inames_" + str(k)) for k, length in enumerate(out_shape)) vec_iname = () if matrix.vectorized: @@ -173,11 +177,11 @@ def realize_sumfact_kernel_function(sf): # a code generation corner case producing way too complicated code. This # could be fixed upstream, but the loopy code realizing reductions is not # trivial and the priority is kind of low. - if matrix.cols != 1: + if matrix.cols == 1: + k_expr = 0 + else: k = sumfact_iname(matrix.cols, "red") k_expr = prim.Variable(k) - else: - k_expr = 0 # Setup the input of the sum factorization kernel. In the # first matrix multiplication this can be taken from @@ -187,15 +191,7 @@ def realize_sumfact_kernel_function(sf): # (vectorized + FastDGGridOperator) input_inames = (k_expr,) + tuple(prim.Variable(j) for j in out_inames[1:]) if l == 0 and sf.stage == 1 and sf.interface.direct_is_possible: - # One permutation for cost reduction, see comment below - inp_shape = permute_backward(inp_shape, sf.cost_permutation) - input_inames = permute_backward(input_inames, sf.cost_permutation) - - # And one more for permuted quadrature points, see comment below - inp_shape = permute_backward(inp_shape, sf.interface.quadrature_permutation) - input_inames = permute_backward(input_inames, sf.interface.quadrature_permutation) - - input_summand = sf.interface.realize_direct(inp_shape, input_inames) + input_summand = sf.interface.realize_input(inp_shape, input_inames) else: # If we did permute the order of a matrices above we also # permuted the order of out_inames. Unfortunately the diff --git a/python/dune/codegen/sumfact/symbolic.py b/python/dune/codegen/sumfact/symbolic.py index 82f7e027..0cafec86 100644 --- a/python/dune/codegen/sumfact/symbolic.py +++ b/python/dune/codegen/sumfact/symbolic.py @@ -40,11 +40,11 @@ class SumfactKernelInterfaceBase(object): @property def quadrature_permutation(self): - return () + raise NotImplementedError @property def cost_permutation(self): - return () + raise NotImplementedError @property def combined_permutation(self): @@ -63,12 +63,12 @@ class SumfactKernelInterfaceBase(object): def permute_forward_cost(self, shape, inames): shape = permute_forward(shape, self.cost_permutation) inames = permute_forward(inames, self.cost_permutation) - return shape_inames + return shape, inames def permute_forward_quadrature(self, shape, inames): shape = permute_forward(shape, self.quadrature_permutation) inames = permute_forward(inames, self.quadrature_permutation) - return shape_inames + return shape, inames @property def within_inames(self): @@ -113,6 +113,20 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase): assert i.quadrature_permutation == self.interfaces[0].quadrature_permutation return self.interfaces[0].quadrature_permutation + @property + def cost_permutation(self): + # The cost_permutation of the underlying scalar SumfactKernel can be + # different for each kernel. + + # TODO! + cost_permutation = self.interfaces[0].cost_permutation + for i in self.interfaces: + assert i.cost_permutation == cost_permutation + + return cost_permutation + + # raise RuntimeError("cost_permutation should not be called on VectorSumfactKernelInput") + @property def stage(self): return 1 @@ -153,6 +167,17 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase): # need to load scalars into the SIMD vector. raise NotImplementedError("SIMD loads from scalars not implemented!") + def realize_input(self, shape, inames): + if self.direct_is_possible: + shape = permute_backward(shape, self.cost_permutation) + shape = permute_backward(shape, self.quadrature_permutation) + inames = permute_backward(inames, self.cost_permutation) + inames = permute_backward(inames, self.quadrature_permutation) + + return self.realize_direct(shape, inames) + else: + raise NotImplementedError("TODO") + @property def function_args(self): return sum((i.function_args for i in remove_duplicates(self.interfaces)), ()) -- GitLab