From 444ce6cabd93a297934eea435be9f2a16e5e84d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20He=C3=9F?= <rene.hess@iwr.uni-heidelberg.de> Date: Sun, 10 Feb 2019 21:39:40 +0100 Subject: [PATCH] [skip ci][WIP] Tensor contraction reordering transformation This commit includes some ugly code that needs to be cleaned up! --- .../dune/codegen/sumfact/transformations.py | 155 ++++++++++++++++++ 1 file changed, 155 insertions(+) create mode 100644 python/dune/codegen/sumfact/transformations.py diff --git a/python/dune/codegen/sumfact/transformations.py b/python/dune/codegen/sumfact/transformations.py new file mode 100644 index 00000000..713d7fd9 --- /dev/null +++ b/python/dune/codegen/sumfact/transformations.py @@ -0,0 +1,155 @@ +import loopy as lp +import pymbolic.primitives as prim +import islpy as isl + +from dune.codegen.loopy.transformations.remove_reductions import remove_all_reductions +from dune.codegen.pdelab.geometry import world_dimension + +def move_zero_assignment_up(knl, move_up_inames): + # Find the instruction we want to move around + cond = lp.match.Tagged('set_zero') + instructions = lp.find_instructions(knl, cond) + move_iname_set = set(map(lambda x: prim.Variable(x), move_up_inames)) + instr = None + for i in instructions: + instr_iname_set = set(i.assignee.index_tuple) + if move_iname_set.issubset(instr_iname_set): + # There should be only one matching instruction + assert (instr==None) + instr = i + assert (instr!=None) + + # Remove it + knl = lp.remove_instructions(knl, set([instr.id])) + + # Create loop domains: In order to move it upwards we need to create + # additional loops + iname_appendix = '_move_up' + domains = knl.domains + for iname in move_up_inames: + # Find loop bound for this iname + for dom in domains: + if iname in dom.get_var_names(isl.dim_type.set): + # index = dom.get_var_names(isl.dim_type.set).index(iname) + + # TODO: Noch unklar wie man die Loop bound aus isl rausbekommt. + todo_begin = str(dom).find(iname + ' <=') + len(iname) + 4 + todo_end = todo_begin + str(dom)[todo_begin:].find(' ') + loop_bound = int(str(dom)[todo_begin:todo_end]) + 1 + break + + domain = "{{ [{0}] : 0<={0}<{1} }}".format(iname + iname_appendix, loop_bound) + domain = lp.kernel.creation.parse_domains(domain, {}) + domains = domains + domain + + # Create tuple of correct inames for usage in subscript below + indices = list(instr.assignee.index_tuple) + for i in range(len(indices)): + if indices[i].name in move_up_inames: + indices[i] = prim.Variable(indices[i].name + iname_appendix) + indices = tuple(indices) + + # The new instructions needs to lie within those inames + within_inames = [] + for i in indices: + within_inames.append(i.name) + + # Create new instruction + assignee = prim.Subscript(instr.assignee.aggregate, indices) + instructions = [] + instructions.append(instr.copy(assignee=assignee, + within_inames=frozenset(within_inames))) + knl = knl.copy(instructions=knl.instructions + instructions, + domains=domains) + + # Add dependency to inner assignment instructions + cond = lp.match.Tagged('assignment') + assignment_instructions = lp.find_instructions(knl, cond) + instr = None + for i in assignment_instructions: + instr_iname_set = set(i.assignee.index_tuple) + if move_iname_set.issubset(instr_iname_set): + # There should be only one matching instruction + assert (instr==None) + instr = i + + id_zero = instructions[0].id + cond = lp.match.Id(instr.id) + knl = lp.add_dependency(knl, cond, id_zero) + + return knl + + +def reorder_loops_in_tensor_contraction(knl, iname_order): + """Reorder the loop nest of the tensor contractions + + iname_order is a string that specifies the loop order. We use the following convention: + + Each contraction in the sum factorization kernel has the form 'ij,jkl->kli' + using einsum notation from numpy. iname_order should be a string like + 'iklj' if the loops should be done in order i, k, l, j. + + In the sum factorization kernel itself those inames are called: + + sf_out_inames_2_* : l + sf_out_inames_1_* : k + sf_out_inames_0_* : i + red_* : j + + where * represents the current direction (0,1,2 for 3D problems). + + TODO: Maybe also support a list of the inames above? + + TODO: Different order for different direction? Could make sense when we use + fastdg and a broadcast since the first contraction has a smaller input + matrix. + """ + dim = world_dimension() + # TODO: In principle there is no need to be dimension dependent. I'm just + # not sure how to pass the iname_order in the general case. This probably + # needs a rework anyway so I just do the 3D case first. + assert dim==3 + + knl = remove_all_reductions(knl) + + # TODO: Doc after rewrite + reduction_iname = 'j' + iname_dict = { 'l' : 'sf_out_inames_2', + 'k' : 'sf_out_inames_1', + 'i' : 'sf_out_inames_0', + 'j' : 'sf_red'} + reduction_index = iname_order.index(reduction_iname) + move_up_inames = list(map(lambda x: iname_dict[x], iname_order[reduction_index+1:])) + + # cond = lp.match.Tagged('set_zero') + cond = lp.match.Tagged('assignment') + instructions = lp.find_instructions(knl, cond) + for instr in instructions: + inames = tuple(map(lambda x: x.name, instr.assignee.index_tuple)) + current_move_up_inames = [] + for i in inames: + for j in move_up_inames: + if i.find(j) >= 0: + current_move_up_inames.append(i) + + knl = move_zero_assignment_up(knl, current_move_up_inames) + + # TODO + # + # Finde the number appended to the inames of this contraction by taking + # all the number starting from the last '_'. There is definitely a more + # elegant way to find that ;). + sf_iname_index = int(inames[0][len(inames[0]) - inames[0][::-1].find('_'):]) + reduction_iname = 'sf_red_{}'.format(sf_iname_index) + + prefered_iname_order = [] + for i in inames: + if i not in current_move_up_inames and i.find('vec') == -1: + prefered_iname_order.append(i) + prefered_iname_order.append(reduction_iname) + for i in current_move_up_inames: + prefered_iname_order.append(i) + prefered_iname_order = tuple(prefered_iname_order) + knl = lp.prioritize_loops(knl, prefered_iname_order) + + return knl -- GitLab