Skip to content
Snippets Groups Projects
Commit c7f2b4cb authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Cleanup code

parent 1c0046f5
No related branches found
No related tags found
No related merge requests found
""" A kernel transformation that precomputes quantities until a vector register
is filled and then does vector computations """
from dune.perftool.generation import (function_mangler,
include_file,
loopy_class_member,
)
from dune.perftool.loopy.vcl import get_vcl_type, get_vcl_type_size
from dune.perftool.loopy.transformations.vectorview import (add_temporary_with_vector_view,
add_vector_view,
get_vector_view_name,
)
from dune.perftool.loopy.symbolic import substitute
from dune.perftool.sumfact.quadrature import quadrature_inames
from dune.perftool.tools import get_pymbolic_basename, get_pymbolic_tag, ceildiv
from dune.perftool.options import get_option
from loopy.kernel.creation import parse_domains
from loopy.symbolic import pw_aff_to_expr
from loopy.match import Tagged
from loopy.symbolic import DependencyMapper
from pytools import product
import pymbolic.primitives as prim
import loopy as lp
import numpy as np
import re
class TransposeReg(lp.symbolic.FunctionIdentifier):
def __init__(self,
horizontal=1,
vertical=1,
):
self.horizontal = horizontal
self.vertical = vertical
def __getinitargs__(self):
return (self.horizontal, self.vertical)
@property
def name(self):
return "transpose_reg"
@function_mangler
def rotate_function_mangler(knl, func, arg_dtypes):
if isinstance(func, TransposeReg):
# This is not 100% within the loopy philosophy, as we are
# passing the vector registers as references and have them
# changed. Loopy assumes this function to be read-only.
include_file("dune/perftool/sumfact/transposereg.hh", filetag="operatorfile")
vcl = lp.types.NumpyType(get_vcl_type(np.float64, vector_width=func.horizontal * func.vertical))
return lp.CallMangleInfo(func.name, (), (vcl,) * func.horizontal)
class VectorIndices(object):
def __init__(self):
self.needed = set()
def get(self, increment):
name = "vec_index_inc{}".format(increment)
self.needed.add((name, increment))
return prim.Variable(name)
def collect_vector_data_rotate(knl):
#
# Process/Assert/Standardize the input
#
insns = [i for i in lp.find_instructions(knl, lp.match.Tagged("quadvec"))]
if not insns:
return knl
inames = quadrature_inames()
# Analyse the inames of the given instructions and identify inames
# that they all have in common. Those inames will also be iname dependencies
# of inserted instructions.
common_inames = frozenset([]).union(*(insn.within_inames for insn in insns)) - frozenset(inames)
# Determine the vector lane width
# TODO infer the numpy type here
vec_size = get_vcl_type_size(np.float64)
vector_indices = VectorIndices()
# Add an iname to the kernel which will be used for vectorization
new_iname = "quad_vec_{}".format("_".join(inames))
domain = "{{ [{0}] : 0<={0}<{1} }}".format(new_iname, str(vec_size))
domain = parse_domains(domain, {})
knl = knl.copy(domains=knl.domains + domain)
knl = lp.tag_inames(knl, [(new_iname, "vec")])
new_insns = []
all_writers = []
#
# Inspect the given instructions for dependent quantities
#
quantities = {}
for insn in insns:
for expr in DependencyMapper()(insn.expression):
basename = get_pymbolic_basename(expr)
quantities.setdefault(basename, frozenset())
quantities[basename] = quantities[basename].union(frozenset([expr]))
# Add vector size buffers for all these quantities
replacemap_vec = {}
replacemap_arr = {}
for quantity in quantities:
quantity_exprs = quantities[quantity]
# Check whether there is an instruction that writes this quantity within
# the given inames. If so, we need a buffer array.
iname_match = lp.match.And(tuple(lp.match.Iname(i) for i in inames))
write_match = lp.match.Writes(quantity)
match = lp.match.And((iname_match, write_match))
write_insns = lp.find_instructions(knl, match)
all_writers.extend([i.id for i in write_insns])
if write_insns:
# Determine the shape of the quantity
shape = knl.temporary_variables[quantity].shape
arrname = quantity + '_buffered_arr'
knl = add_temporary_with_vector_view(knl,
arrname,
dtype=np.float64,
shape=shape + (vec_size,),
dim_tags=",".join("c" for i in range(len(shape) + 1)),
base_storage=quantity + '_base_storage',
scope=lp.temp_var_scope.PRIVATE,
)
def get_quantity_subscripts(e, zero=False):
if isinstance(e, prim.Subscript):
index = e.index
if isinstance(index, tuple):
return index
else:
return (index,)
else:
if zero:
return (0,)
else:
return ()
for expr in quantity_exprs:
replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(arrname)), get_quantity_subscripts(expr, zero=True) + (prim.Variable(new_iname),))
while write_insns:
insn = write_insns.pop()
if isinstance(insn, lp.Assignment):
assignee = prim.Subscript(prim.Variable(arrname), get_quantity_subscripts(insn.assignee) + (prim.Variable('rotate_index'),))
new_insns.append(insn.copy(assignee=assignee,
depends_on=insn.depends_on.union(frozenset({lp.match.Tagged("sumfact_stage1")})),
)
)
for e in quantity_exprs:
replacemap_arr[e] = prim.Subscript(prim.Variable(arrname), get_quantity_subscripts(e) + (prim.Variable('rotate_index'),))
elif isinstance(insn, lp.CInstruction):
# This entire code path should go away as we either
# * switch CInstructions to implicit iname assignments (see https://github.com/inducer/loopy/issues/55)
# * switch to doing geometry stuff for sum factorization ourselves
if len(shape) == 0:
# Rip apart the code and change the assignee
assignee, expression = insn.code.split("=")
assignee = prim.Subscript(prim.Variable(arrname), (prim.Variable('rotate_index'),))
code = "{} ={}".format(str(assignee), expression)
new_insns.append(insn.copy(code=code,
depends_on_is_final=True,
))
else:
# This is a *very* unfortunate code path
# Get inames to assign to the vector buffer
cinsn_inames = tuple("{}_assign_{}".format(quantity, i) for i in range(len(shape)))
domains = frozenset("{{ [{0}] : 0<={0}<{1} }}".format(iname, shape[i]) for i, iname in enumerate(cinsn_inames))
for dom in domains:
domain = parse_domains(dom, {})
knl = knl.copy(domains=knl.domains + domain)
# We keep the old writing instructions
new_insns.append(insn)
# and write a new one
cinsn_id = "{}_assign_id".format(quantity)
new_insns.append(lp.Assignment(prim.Subscript(prim.Variable(arrname), tuple(prim.Variable(i) for i in cinsn_inames) + (prim.Variable('rotate_index'),)),
prim.Subscript(prim.Variable(quantity), tuple(prim.Variable(i) for i in cinsn_inames)),
within_inames=common_inames.union(inames).union(frozenset(cinsn_inames)),
within_inames_is_final=True,
depends_on=frozenset({lp.match.Writes(quantity)}),
id=cinsn_id,
))
all_writers.append(cinsn_id)
else:
raise NotImplementedError
elif quantity in knl.temporary_variables:
tag, = set(get_pymbolic_tag(expr) for expr in quantity_exprs)
if tag is not None and tag.startswith('vecsumfac'):
# Extract information from the tag
horizontal, vertical = tuple(int(i) for i in re.match("vecsumfac_h(.*)_v(.*)", tag).groups())
#
# There is a vector quantity to be vectorized! That requires register rotation!
#
# 1. Rotating the input data
knl = add_vector_view(knl, quantity, flatview=True)
if horizontal > 1:
new_insns.append(lp.CallInstruction((), # assignees
prim.Call(TransposeReg(vertical=vertical, horizontal=horizontal),
tuple(prim.Subscript(prim.Variable(get_vector_view_name(quantity)),
(vector_indices.get(horizontal) + i, prim.Variable(new_iname)))
for i in range(horizontal))),
depends_on=frozenset({'continue_stmt'}),
within_inames=common_inames.union(inames).union(frozenset({new_iname})),
within_inames_is_final=True,
id="{}_rotate".format(quantity),
))
# Add substitution rules
for expr in quantity_exprs:
assert isinstance(expr, prim.Subscript)
last_index = expr.index[-1] // vertical
replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)),
(vector_indices.get(horizontal) + last_index, prim.Variable(new_iname)),
)
elif tag is not None and tag == 'sumfac':
# Add a vector view to this quantity
expr, = quantity_exprs
knl = add_vector_view(knl, quantity, flatview=True)
replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)),
(vector_indices.get(1), prim.Variable(new_iname)),
)
elif quantity in [a.name for a in knl.args]:
arg, = [a for a in knl.args if a.name == quantity]
tags = set(get_pymbolic_tag(expr) for expr in quantity_exprs)
if tags and tags.pop() == "operator_precomputed":
expr, = quantity_exprs
shape=(ceildiv(product(s for s in arg.shape), vec_size), vec_size)
name = loopy_class_member(quantity,
shape=shape,
dim_tags="f,vec",
potentially_vectorized=True,
classtag="operator",
dtype=np.float64,
)
knl = knl.copy(args=knl.args + [lp.GlobalArg(name, shape=shape, dim_tags="c,vec", dtype=np.float64)])
replacemap_vec[expr] = prim.Subscript(prim.Variable(name),
(vector_indices.get(1), prim.Variable(new_iname)),
)
new_insns = [i.copy(expression=substitute(i.expression, replacemap_arr)) for i in new_insns]
other_insns = [i for i in knl.instructions if i.id not in [j.id for j in insns + new_insns]]
#
# Add three counter variables to the kernel
#
temporaries = knl.temporary_variables
temporaries['total_index'] = lp.TemporaryVariable('total_index',
dtype=np.int32,
scope=lp.temp_var_scope.PRIVATE,
)
new_insns.append(lp.Assignment(prim.Variable("total_index"), # assignee
0, # expression
within_inames=common_inames,
within_inames_is_final=True,
id="assign_total_index",
))
new_insns.append(lp.Assignment(prim.Variable("total_index"), # assignee
prim.Sum((prim.Variable("total_index"), 1)), # expression
within_inames=common_inames.union(inames),
within_inames_is_final=True,
depends_on=frozenset(all_writers).union(frozenset({"assign_total_index"})),
depends_on_is_final=True,
id="update_total_index",
))
# Insert a rotating index, that counts 0 , .. , vecsize - 1
temporaries['rotate_index'] = lp.TemporaryVariable('rotate_index', # name
dtype=np.int32,
scope=lp.temp_var_scope.PRIVATE,
)
new_insns.append(lp.Assignment(prim.Variable("rotate_index"), # assignee
0, # expression
within_inames=common_inames,
within_inames_is_final=True,
id="assign_rotate_index",
))
new_insns.append(lp.Assignment(prim.Variable("rotate_index"), # assignee
prim.Remainder(prim.Sum((prim.Variable("rotate_index"), 1)), vec_size), # expression
within_inames=common_inames.union(inames),
within_inames_is_final=True,
depends_on=frozenset(all_writers).union(frozenset({"assign_rotate_index"})),
depends_on_is_final=True,
id="update_rotate_index",
))
knl = knl.copy(temporary_variables=temporaries)
#
# Add a continue statement depending on the rotate index
#
# Determine the condition for the continue statement
upper_bound = prim.Product(tuple(pw_aff_to_expr(knl.get_iname_bounds(i).size) for i in inames))
total_check = prim.Comparison(prim.Variable("total_index"), "<", upper_bound)
rotate_check = prim.Comparison(prim.Variable("rotate_index"), "!=", 0)
check = prim.LogicalAnd((rotate_check, total_check))
# Insert the 'continue' statement
new_insns.append(lp.CInstruction((), # iname exprs that the code needs access to
"continue;", # the code
predicates=frozenset({check}),
depends_on=frozenset({"update_rotate_index", "update_total_index"}).union(frozenset(all_writers)),
depends_on_is_final=True,
within_inames=common_inames.union(inames),
within_inames_is_final=True,
id="continue_stmt",
))
#
# Reconstruct the compute instructions
#
for insn in insns:
# Get a vector view of the lhs expression
lhsname = get_pymbolic_basename(insn.assignee)
knl = add_vector_view(knl, lhsname, pad_to=vec_size, flatview=True)
lhsname = get_vector_view_name(lhsname)
rotating = "gradvec" in insn.tags
if rotating:
assert isinstance(insn.assignee, prim.Subscript)
tag = get_pymbolic_tag(insn.assignee)
horizontal, vertical = tuple(int(i) for i in re.match("vecsumfac_h(.*)_v(.*)", tag).groups())
if horizontal > 1:
last_index = insn.assignee.index[-1] // vertical
else:
last_index = 0
else:
last_index = 0
horizontal = 1
new_insns.append(lp.Assignment(prim.Subscript(prim.Variable(lhsname),
(vector_indices.get(horizontal) + last_index, prim.Variable(new_iname)),
),
substitute(insn.expression, replacemap_vec),
depends_on=frozenset({"continue_stmt"}),
depends_on_is_final=True,
within_inames=common_inames.union(frozenset(inames + (new_iname,))),
within_inames_is_final=True,
id=insn.id,
tags=frozenset({"vec_write"})
)
)
# Rotate back!
if rotating and "{}_rotateback".format(lhsname) not in [i.id for i in new_insns] and horizontal > 1:
new_insns.append(lp.CallInstruction((), # assignees
prim.Call(TransposeReg(horizontal=horizontal, vertical=vertical),
tuple(prim.Subscript(prim.Variable(lhsname),
(vector_indices.get(horizontal) + i, prim.Variable(new_iname)))
for i in range(horizontal))),
depends_on=frozenset({Tagged("vec_write")}),
within_inames=common_inames.union(inames).union(frozenset({new_iname})),
within_inames_is_final=True,
id="{}_rotateback".format(lhsname),
))
# Add the necessary vector indices
for name, increment in vector_indices.needed:
temporaries[name] = lp.TemporaryVariable(name, # name
dtype=np.int32,
scope=lp.temp_var_scope.PRIVATE,
)
new_insns.append(lp.Assignment(prim.Variable(name), # assignee
0, # expression
within_inames=common_inames,
within_inames_is_final=True,
id="assign_{}".format(name),
))
new_insns.append(lp.Assignment(prim.Variable(name), # assignee
prim.Sum((prim.Variable(name), increment)), # expression
within_inames=common_inames.union(inames),
within_inames_is_final=True,
depends_on=frozenset({Tagged("vec_write"), "assign_{}".format(name)}),
depends_on_is_final=True,
id="update_{}".format(name),
))
from loopy.kernel.creation import resolve_dependencies
return resolve_dependencies(knl.copy(instructions=new_insns + other_insns))
...@@ -82,7 +82,7 @@ class AntiPatternRemover(IdentityMapper): ...@@ -82,7 +82,7 @@ class AntiPatternRemover(IdentityMapper):
return IdentityMapper.map_floor_div(self, expr) return IdentityMapper.map_floor_div(self, expr)
def collect_vector_data_precompute(knl): def vectorize_quadrature_loop(knl):
# #
# Process/Assert/Standardize the input # Process/Assert/Standardize the input
# #
......
...@@ -502,10 +502,8 @@ def extract_kernel_from_cache(tag, wrap_in_cgen=True): ...@@ -502,10 +502,8 @@ def extract_kernel_from_cache(tag, wrap_in_cgen=True):
# Maybe apply vectorization strategies # Maybe apply vectorization strategies
if get_option("vectorize_quad"): if get_option("vectorize_quad"):
if get_option("sumfact"): if get_option("sumfact"):
from dune.perftool.loopy.transformations.collect_rotate import collect_vector_data_rotate from dune.perftool.loopy.transformations.vectorize_quad import vectorize_quadrature_loop
from dune.perftool.loopy.transformations.collect_precompute import collect_vector_data_precompute kernel = vectorize_quadrature_loop(kernel)
kernel = collect_vector_data_precompute(kernel)
# kernel = collect_vector_data_rotate(kernel)
else: else:
raise NotImplementedError("Only vectorizing sumfactorized code right now!") raise NotImplementedError("Only vectorizing sumfactorized code right now!")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment