From e3b1849da7bcae62d56235140db95e620023ee45 Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Tue, 6 Dec 2016 17:46:12 +0100 Subject: [PATCH] Combine vectorization strategies Only 6 orders of magnitude to go! --- python/dune/perftool/loopy/target.py | 1 + .../loopy/transformations/collect_rotate.py | 79 ++++++++++++++++--- python/dune/perftool/sumfact/sumfact.py | 5 +- python/dune/perftool/tools.py | 17 ++++ test/sumfact/poisson/poisson_3d_order1.mini | 1 - test/sumfact/poisson/poisson_3d_order2.mini | 1 - 6 files changed, 91 insertions(+), 13 deletions(-) diff --git a/python/dune/perftool/loopy/target.py b/python/dune/perftool/loopy/target.py index 055d06b4..3bb99b44 100644 --- a/python/dune/perftool/loopy/target.py +++ b/python/dune/perftool/loopy/target.py @@ -31,6 +31,7 @@ import cgen def _type_to_op_counter_type(name): return "oc::OpCounter<{}>".format(name) + @pt.memoize def numpy_to_cpp_dtype(key): _registry = {'float32': 'float', diff --git a/python/dune/perftool/loopy/transformations/collect_rotate.py b/python/dune/perftool/loopy/transformations/collect_rotate.py index c4e901e4..b011bf8b 100644 --- a/python/dune/perftool/loopy/transformations/collect_rotate.py +++ b/python/dune/perftool/loopy/transformations/collect_rotate.py @@ -1,26 +1,39 @@ """ A kernel transformation that precomputes quantities until a vector register is filled and then does vector computations """ -from dune.perftool.loopy.vcl import get_vcl_type_size +from dune.perftool.generation import (function_mangler, + include_file, + ) +from dune.perftool.loopy.vcl import get_vcl_type, get_vcl_type_size from dune.perftool.loopy.transformations.vectorview import (add_temporary_with_vector_view, add_vector_view, get_vector_view_name, ) +from dune.perftool.loopy.symbolic import substitute from dune.perftool.sumfact.quadrature import quadrature_inames -from dune.perftool.tools import get_pymbolic_basename +from dune.perftool.tools import get_pymbolic_basename, get_pymbolic_tag from loopy.kernel.creation import parse_domains from loopy.symbolic import pw_aff_to_expr from loopy.match import Tagged from pymbolic.mapper.dependency import DependencyMapper -from pymbolic.mapper.substitutor import substitute import pymbolic.primitives as prim import loopy as lp import numpy as np +@function_mangler +def rotate_function_mangler(knl, func, arg_dtypes): + if func == "transpose_reg": + # This is not 100% within the loopy philosoph, as we are + # passing the vector registers as references and have them + # changed. Loopy assumes this function to be read-only. + vcl = lp.types.NumpyType(get_vcl_type(np.float64, register_size=256)) + return lp.CallMangleInfo("transpose_reg", (), (vcl, vcl, vcl, vcl)) + + def collect_vector_data_rotate(knl): # # Process/Assert/Standardize the input @@ -48,6 +61,7 @@ def collect_vector_data_rotate(knl): new_insns = [] all_writers = [] + rotating = False # # Inspect the given instructions for dependent quantities @@ -123,11 +137,38 @@ def collect_vector_data_rotate(knl): else: raise NotImplementedError elif quantity in knl.temporary_variables: - # Add a vector view to this quantity - knl = add_vector_view(knl, quantity) - replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)), - (prim.Variable("vec_index"), prim.Variable(new_iname)), - ) + if all(get_pymbolic_tag(expr) == 'vector' for expr in quantities[quantity]): + # + # There is a vector quantity to be vectorized! That requires register rotation! + # + + # 1. Rotating the input data + knl = add_vector_view(knl, quantity) + include_file("dune/perftool/sumfact/transposereg.hh", filetag="operatorfile") + new_insns.append(lp.CallInstruction((), # assignees + prim.Call(prim.Variable("transpose_reg"), + tuple(prim.Subscript(prim.Variable(get_vector_view_name(quantity)), (prim.Variable("vec_index") + i, prim.Variable(new_iname))) for i in range(4))), + depends_on=frozenset({'continue_stmt'}), + within_inames=common_inames.union(inames).union(frozenset({new_iname})), + within_inames_is_final=True, + id="{}_rotate".format(quantity), + )) + + # Add substitution rules + for expr in quantities[quantity]: + rotating = True + assert isinstance(expr, prim.Subscript) + last_index = expr.index[-1] + assert last_index in tuple(range(4)) + replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)), + (prim.Variable("vec_index") + last_index, prim.Variable(new_iname)), + ) + else: + # Add a vector view to this quantity + knl = add_vector_view(knl, quantity) + replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)), + (prim.Variable("vec_index"), prim.Variable(new_iname)), + ) other_insns = [i for i in knl.instructions if i.id not in [j.id for j in insns + new_insns]] @@ -149,7 +190,7 @@ def collect_vector_data_rotate(knl): id="assign_vec_index", )) new_insns.append(lp.Assignment(prim.Variable("vec_index"), # assignee - prim.Sum((prim.Variable("vec_index"), 1)), # expression + prim.Sum((prim.Variable("vec_index"), vec_size if rotating else 1)), # expression within_inames=common_inames.union(inames), within_inames_is_final=True, depends_on=frozenset({Tagged("vec_write"), "assign_vec_index"}), @@ -210,8 +251,15 @@ def collect_vector_data_rotate(knl): knl = add_vector_view(knl, lhsname) lhsname = get_vector_view_name(lhsname) + if rotating: + assert isinstance(insn.assignee, prim.Subscript) + last_index = insn.assignee.index[-1] + assert last_index in tuple(range(4)) + else: + last_index = 0 + new_insns.append(lp.Assignment(prim.Subscript(prim.Variable(lhsname), - (prim.Variable("vec_index"), prim.Variable(new_iname)), + (prim.Variable("vec_index") + last_index, prim.Variable(new_iname)), ), substitute(insn.expression, replacemap_vec), depends_on=frozenset({"continue_stmt"}), @@ -223,4 +271,15 @@ def collect_vector_data_rotate(knl): ) ) + # Rotate back! + if rotating: + new_insns.append(lp.CallInstruction((), # assignees + prim.Call(prim.Variable("transpose_reg"), + tuple(prim.Subscript(prim.Variable(lhsname), (prim.Variable("vec_index") + i, prim.Variable(new_iname))) for i in range(4))), + depends_on=frozenset({Tagged("vec_write")}), + within_inames=common_inames.union(inames).union(frozenset({new_iname})), + within_inames_is_final=True, + id="{}_rotateback".format(lhsname), + )) + return knl.copy(instructions=new_insns + other_insns) diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py index 57d25588..d6f11c77 100644 --- a/python/dune/perftool/sumfact/sumfact.py +++ b/python/dune/perftool/sumfact/sumfact.py @@ -344,4 +344,7 @@ def sum_factorization_kernel(a_matrices, buf, stage, insn_dep=frozenset({}), add ) silenced_warning('read_no_write({})'.format(out)) - return prim.Variable(out), insn_dep + if isinstance(next(iter(a_matrices)), LargeAMatrix): + return lp.TaggedVariable(out, "vector"), insn_dep + else: + return prim.Variable(out), insn_dep diff --git a/python/dune/perftool/tools.py b/python/dune/perftool/tools.py index e9619b83..e6259e44 100644 --- a/python/dune/perftool/tools.py +++ b/python/dune/perftool/tools.py @@ -1,5 +1,7 @@ """ Some grabbag tools """ +from __future__ import absolute_import +import loopy as lp import pymbolic.primitives as prim @@ -42,3 +44,18 @@ def maybe_wrap_subscript(expr, indices): return prim.Subscript(expr, indices) else: return expr + + +def get_pymbolic_tag(expr): + assert isinstance(expr, prim.Expression) + + if isinstance(expr, lp.TaggedVariable): + return expr.tag + + if isinstance(expr, prim.Variable): + return None + + if isinstance(expr, prim.Subscript): + return get_pymbolic_tag(expr.aggregate) + + raise NotImplementedError("Cannot determine tag on {}".format(expr)) diff --git a/test/sumfact/poisson/poisson_3d_order1.mini b/test/sumfact/poisson/poisson_3d_order1.mini index a636bda7..d794e0dc 100644 --- a/test/sumfact/poisson/poisson_3d_order1.mini +++ b/test/sumfact/poisson/poisson_3d_order1.mini @@ -4,7 +4,6 @@ __exec_suffix = {diff_suffix}_{vecq_suffix}_{vecg_suffix} diff_suffix = numdiff, symdiff | expand num vecq_suffix = quadvec, nonquadvec | expand vec_q vecg_suffix = gradvec, nongradvec | expand vec_g -{vecq_suffix} == quadvec and {vecg_suffix} == gradvec | exclude cells = 8 8 8 extension = 1. 1. 1. diff --git a/test/sumfact/poisson/poisson_3d_order2.mini b/test/sumfact/poisson/poisson_3d_order2.mini index 24ec010e..c81eae55 100644 --- a/test/sumfact/poisson/poisson_3d_order2.mini +++ b/test/sumfact/poisson/poisson_3d_order2.mini @@ -4,7 +4,6 @@ __exec_suffix = {diff_suffix}_{vecq_suffix}_{vecg_suffix} diff_suffix = numdiff, symdiff | expand num vecq_suffix = quadvec, nonquadvec | expand vec_q vecg_suffix = gradvec, nongradvec | expand vec_g -{vecq_suffix} == quadvec and {vecg_suffix} == gradvec | exclude cells = 8 8 8 extension = 1. 1. 1. -- GitLab