Skip to content
Snippets Groups Projects
Commit e3b1849d authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Combine vectorization strategies

Only 6 orders of magnitude to go!
parent ab9547c7
No related branches found
No related tags found
No related merge requests found
......@@ -31,6 +31,7 @@ import cgen
def _type_to_op_counter_type(name):
return "oc::OpCounter<{}>".format(name)
@pt.memoize
def numpy_to_cpp_dtype(key):
_registry = {'float32': 'float',
......
""" A kernel transformation that precomputes quantities until a vector register
is filled and then does vector computations """
from dune.perftool.loopy.vcl import get_vcl_type_size
from dune.perftool.generation import (function_mangler,
include_file,
)
from dune.perftool.loopy.vcl import get_vcl_type, get_vcl_type_size
from dune.perftool.loopy.transformations.vectorview import (add_temporary_with_vector_view,
add_vector_view,
get_vector_view_name,
)
from dune.perftool.loopy.symbolic import substitute
from dune.perftool.sumfact.quadrature import quadrature_inames
from dune.perftool.tools import get_pymbolic_basename
from dune.perftool.tools import get_pymbolic_basename, get_pymbolic_tag
from loopy.kernel.creation import parse_domains
from loopy.symbolic import pw_aff_to_expr
from loopy.match import Tagged
from pymbolic.mapper.dependency import DependencyMapper
from pymbolic.mapper.substitutor import substitute
import pymbolic.primitives as prim
import loopy as lp
import numpy as np
@function_mangler
def rotate_function_mangler(knl, func, arg_dtypes):
if func == "transpose_reg":
# This is not 100% within the loopy philosoph, as we are
# passing the vector registers as references and have them
# changed. Loopy assumes this function to be read-only.
vcl = lp.types.NumpyType(get_vcl_type(np.float64, register_size=256))
return lp.CallMangleInfo("transpose_reg", (), (vcl, vcl, vcl, vcl))
def collect_vector_data_rotate(knl):
#
# Process/Assert/Standardize the input
......@@ -48,6 +61,7 @@ def collect_vector_data_rotate(knl):
new_insns = []
all_writers = []
rotating = False
#
# Inspect the given instructions for dependent quantities
......@@ -123,11 +137,38 @@ def collect_vector_data_rotate(knl):
else:
raise NotImplementedError
elif quantity in knl.temporary_variables:
# Add a vector view to this quantity
knl = add_vector_view(knl, quantity)
replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)),
(prim.Variable("vec_index"), prim.Variable(new_iname)),
)
if all(get_pymbolic_tag(expr) == 'vector' for expr in quantities[quantity]):
#
# There is a vector quantity to be vectorized! That requires register rotation!
#
# 1. Rotating the input data
knl = add_vector_view(knl, quantity)
include_file("dune/perftool/sumfact/transposereg.hh", filetag="operatorfile")
new_insns.append(lp.CallInstruction((), # assignees
prim.Call(prim.Variable("transpose_reg"),
tuple(prim.Subscript(prim.Variable(get_vector_view_name(quantity)), (prim.Variable("vec_index") + i, prim.Variable(new_iname))) for i in range(4))),
depends_on=frozenset({'continue_stmt'}),
within_inames=common_inames.union(inames).union(frozenset({new_iname})),
within_inames_is_final=True,
id="{}_rotate".format(quantity),
))
# Add substitution rules
for expr in quantities[quantity]:
rotating = True
assert isinstance(expr, prim.Subscript)
last_index = expr.index[-1]
assert last_index in tuple(range(4))
replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)),
(prim.Variable("vec_index") + last_index, prim.Variable(new_iname)),
)
else:
# Add a vector view to this quantity
knl = add_vector_view(knl, quantity)
replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)),
(prim.Variable("vec_index"), prim.Variable(new_iname)),
)
other_insns = [i for i in knl.instructions if i.id not in [j.id for j in insns + new_insns]]
......@@ -149,7 +190,7 @@ def collect_vector_data_rotate(knl):
id="assign_vec_index",
))
new_insns.append(lp.Assignment(prim.Variable("vec_index"), # assignee
prim.Sum((prim.Variable("vec_index"), 1)), # expression
prim.Sum((prim.Variable("vec_index"), vec_size if rotating else 1)), # expression
within_inames=common_inames.union(inames),
within_inames_is_final=True,
depends_on=frozenset({Tagged("vec_write"), "assign_vec_index"}),
......@@ -210,8 +251,15 @@ def collect_vector_data_rotate(knl):
knl = add_vector_view(knl, lhsname)
lhsname = get_vector_view_name(lhsname)
if rotating:
assert isinstance(insn.assignee, prim.Subscript)
last_index = insn.assignee.index[-1]
assert last_index in tuple(range(4))
else:
last_index = 0
new_insns.append(lp.Assignment(prim.Subscript(prim.Variable(lhsname),
(prim.Variable("vec_index"), prim.Variable(new_iname)),
(prim.Variable("vec_index") + last_index, prim.Variable(new_iname)),
),
substitute(insn.expression, replacemap_vec),
depends_on=frozenset({"continue_stmt"}),
......@@ -223,4 +271,15 @@ def collect_vector_data_rotate(knl):
)
)
# Rotate back!
if rotating:
new_insns.append(lp.CallInstruction((), # assignees
prim.Call(prim.Variable("transpose_reg"),
tuple(prim.Subscript(prim.Variable(lhsname), (prim.Variable("vec_index") + i, prim.Variable(new_iname))) for i in range(4))),
depends_on=frozenset({Tagged("vec_write")}),
within_inames=common_inames.union(inames).union(frozenset({new_iname})),
within_inames_is_final=True,
id="{}_rotateback".format(lhsname),
))
return knl.copy(instructions=new_insns + other_insns)
......@@ -344,4 +344,7 @@ def sum_factorization_kernel(a_matrices, buf, stage, insn_dep=frozenset({}), add
)
silenced_warning('read_no_write({})'.format(out))
return prim.Variable(out), insn_dep
if isinstance(next(iter(a_matrices)), LargeAMatrix):
return lp.TaggedVariable(out, "vector"), insn_dep
else:
return prim.Variable(out), insn_dep
""" Some grabbag tools """
from __future__ import absolute_import
import loopy as lp
import pymbolic.primitives as prim
......@@ -42,3 +44,18 @@ def maybe_wrap_subscript(expr, indices):
return prim.Subscript(expr, indices)
else:
return expr
def get_pymbolic_tag(expr):
assert isinstance(expr, prim.Expression)
if isinstance(expr, lp.TaggedVariable):
return expr.tag
if isinstance(expr, prim.Variable):
return None
if isinstance(expr, prim.Subscript):
return get_pymbolic_tag(expr.aggregate)
raise NotImplementedError("Cannot determine tag on {}".format(expr))
......@@ -4,7 +4,6 @@ __exec_suffix = {diff_suffix}_{vecq_suffix}_{vecg_suffix}
diff_suffix = numdiff, symdiff | expand num
vecq_suffix = quadvec, nonquadvec | expand vec_q
vecg_suffix = gradvec, nongradvec | expand vec_g
{vecq_suffix} == quadvec and {vecg_suffix} == gradvec | exclude
cells = 8 8 8
extension = 1. 1. 1.
......
......@@ -4,7 +4,6 @@ __exec_suffix = {diff_suffix}_{vecq_suffix}_{vecg_suffix}
diff_suffix = numdiff, symdiff | expand num
vecq_suffix = quadvec, nonquadvec | expand vec_q
vecg_suffix = gradvec, nongradvec | expand vec_g
{vecq_suffix} == quadvec and {vecg_suffix} == gradvec | exclude
cells = 8 8 8
extension = 1. 1. 1.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment