diff --git a/python/dune/perftool/loopy/transformations/collect_rotate.py b/python/dune/perftool/loopy/transformations/collect_rotate.py deleted file mode 100644 index 167de058f85b773978d4d6cef101101df72bf840..0000000000000000000000000000000000000000 --- a/python/dune/perftool/loopy/transformations/collect_rotate.py +++ /dev/null @@ -1,399 +0,0 @@ -""" A kernel transformation that precomputes quantities until a vector register -is filled and then does vector computations """ - -from dune.perftool.generation import (function_mangler, - include_file, - loopy_class_member, - ) -from dune.perftool.loopy.vcl import get_vcl_type, get_vcl_type_size -from dune.perftool.loopy.transformations.vectorview import (add_temporary_with_vector_view, - add_vector_view, - get_vector_view_name, - ) -from dune.perftool.loopy.symbolic import substitute -from dune.perftool.sumfact.quadrature import quadrature_inames -from dune.perftool.tools import get_pymbolic_basename, get_pymbolic_tag, ceildiv -from dune.perftool.options import get_option - -from loopy.kernel.creation import parse_domains -from loopy.symbolic import pw_aff_to_expr -from loopy.match import Tagged - -from loopy.symbolic import DependencyMapper -from pytools import product - -import pymbolic.primitives as prim -import loopy as lp -import numpy as np -import re - - -class TransposeReg(lp.symbolic.FunctionIdentifier): - def __init__(self, - horizontal=1, - vertical=1, - ): - self.horizontal = horizontal - self.vertical = vertical - - def __getinitargs__(self): - return (self.horizontal, self.vertical) - - @property - def name(self): - return "transpose_reg" - - -@function_mangler -def rotate_function_mangler(knl, func, arg_dtypes): - if isinstance(func, TransposeReg): - # This is not 100% within the loopy philosophy, as we are - # passing the vector registers as references and have them - # changed. Loopy assumes this function to be read-only. - include_file("dune/perftool/sumfact/transposereg.hh", filetag="operatorfile") - vcl = lp.types.NumpyType(get_vcl_type(np.float64, vector_width=func.horizontal * func.vertical)) - return lp.CallMangleInfo(func.name, (), (vcl,) * func.horizontal) - - -class VectorIndices(object): - def __init__(self): - self.needed = set() - - def get(self, increment): - name = "vec_index_inc{}".format(increment) - self.needed.add((name, increment)) - return prim.Variable(name) - - -def collect_vector_data_rotate(knl): - # - # Process/Assert/Standardize the input - # - insns = [i for i in lp.find_instructions(knl, lp.match.Tagged("quadvec"))] - if not insns: - return knl - inames = quadrature_inames() - - # Analyse the inames of the given instructions and identify inames - # that they all have in common. Those inames will also be iname dependencies - # of inserted instructions. - common_inames = frozenset([]).union(*(insn.within_inames for insn in insns)) - frozenset(inames) - - # Determine the vector lane width - # TODO infer the numpy type here - vec_size = get_vcl_type_size(np.float64) - vector_indices = VectorIndices() - - # Add an iname to the kernel which will be used for vectorization - new_iname = "quad_vec_{}".format("_".join(inames)) - domain = "{{ [{0}] : 0<={0}<{1} }}".format(new_iname, str(vec_size)) - domain = parse_domains(domain, {}) - knl = knl.copy(domains=knl.domains + domain) - knl = lp.tag_inames(knl, [(new_iname, "vec")]) - - new_insns = [] - all_writers = [] - - # - # Inspect the given instructions for dependent quantities - # - - quantities = {} - for insn in insns: - for expr in DependencyMapper()(insn.expression): - basename = get_pymbolic_basename(expr) - quantities.setdefault(basename, frozenset()) - quantities[basename] = quantities[basename].union(frozenset([expr])) - - # Add vector size buffers for all these quantities - replacemap_vec = {} - replacemap_arr = {} - for quantity in quantities: - quantity_exprs = quantities[quantity] - - # Check whether there is an instruction that writes this quantity within - # the given inames. If so, we need a buffer array. - iname_match = lp.match.And(tuple(lp.match.Iname(i) for i in inames)) - write_match = lp.match.Writes(quantity) - match = lp.match.And((iname_match, write_match)) - write_insns = lp.find_instructions(knl, match) - all_writers.extend([i.id for i in write_insns]) - - if write_insns: - # Determine the shape of the quantity - shape = knl.temporary_variables[quantity].shape - - arrname = quantity + '_buffered_arr' - knl = add_temporary_with_vector_view(knl, - arrname, - dtype=np.float64, - shape=shape + (vec_size,), - dim_tags=",".join("c" for i in range(len(shape) + 1)), - base_storage=quantity + '_base_storage', - scope=lp.temp_var_scope.PRIVATE, - ) - - def get_quantity_subscripts(e, zero=False): - if isinstance(e, prim.Subscript): - index = e.index - if isinstance(index, tuple): - return index - else: - return (index,) - else: - if zero: - return (0,) - else: - return () - - for expr in quantity_exprs: - replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(arrname)), get_quantity_subscripts(expr, zero=True) + (prim.Variable(new_iname),)) - - while write_insns: - insn = write_insns.pop() - - if isinstance(insn, lp.Assignment): - assignee = prim.Subscript(prim.Variable(arrname), get_quantity_subscripts(insn.assignee) + (prim.Variable('rotate_index'),)) - new_insns.append(insn.copy(assignee=assignee, - depends_on=insn.depends_on.union(frozenset({lp.match.Tagged("sumfact_stage1")})), - ) - ) - for e in quantity_exprs: - replacemap_arr[e] = prim.Subscript(prim.Variable(arrname), get_quantity_subscripts(e) + (prim.Variable('rotate_index'),)) - - elif isinstance(insn, lp.CInstruction): - # This entire code path should go away as we either - # * switch CInstructions to implicit iname assignments (see https://github.com/inducer/loopy/issues/55) - # * switch to doing geometry stuff for sum factorization ourselves - if len(shape) == 0: - # Rip apart the code and change the assignee - assignee, expression = insn.code.split("=") - assignee = prim.Subscript(prim.Variable(arrname), (prim.Variable('rotate_index'),)) - code = "{} ={}".format(str(assignee), expression) - new_insns.append(insn.copy(code=code, - depends_on_is_final=True, - )) - else: - # This is a *very* unfortunate code path - - # Get inames to assign to the vector buffer - cinsn_inames = tuple("{}_assign_{}".format(quantity, i) for i in range(len(shape))) - domains = frozenset("{{ [{0}] : 0<={0}<{1} }}".format(iname, shape[i]) for i, iname in enumerate(cinsn_inames)) - for dom in domains: - domain = parse_domains(dom, {}) - knl = knl.copy(domains=knl.domains + domain) - - # We keep the old writing instructions - new_insns.append(insn) - - # and write a new one - cinsn_id = "{}_assign_id".format(quantity) - new_insns.append(lp.Assignment(prim.Subscript(prim.Variable(arrname), tuple(prim.Variable(i) for i in cinsn_inames) + (prim.Variable('rotate_index'),)), - prim.Subscript(prim.Variable(quantity), tuple(prim.Variable(i) for i in cinsn_inames)), - within_inames=common_inames.union(inames).union(frozenset(cinsn_inames)), - within_inames_is_final=True, - depends_on=frozenset({lp.match.Writes(quantity)}), - id=cinsn_id, - )) - all_writers.append(cinsn_id) - else: - raise NotImplementedError - elif quantity in knl.temporary_variables: - tag, = set(get_pymbolic_tag(expr) for expr in quantity_exprs) - if tag is not None and tag.startswith('vecsumfac'): - # Extract information from the tag - horizontal, vertical = tuple(int(i) for i in re.match("vecsumfac_h(.*)_v(.*)", tag).groups()) - - # - # There is a vector quantity to be vectorized! That requires register rotation! - # - - # 1. Rotating the input data - knl = add_vector_view(knl, quantity, flatview=True) - if horizontal > 1: - new_insns.append(lp.CallInstruction((), # assignees - prim.Call(TransposeReg(vertical=vertical, horizontal=horizontal), - tuple(prim.Subscript(prim.Variable(get_vector_view_name(quantity)), - (vector_indices.get(horizontal) + i, prim.Variable(new_iname))) - for i in range(horizontal))), - depends_on=frozenset({'continue_stmt'}), - within_inames=common_inames.union(inames).union(frozenset({new_iname})), - within_inames_is_final=True, - id="{}_rotate".format(quantity), - )) - - # Add substitution rules - for expr in quantity_exprs: - assert isinstance(expr, prim.Subscript) - last_index = expr.index[-1] // vertical - replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)), - (vector_indices.get(horizontal) + last_index, prim.Variable(new_iname)), - ) - elif tag is not None and tag == 'sumfac': - # Add a vector view to this quantity - expr, = quantity_exprs - knl = add_vector_view(knl, quantity, flatview=True) - replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)), - (vector_indices.get(1), prim.Variable(new_iname)), - ) - elif quantity in [a.name for a in knl.args]: - arg, = [a for a in knl.args if a.name == quantity] - tags = set(get_pymbolic_tag(expr) for expr in quantity_exprs) - if tags and tags.pop() == "operator_precomputed": - expr, = quantity_exprs - shape=(ceildiv(product(s for s in arg.shape), vec_size), vec_size) - name = loopy_class_member(quantity, - shape=shape, - dim_tags="f,vec", - potentially_vectorized=True, - classtag="operator", - dtype=np.float64, - ) - knl = knl.copy(args=knl.args + [lp.GlobalArg(name, shape=shape, dim_tags="c,vec", dtype=np.float64)]) - replacemap_vec[expr] = prim.Subscript(prim.Variable(name), - (vector_indices.get(1), prim.Variable(new_iname)), - ) - - new_insns = [i.copy(expression=substitute(i.expression, replacemap_arr)) for i in new_insns] - - other_insns = [i for i in knl.instructions if i.id not in [j.id for j in insns + new_insns]] - - # - # Add three counter variables to the kernel - # - temporaries = knl.temporary_variables - - temporaries['total_index'] = lp.TemporaryVariable('total_index', - dtype=np.int32, - scope=lp.temp_var_scope.PRIVATE, - ) - new_insns.append(lp.Assignment(prim.Variable("total_index"), # assignee - 0, # expression - within_inames=common_inames, - within_inames_is_final=True, - id="assign_total_index", - )) - new_insns.append(lp.Assignment(prim.Variable("total_index"), # assignee - prim.Sum((prim.Variable("total_index"), 1)), # expression - within_inames=common_inames.union(inames), - within_inames_is_final=True, - depends_on=frozenset(all_writers).union(frozenset({"assign_total_index"})), - depends_on_is_final=True, - id="update_total_index", - )) - - # Insert a rotating index, that counts 0 , .. , vecsize - 1 - temporaries['rotate_index'] = lp.TemporaryVariable('rotate_index', # name - dtype=np.int32, - scope=lp.temp_var_scope.PRIVATE, - ) - new_insns.append(lp.Assignment(prim.Variable("rotate_index"), # assignee - 0, # expression - within_inames=common_inames, - within_inames_is_final=True, - id="assign_rotate_index", - )) - new_insns.append(lp.Assignment(prim.Variable("rotate_index"), # assignee - prim.Remainder(prim.Sum((prim.Variable("rotate_index"), 1)), vec_size), # expression - within_inames=common_inames.union(inames), - within_inames_is_final=True, - depends_on=frozenset(all_writers).union(frozenset({"assign_rotate_index"})), - depends_on_is_final=True, - id="update_rotate_index", - )) - - knl = knl.copy(temporary_variables=temporaries) - - # - # Add a continue statement depending on the rotate index - # - - # Determine the condition for the continue statement - upper_bound = prim.Product(tuple(pw_aff_to_expr(knl.get_iname_bounds(i).size) for i in inames)) - total_check = prim.Comparison(prim.Variable("total_index"), "<", upper_bound) - rotate_check = prim.Comparison(prim.Variable("rotate_index"), "!=", 0) - check = prim.LogicalAnd((rotate_check, total_check)) - - # Insert the 'continue' statement - new_insns.append(lp.CInstruction((), # iname exprs that the code needs access to - "continue;", # the code - predicates=frozenset({check}), - depends_on=frozenset({"update_rotate_index", "update_total_index"}).union(frozenset(all_writers)), - depends_on_is_final=True, - within_inames=common_inames.union(inames), - within_inames_is_final=True, - id="continue_stmt", - )) - - # - # Reconstruct the compute instructions - # - - for insn in insns: - # Get a vector view of the lhs expression - lhsname = get_pymbolic_basename(insn.assignee) - knl = add_vector_view(knl, lhsname, pad_to=vec_size, flatview=True) - lhsname = get_vector_view_name(lhsname) - rotating = "gradvec" in insn.tags - - if rotating: - assert isinstance(insn.assignee, prim.Subscript) - tag = get_pymbolic_tag(insn.assignee) - horizontal, vertical = tuple(int(i) for i in re.match("vecsumfac_h(.*)_v(.*)", tag).groups()) - if horizontal > 1: - last_index = insn.assignee.index[-1] // vertical - else: - last_index = 0 - else: - last_index = 0 - horizontal = 1 - - new_insns.append(lp.Assignment(prim.Subscript(prim.Variable(lhsname), - (vector_indices.get(horizontal) + last_index, prim.Variable(new_iname)), - ), - substitute(insn.expression, replacemap_vec), - depends_on=frozenset({"continue_stmt"}), - depends_on_is_final=True, - within_inames=common_inames.union(frozenset(inames + (new_iname,))), - within_inames_is_final=True, - id=insn.id, - tags=frozenset({"vec_write"}) - ) - ) - - # Rotate back! - if rotating and "{}_rotateback".format(lhsname) not in [i.id for i in new_insns] and horizontal > 1: - new_insns.append(lp.CallInstruction((), # assignees - prim.Call(TransposeReg(horizontal=horizontal, vertical=vertical), - tuple(prim.Subscript(prim.Variable(lhsname), - (vector_indices.get(horizontal) + i, prim.Variable(new_iname))) - for i in range(horizontal))), - depends_on=frozenset({Tagged("vec_write")}), - within_inames=common_inames.union(inames).union(frozenset({new_iname})), - within_inames_is_final=True, - id="{}_rotateback".format(lhsname), - )) - - # Add the necessary vector indices - for name, increment in vector_indices.needed: - temporaries[name] = lp.TemporaryVariable(name, # name - dtype=np.int32, - scope=lp.temp_var_scope.PRIVATE, - ) - new_insns.append(lp.Assignment(prim.Variable(name), # assignee - 0, # expression - within_inames=common_inames, - within_inames_is_final=True, - id="assign_{}".format(name), - )) - new_insns.append(lp.Assignment(prim.Variable(name), # assignee - prim.Sum((prim.Variable(name), increment)), # expression - within_inames=common_inames.union(inames), - within_inames_is_final=True, - depends_on=frozenset({Tagged("vec_write"), "assign_{}".format(name)}), - depends_on_is_final=True, - id="update_{}".format(name), - )) - - from loopy.kernel.creation import resolve_dependencies - return resolve_dependencies(knl.copy(instructions=new_insns + other_insns)) diff --git a/python/dune/perftool/loopy/transformations/collect_precompute.py b/python/dune/perftool/loopy/transformations/vectorize_quad.py similarity index 99% rename from python/dune/perftool/loopy/transformations/collect_precompute.py rename to python/dune/perftool/loopy/transformations/vectorize_quad.py index 7cdf72bac12b1071a24958745fd6325fc27cc1f9..83eb5c98b1579cbd2e38ae6d95baf684280e6284 100644 --- a/python/dune/perftool/loopy/transformations/collect_precompute.py +++ b/python/dune/perftool/loopy/transformations/vectorize_quad.py @@ -82,7 +82,7 @@ class AntiPatternRemover(IdentityMapper): return IdentityMapper.map_floor_div(self, expr) -def collect_vector_data_precompute(knl): +def vectorize_quadrature_loop(knl): # # Process/Assert/Standardize the input # diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py index d0b1d450d52a6ac6aad3ee892afdce4bdb64a899..3f24963354cedab260f00abb6af6ea53533cc0c2 100644 --- a/python/dune/perftool/pdelab/localoperator.py +++ b/python/dune/perftool/pdelab/localoperator.py @@ -502,10 +502,8 @@ def extract_kernel_from_cache(tag, wrap_in_cgen=True): # Maybe apply vectorization strategies if get_option("vectorize_quad"): if get_option("sumfact"): - from dune.perftool.loopy.transformations.collect_rotate import collect_vector_data_rotate - from dune.perftool.loopy.transformations.collect_precompute import collect_vector_data_precompute - kernel = collect_vector_data_precompute(kernel) -# kernel = collect_vector_data_rotate(kernel) + from dune.perftool.loopy.transformations.vectorize_quad import vectorize_quadrature_loop + kernel = vectorize_quadrature_loop(kernel) else: raise NotImplementedError("Only vectorizing sumfactorized code right now!")