diff --git a/python/dune/codegen/loopy/transformations/remove_reductions.py b/python/dune/codegen/loopy/transformations/remove_reductions.py index c7848a315a4a2b1e48b5368ea3a6416f7e3ed13b..f53816126686fc8407c468503b4dd59694faab19 100644 --- a/python/dune/codegen/loopy/transformations/remove_reductions.py +++ b/python/dune/codegen/loopy/transformations/remove_reductions.py @@ -39,7 +39,7 @@ def remove_reduction(knl, match): within_inames=within_inames, id=id_accum, depends_on=frozenset((id_zero,) + tuple(depends_on)), - tags=('assignement',))) + tags=('assignment',))) knl = knl.copy(instructions=knl.instructions + instructions) diff --git a/python/dune/codegen/sumfact/accumulation.py b/python/dune/codegen/sumfact/accumulation.py index ae4a763063f0dc25303fc517356b58b02551eebc..09160ea5b0fcdcbad23621ff6a8d8c66267df666 100644 --- a/python/dune/codegen/sumfact/accumulation.py +++ b/python/dune/codegen/sumfact/accumulation.py @@ -547,6 +547,9 @@ def generate_accumulation_instruction(expr, visitor): from dune.codegen.sumfact.vectorization import attach_vectorization_info vsf = attach_vectorization_info(sf) + from dune.codegen.sumfact.transformations import attach_transformations + vsf = attach_transformations(sf, vsf) + # Make sure we have a buffer that we can set up the input with buffer = vsf.buffer if buffer is None: diff --git a/python/dune/codegen/sumfact/basis.py b/python/dune/codegen/sumfact/basis.py index c9b75eb445af01e9420e850650b12054dd0115f3..d3a2a190255ad2178fc44b5a691f49beb919b706 100644 --- a/python/dune/codegen/sumfact/basis.py +++ b/python/dune/codegen/sumfact/basis.py @@ -220,6 +220,9 @@ class SumfactBasisMixin(GenericBasisMixin): from dune.codegen.sumfact.vectorization import attach_vectorization_info vsf = attach_vectorization_info(sf) + from dune.codegen.sumfact.transformations import attach_transformations + vsf = attach_transformations(sf, vsf) + self.indices = None # If this sum factorization kernel was not used in the dry run we diff --git a/python/dune/codegen/sumfact/geometry.py b/python/dune/codegen/sumfact/geometry.py index d37558f5c40f42355b5fa99ed9a4ea25ab7ade28..689c2b60c30c06b821972b250627f19bd608095c 100644 --- a/python/dune/codegen/sumfact/geometry.py +++ b/python/dune/codegen/sumfact/geometry.py @@ -224,6 +224,9 @@ class SumfactMultiLinearGeometryMixin(GenericPDELabGeometryMixin): from dune.codegen.sumfact.vectorization import attach_vectorization_info vsf = attach_vectorization_info(sf) + from dune.codegen.sumfact.transformations import attach_transformations + vsf = attach_transformations(sf, vsf) + # If this sum factorization kernel was not used in the dry run we # just return 0 if vsf == 0: @@ -541,6 +544,9 @@ def _name_jacobian(i, j, restriction, visitor): from dune.codegen.sumfact.vectorization import attach_vectorization_info vsf = attach_vectorization_info(sf) + from dune.codegen.sumfact.transformations import attach_transformations + vsf = attach_transformations(sf, vsf) + # If this sum factorization kernel was not used in the dry run we # just return 0 if vsf == 0: diff --git a/python/dune/codegen/sumfact/realization.py b/python/dune/codegen/sumfact/realization.py index 5317268fdad3afb7dbcc406d17e7e145a24c41dd..3a628576bf5a130fe359bfb78967018a5f8a5211 100644 --- a/python/dune/codegen/sumfact/realization.py +++ b/python/dune/codegen/sumfact/realization.py @@ -30,7 +30,6 @@ from dune.codegen.sumfact.quadrature import quadrature_points_per_direction from dune.codegen.sumfact.symbolic import (SumfactKernel, VectorizedSumfactKernel, ) -from dune.codegen.sumfact.vectorization import attach_vectorization_info from dune.codegen.sumfact.accumulation import sumfact_iname from dune.codegen.loopy.target import dtype_floatingpoint from dune.codegen.loopy.vcl import ExplicitVCLCast @@ -279,8 +278,10 @@ def realize_sumfact_kernel_function(sf): }) # Register kernel transformations - from dune.codegen.sumfact.transformations import reorder_loops_in_tensor_contraction - transform(reorder_loops_in_tensor_contraction, 'lkji') + for trafo in sf.transformations: + transform(trafo.kernel_transformation()[0], + trafo.kernel_transformation()[1], + **trafo.kernel_transformation()[2]) # Construct a loopy kernel object from dune.codegen.pdelab.localoperator import extract_kernel_from_cache diff --git a/python/dune/codegen/sumfact/symbolic.py b/python/dune/codegen/sumfact/symbolic.py index 60641b0ce4af4ddef525a4c03a438c4ebe277c0c..3fddff6297e42cf807fe1bb0017b9b3dce7aa668 100644 --- a/python/dune/codegen/sumfact/symbolic.py +++ b/python/dune/codegen/sumfact/symbolic.py @@ -473,6 +473,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): insn_dep=frozenset(), interface=SumfactKernelInterfaceBase(), predicates=frozenset(), + transformations=(), ): """Create a sum factorization kernel @@ -503,6 +504,9 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): the transformation brings the next reduction index in the fastest position. + In einsum notation from numpy this can be written as three contractions + of the form: 'ij,jkl->kli' + It can make sense to permute the order of directions. If you have a small m_l (e.g. stage 1 on faces) it is better to do direction l first. This can be done by: @@ -533,6 +537,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): other. interface: An SumfactKernelInterfaceBase instance describing the input (stage 1) or output (stage 3) of the kernel + """ # Assert the inputs! assert isinstance(matrix_sequence, tuple) @@ -595,6 +600,10 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): name_quad_perm = "_qpperm_{}".format("".join(str(a) for a in self.interface.quadrature_permutation)) name = name + name_quad_perm + # Change name for applied transformations + for t in self.transformations: + name = name + '_' + t.name_appendix() + return name @property @@ -828,6 +837,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) vertical_width=1, buffer=None, insn_dep=frozenset(), + transformations=(), ): # Assert the input data structure assert isinstance(kernels, tuple) @@ -854,6 +864,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) buffer=buffer, insn_dep=insn_dep, vertical_width=vertical_width, + transformations=transformations, ) prim.Variable.__init__(self, "VecSUMFAC") @@ -880,9 +891,15 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) # @property def function_name(self): - return "sfimpl_{}{}".format("_".join(str(m) for m in self.matrix_sequence_quadrature_permuted), + name = "sfimpl_{}{}".format("_".join(str(m) for m in self.matrix_sequence_quadrature_permuted), self.interface.function_name_suffix) + # Change name for applied transformations + for t in self.transformations: + name = name + '_' + t.name_appendix() + + return name + @property def cache_key(self): """ The cache key that can be used in generation magic diff --git a/python/dune/codegen/sumfact/transformations.py b/python/dune/codegen/sumfact/transformations.py index 713d7fd991f23bee7d05d788db525a492de610a9..76cf241c90d6a85d6276b29d0bd670a63bedfd82 100644 --- a/python/dune/codegen/sumfact/transformations.py +++ b/python/dune/codegen/sumfact/transformations.py @@ -2,10 +2,14 @@ import loopy as lp import pymbolic.primitives as prim import islpy as isl +from dune.codegen.generation import get_global_context_value from dune.codegen.loopy.transformations.remove_reductions import remove_all_reductions from dune.codegen.pdelab.geometry import world_dimension def move_zero_assignment_up(knl, move_up_inames): + if len(move_up_inames) == 0: + return knl + # Find the instruction we want to move around cond = lp.match.Tagged('set_zero') instructions = lp.find_instructions(knl, cond) @@ -33,7 +37,9 @@ def move_zero_assignment_up(knl, move_up_inames): # index = dom.get_var_names(isl.dim_type.set).index(iname) # TODO: Noch unklar wie man die Loop bound aus isl rausbekommt. - todo_begin = str(dom).find(iname + ' <=') + len(iname) + 4 + todo_begin = str(dom).find(iname + ' =') + len(iname) + 3 + if todo_begin == len(iname) + 3 - 1: + todo_begin = str(dom).find(iname + ' <=') + len(iname) + 4 todo_end = todo_begin + str(dom)[todo_begin:].find(' ') loop_bound = int(str(dom)[todo_begin:todo_end]) + 1 break @@ -134,13 +140,16 @@ def reorder_loops_in_tensor_contraction(knl, iname_order): knl = move_zero_assignment_up(knl, current_move_up_inames) - # TODO - # - # Finde the number appended to the inames of this contraction by taking - # all the number starting from the last '_'. There is definitely a more - # elegant way to find that ;). - sf_iname_index = int(inames[0][len(inames[0]) - inames[0][::-1].find('_'):]) - reduction_iname = 'sf_red_{}'.format(sf_iname_index) + # TODO: There should be a better method than searching the string for + # 'sf_red'. Unfortunately there are sometimes Call instructions due to + # broadcasts. That makes different ways difficult. + import re + regex = re.compile('sf_red_([0-9]*)') + reduction_index = set(regex.findall(str(instr))) + assert len(reduction_index) == 1 + reduction_index = reduction_index.pop() + reduction_iname = 'sf_red_{}'.format(reduction_index) + prefered_iname_order = [] for i in inames: @@ -153,3 +162,46 @@ def reorder_loops_in_tensor_contraction(knl, iname_order): knl = lp.prioritize_loops(knl, prefered_iname_order) return knl + + +class SumfactKernelFunctionTransformation(object): + def kernel_transformation(self): + """Transformation that will be applied to sumfact kernel function""" + raise NotImplementedError + + def name_appendix(self): + """Name will be appended to name of sumfact kernel function""" + raise NotImplementedError + + +class LoopOrderTransformation(SumfactKernelFunctionTransformation): + def __init__(self, order): + self.order = order + + def kernel_transformation(self): + return (reorder_loops_in_tensor_contraction, self.order, {}) + + def name_appendix(self): + return 'looporder{}'.format(self.order) + + +def attach_transformations(sf, vsf): + if vsf == 0: + return 0 + + if get_global_context_value("dry_run") == None: + + # Example and test of such a transformation: + + # # Store transformation in sumfact kernel + # from dune.codegen.sumfact.transformations import LoopOrderTransformation + # trafo = LoopOrderTransformation('kjli') + # vsf = vsf.copy(transformations=vsf.transformations + (trafo,)) + + # # Map the original kernel to the new one + # from dune.codegen.sumfact.vectorization import _cache_vectorization_info + # _cache_vectorization_info(sf, vsf) + + return vsf + + return sf