diff --git a/python/dune/codegen/sumfact/basis.py b/python/dune/codegen/sumfact/basis.py index a26bae5bfc61717f06fe1bbd4d0abc99f28240cf..8ab14a7d2f060bf0898bc76d58c1d3efc1e953c1 100644 --- a/python/dune/codegen/sumfact/basis.py +++ b/python/dune/codegen/sumfact/basis.py @@ -193,7 +193,7 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord): return prim.Subscript(prim.Variable(inp), inames + vec_iname) - def realize_direct_input(self, shape, inames, which=0): + def realize_direct_input(self, inames, shape, which=0): # If the input comes directly from a global data structure inames are # ordered x,y,z,... # diff --git a/python/dune/codegen/sumfact/realization.py b/python/dune/codegen/sumfact/realization.py index 010118b10a1c2f1d68109f16b8f5471d03da9224..a9bd0b8baddfbc7c269d92e360bb8ad3bd194910 100644 --- a/python/dune/codegen/sumfact/realization.py +++ b/python/dune/codegen/sumfact/realization.py @@ -192,7 +192,7 @@ def realize_sumfact_kernel_function(sf): input_inames = (k_expr,) + tuple(prim.Variable(j) for j in out_inames[1:]) if l == 0 and sf.stage == 1 and sf.interface.direct_is_possible: - input_summand = sf.interface.realize_direct_input(inp_shape, input_inames) + input_summand = sf.interface.realize_direct_input(input_inames, inp_shape) elif l == 0: # TODO: Simplify arguments! input_summand = sf.interface.realize_input(input_inames, @@ -241,17 +241,30 @@ def realize_sumfact_kernel_function(sf): insn_args["forced_iname_deps"] = frozenset({vec_iname[0].name}) insn_dep = sf.interface.realize_direct_output(matprod, output_inames, out_shape, **insn_args) elif l == len(matrix_sequence) - 1: + # Handle output of the last tensor contraction + # + # Stage 1: Reverse cost permutation, keep quadrature permutation + # Stage 3: Reverse cost and quadrature permuation output_shape = tuple(out_shape[1:]) + (out_shape[0],) - insn_dep = sf.interface.realize_output(matprod, - output_inames, - output_shape, - vec_iname, - vec_shape, - buffer, - ftags, - l, - **insn_args, - ) + output_inames = permute_backward(output_inames, sf.interface.cost_permutation) + output_shape = permute_backward(output_shape, sf.interface.cost_permutation) + if sf.interface.stage == 3: + output_inames = permute_backward(output_inames, sf.interface.quadrature_permutation) + output_shape = permute_backward(output_shape, sf.interface.quadrature_permutation) + + out = buffer.get_temporary("buff_step{}_out".format(l), + shape=output_shape + vec_shape, + dim_tags=ftags, + ) + + # Issue the reduction instruction that implements the multiplication + # at the same time store the instruction ID for the next instruction to depend on + insn_dep = frozenset({instruction(assignee=prim.Subscript(prim.Variable(out), output_inames + vec_iname), + expression=matprod, + **insn_args, + ) + }) + else: output_shape = tuple(out_shape[1:]) + (out_shape[0],) out = buffer.get_temporary("buff_step{}_out".format(l), diff --git a/python/dune/codegen/sumfact/symbolic.py b/python/dune/codegen/sumfact/symbolic.py index 96db55bb2f83744d56617cd1af028aaf98f41afd..3f4e91f4ce194a1d354dc98a09e4b56afffff270 100644 --- a/python/dune/codegen/sumfact/symbolic.py +++ b/python/dune/codegen/sumfact/symbolic.py @@ -47,7 +47,6 @@ class SumfactKernelInterfaceBase(object): Instructions this setup depends on. index : int Vectorization index, SIMD lane. - """ raise NotImplementedError @@ -56,14 +55,16 @@ class SumfactKernelInterfaceBase(object): This happens inside the sumfact kernel function. - Stage 1: Input is already permuted the right way in setup_input. + Stage 1 : Input is already permuted the right way in setup_input. - Stage 3: TODO -> Check permutation in accumulation.py + Stage 3 : TODO -> Check permutation in accumulation.py Parameters ---------- inames : tuple of pymbolic.primitives.Variable + Inames for accesing the input. Ordered according to permuted matrix sequence. shape : tuple of int + Shape of input. Ordered according to permuted matrix sequence. vec_iname : tuple of pymbolic.primitives.Variable In case of vectorized kernel provide vectorization iname. vec_shape : tuple of int @@ -75,66 +76,66 @@ class SumfactKernelInterfaceBase(object): """ raise NotImplementedError - def realize_direct_input(self, shape, inames, which=0): + def realize_direct_input(self, inames, shape, which=0): """Interpret the input of sumfact kernel function in the right way (fastdg) This happens inside the sumfact kernel function. - TODO: Add note about permutation - TODO: Document input arguments + Stage 1: The input to the sum factorization kernel will be ordered x, + y, z,... The shape and inames from this method come from the cost + permuted matrix sequence. Make sure to permute them back when accesing + the input. + + Parameters + ---------- + inames : tuple of pymbolic.primitives.Variable + Inames for accesing the input. Ordered according to permuted matrix sequence. + shape: tuple of int + Shape of input. Ordered according to permuted matrix sequence. + which : int + In case of VetcorizedSumfactKernel this might specify if the lower or upper + part of a the SIMD register is for this input. """ raise NotImplementedError def accumulate_output(self, sf, result, insn_dep, inames=None, additional_inames=()): """Generate accumulate instruction after a stage 3 sumfact kernel function (non fastdg) - This happens after the function call. + This happens after the function call. After stage 2 the result should + be ordered x, y, z,..., no permutations necessary. - TODO: Add note about permutation - TODO: Document input arguments + Parameters + ---------- + sf : SumfactKernel or VectorizedSumfactKernel + result : SumfactKernel or some pymbolic stuff + Result of a sum factorization + insn_dep : frozenset + Instructions this setup depends on. + inames : tuple of pymbolic.primitives.Variable + additional_inames : tuple of pymbolic.primitives.Variable + Additional inames the accumulation instruction depends on (eg. loop over + ansatz functions for jacobians). """ raise NotImplementedError - def realize_output(self, result, inames, shape, vec_iname, vec_shape, buffer, ftags, l, **args): - """Handle the output of the last tensor contraction in the sumfact kernel function the right way - - This happens inside the sumfact kernel function. - - Stage 1: Reverse cost permutation, output should only be quadrature - permuted. - - Stage 3: Reverse cost and quadrature permutation. The output will be - sorted according to dof/residual vector. - - TODO: Cleanup arguments - TODO: Document input arguments - """ - inames = permute_backward(inames, self.cost_permutation) - shape = permute_backward(shape, self.cost_permutation) - if self.stage == 3: - inames = permute_backward(inames, self.quadrature_permutation) - shape = permute_backward(shape, self.quadrature_permutation) - - out = buffer.get_temporary("buff_step{}_out".format(l), - shape=shape + vec_shape, - dim_tags=ftags, - ) - - # Issue the reduction instruction that implements the multiplication - # at the same time store the instruction ID for the next instruction to depend on - return frozenset({instruction(assignee=prim.Subscript(prim.Variable(out), inames + vec_iname), - expression=result, - **args - ) - }) - - def realize_direct_output(self, result, iname, shape, which=0, **args): + def realize_direct_output(self, result, iname, shape, which=0, **kwargs): """Accumulate results directly in the sumfact kernel function (fastdg) This happens inside the sumfact kernel function. TODO: Add note about permutation TODO: Document input arguments + + Parameters + ---------- + result : pymbolic stuff + Result of the sum factorization + iname : tuple of pymbolic.primitives.Variable + shape : tuple of ints + which : int + TODO Doc me! + **kwargs : + Key word arguments passed to loopy instruction """ raise NotImplementedError @@ -265,7 +266,7 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase): return prim.Subscript(prim.Variable(inp), inames + vec_iname) - def realize_direct_input(self, shape, inames): + def realize_direct_input(self, inames, shape): # TODO: vector_cost_permutation not used! # Check whether the input exhibits a favorable structure @@ -278,15 +279,15 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase): # All input coefficients use the exact same input coefficient. # We implement this by broadcasting it into a SIMD register return prim.Call(ExplicitVCLCast(dtype_floatingpoint()), - (self.interfaces[0].realize_direct_input(shape, inames),) + (self.interfaces[0].realize_direct_input(inames, shape),) ) elif len(total) == 2 and len(lower) == 1 and len(upper) == 1: # The lower and the upper part of the SIMD register use # the same input coefficient, we combine the SIMD register # from two shorter SIMD types return prim.Call(VCLLowerUpperLoad(dtype_floatingpoint()), - (self.interfaces[0].realize_direct_input(shape, inames), - self.interfaces[len(self.interfaces) // 2].realize_direct_input(shape, inames, which=1), + (self.interfaces[0].realize_direct_input(inames, shape), + self.interfaces[len(self.interfaces) // 2].realize_direct_input(inames, shape, which=1), ) ) else: @@ -1025,9 +1026,9 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) # that derivatives are the same. from copy import deepcopy sf_interface = deepcopy(sf.interface) - sf_interface._cost_permutation=None + sf_interface._cost_permutation = None k_interface = deepcopy(k.interface) - k_interface._cost_permutation=None + k_interface._cost_permutation = None if repr(sf_interface) == repr(k_interface): if tuple(mat.derivative for mat in sf.matrix_sequence_quadrature_permuted) == tuple(mat.derivative for mat in k.matrix_sequence_quadrature_permuted): return i