diff --git a/python/dune/perftool/loopy/target.py b/python/dune/perftool/loopy/target.py index 055d06b4ba21252fc1aaaa99b5246be3f9bfd850..3bb99b44f53321086da442f94bcaf754cb1980ac 100644 --- a/python/dune/perftool/loopy/target.py +++ b/python/dune/perftool/loopy/target.py @@ -31,6 +31,7 @@ import cgen def _type_to_op_counter_type(name): return "oc::OpCounter<{}>".format(name) + @pt.memoize def numpy_to_cpp_dtype(key): _registry = {'float32': 'float', diff --git a/python/dune/perftool/loopy/transformations/collect_rotate.py b/python/dune/perftool/loopy/transformations/collect_rotate.py index c4e901e4b02901e597182a486b4a7551d0938288..b011bf8bb72c021f8da80880cf3e4991d1207141 100644 --- a/python/dune/perftool/loopy/transformations/collect_rotate.py +++ b/python/dune/perftool/loopy/transformations/collect_rotate.py @@ -1,26 +1,39 @@ """ A kernel transformation that precomputes quantities until a vector register is filled and then does vector computations """ -from dune.perftool.loopy.vcl import get_vcl_type_size +from dune.perftool.generation import (function_mangler, + include_file, + ) +from dune.perftool.loopy.vcl import get_vcl_type, get_vcl_type_size from dune.perftool.loopy.transformations.vectorview import (add_temporary_with_vector_view, add_vector_view, get_vector_view_name, ) +from dune.perftool.loopy.symbolic import substitute from dune.perftool.sumfact.quadrature import quadrature_inames -from dune.perftool.tools import get_pymbolic_basename +from dune.perftool.tools import get_pymbolic_basename, get_pymbolic_tag from loopy.kernel.creation import parse_domains from loopy.symbolic import pw_aff_to_expr from loopy.match import Tagged from pymbolic.mapper.dependency import DependencyMapper -from pymbolic.mapper.substitutor import substitute import pymbolic.primitives as prim import loopy as lp import numpy as np +@function_mangler +def rotate_function_mangler(knl, func, arg_dtypes): + if func == "transpose_reg": + # This is not 100% within the loopy philosoph, as we are + # passing the vector registers as references and have them + # changed. Loopy assumes this function to be read-only. + vcl = lp.types.NumpyType(get_vcl_type(np.float64, register_size=256)) + return lp.CallMangleInfo("transpose_reg", (), (vcl, vcl, vcl, vcl)) + + def collect_vector_data_rotate(knl): # # Process/Assert/Standardize the input @@ -48,6 +61,7 @@ def collect_vector_data_rotate(knl): new_insns = [] all_writers = [] + rotating = False # # Inspect the given instructions for dependent quantities @@ -123,11 +137,38 @@ def collect_vector_data_rotate(knl): else: raise NotImplementedError elif quantity in knl.temporary_variables: - # Add a vector view to this quantity - knl = add_vector_view(knl, quantity) - replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)), - (prim.Variable("vec_index"), prim.Variable(new_iname)), - ) + if all(get_pymbolic_tag(expr) == 'vector' for expr in quantities[quantity]): + # + # There is a vector quantity to be vectorized! That requires register rotation! + # + + # 1. Rotating the input data + knl = add_vector_view(knl, quantity) + include_file("dune/perftool/sumfact/transposereg.hh", filetag="operatorfile") + new_insns.append(lp.CallInstruction((), # assignees + prim.Call(prim.Variable("transpose_reg"), + tuple(prim.Subscript(prim.Variable(get_vector_view_name(quantity)), (prim.Variable("vec_index") + i, prim.Variable(new_iname))) for i in range(4))), + depends_on=frozenset({'continue_stmt'}), + within_inames=common_inames.union(inames).union(frozenset({new_iname})), + within_inames_is_final=True, + id="{}_rotate".format(quantity), + )) + + # Add substitution rules + for expr in quantities[quantity]: + rotating = True + assert isinstance(expr, prim.Subscript) + last_index = expr.index[-1] + assert last_index in tuple(range(4)) + replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)), + (prim.Variable("vec_index") + last_index, prim.Variable(new_iname)), + ) + else: + # Add a vector view to this quantity + knl = add_vector_view(knl, quantity) + replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)), + (prim.Variable("vec_index"), prim.Variable(new_iname)), + ) other_insns = [i for i in knl.instructions if i.id not in [j.id for j in insns + new_insns]] @@ -149,7 +190,7 @@ def collect_vector_data_rotate(knl): id="assign_vec_index", )) new_insns.append(lp.Assignment(prim.Variable("vec_index"), # assignee - prim.Sum((prim.Variable("vec_index"), 1)), # expression + prim.Sum((prim.Variable("vec_index"), vec_size if rotating else 1)), # expression within_inames=common_inames.union(inames), within_inames_is_final=True, depends_on=frozenset({Tagged("vec_write"), "assign_vec_index"}), @@ -210,8 +251,15 @@ def collect_vector_data_rotate(knl): knl = add_vector_view(knl, lhsname) lhsname = get_vector_view_name(lhsname) + if rotating: + assert isinstance(insn.assignee, prim.Subscript) + last_index = insn.assignee.index[-1] + assert last_index in tuple(range(4)) + else: + last_index = 0 + new_insns.append(lp.Assignment(prim.Subscript(prim.Variable(lhsname), - (prim.Variable("vec_index"), prim.Variable(new_iname)), + (prim.Variable("vec_index") + last_index, prim.Variable(new_iname)), ), substitute(insn.expression, replacemap_vec), depends_on=frozenset({"continue_stmt"}), @@ -223,4 +271,15 @@ def collect_vector_data_rotate(knl): ) ) + # Rotate back! + if rotating: + new_insns.append(lp.CallInstruction((), # assignees + prim.Call(prim.Variable("transpose_reg"), + tuple(prim.Subscript(prim.Variable(lhsname), (prim.Variable("vec_index") + i, prim.Variable(new_iname))) for i in range(4))), + depends_on=frozenset({Tagged("vec_write")}), + within_inames=common_inames.union(inames).union(frozenset({new_iname})), + within_inames_is_final=True, + id="{}_rotateback".format(lhsname), + )) + return knl.copy(instructions=new_insns + other_insns) diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py index 57d255883b9ea164f07c8df7e689a5a6960538d4..d6f11c77b5055022f868c6b423921aa11d9cf1f5 100644 --- a/python/dune/perftool/sumfact/sumfact.py +++ b/python/dune/perftool/sumfact/sumfact.py @@ -344,4 +344,7 @@ def sum_factorization_kernel(a_matrices, buf, stage, insn_dep=frozenset({}), add ) silenced_warning('read_no_write({})'.format(out)) - return prim.Variable(out), insn_dep + if isinstance(next(iter(a_matrices)), LargeAMatrix): + return lp.TaggedVariable(out, "vector"), insn_dep + else: + return prim.Variable(out), insn_dep diff --git a/python/dune/perftool/tools.py b/python/dune/perftool/tools.py index e9619b83288b92a49294fcd9e99ae1b069c9ceb1..e6259e44b0cd4d030587741f8361e804b8fd3e55 100644 --- a/python/dune/perftool/tools.py +++ b/python/dune/perftool/tools.py @@ -1,5 +1,7 @@ """ Some grabbag tools """ +from __future__ import absolute_import +import loopy as lp import pymbolic.primitives as prim @@ -42,3 +44,18 @@ def maybe_wrap_subscript(expr, indices): return prim.Subscript(expr, indices) else: return expr + + +def get_pymbolic_tag(expr): + assert isinstance(expr, prim.Expression) + + if isinstance(expr, lp.TaggedVariable): + return expr.tag + + if isinstance(expr, prim.Variable): + return None + + if isinstance(expr, prim.Subscript): + return get_pymbolic_tag(expr.aggregate) + + raise NotImplementedError("Cannot determine tag on {}".format(expr)) diff --git a/test/sumfact/poisson/poisson_3d_order1.mini b/test/sumfact/poisson/poisson_3d_order1.mini index a636bda7fe69f91ebfc99154ecedf2b90ab89d07..d794e0dca0f7f6b2147b22893a98339e33e3ee52 100644 --- a/test/sumfact/poisson/poisson_3d_order1.mini +++ b/test/sumfact/poisson/poisson_3d_order1.mini @@ -4,7 +4,6 @@ __exec_suffix = {diff_suffix}_{vecq_suffix}_{vecg_suffix} diff_suffix = numdiff, symdiff | expand num vecq_suffix = quadvec, nonquadvec | expand vec_q vecg_suffix = gradvec, nongradvec | expand vec_g -{vecq_suffix} == quadvec and {vecg_suffix} == gradvec | exclude cells = 8 8 8 extension = 1. 1. 1. diff --git a/test/sumfact/poisson/poisson_3d_order2.mini b/test/sumfact/poisson/poisson_3d_order2.mini index 24ec010e7beef46dea74bffd318411118a2c727d..c81eae557caa00dbd1268b54292fee936cd29213 100644 --- a/test/sumfact/poisson/poisson_3d_order2.mini +++ b/test/sumfact/poisson/poisson_3d_order2.mini @@ -4,7 +4,6 @@ __exec_suffix = {diff_suffix}_{vecq_suffix}_{vecg_suffix} diff_suffix = numdiff, symdiff | expand num vecq_suffix = quadvec, nonquadvec | expand vec_q vecg_suffix = gradvec, nongradvec | expand vec_g -{vecq_suffix} == quadvec and {vecg_suffix} == gradvec | exclude cells = 8 8 8 extension = 1. 1. 1.