Skip to content
Snippets Groups Projects
Commit 0c2f06d3 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Saving some work

parent 03675133
No related branches found
No related tags found
No related merge requests found
""" Export the interface interesting to the rest of the project """ """ Export the interface interesting to the rest of the project """
from dune.perftool.loopy.transformations.collect_precompute import collect_vector_data_precompute from dune.perftool.loopy.transformations.collect_precompute import collect_vector_data_precompute
from dune.perftool.loopy.transformations.collect_rotate import collect_vector_data_rotate
from dune.perftool.loopy.transformations.duplicate import heuristic_duplication from dune.perftool.loopy.transformations.duplicate import heuristic_duplication
""" A kernel transformation that precomputes quantities until a vector register
is filled and then does vector computations """
from dune.perftool.loopy.vcl import get_vcl_type_size
from dune.perftool.loopy.transformations.vectorview import (add_temporary_with_vector_view,
add_vector_view,
get_vector_view_name,
)
from dune.perftool.tools import get_pymbolic_basename
from loopy.kernel.creation import parse_domains
from loopy.symbolic import pw_aff_to_expr
from pymbolic.mapper.dependency import DependencyMapper
from pymbolic.mapper.substitutor import substitute
import pymbolic.primitives as prim
import loopy as lp
import numpy as np
def collect_vector_data_rotate(knl, insns, inames):
#
# Process/Assert/Standardize the input
#
# inames input -> tuple
if isinstance(inames, str):
inames = inames.split(",")
inames = tuple(i.strip() for i in inames)
# insns -> list of Instruction instances
if isinstance(insns, lp.match.MatchExpressionBase):
insns = lp.find_instructions(knl, insns)
else:
if isinstance(insns, str):
insns = [i.strip() for i in insns.split(",")]
insns = [knl.id_to_insn[i] for i in insns]
# Analyse the inames of the given instructions and identify inames
# that they all have in common. Those inames will also be iname dependencies
# of inserted instructions.
common_inames = frozenset([]).union(*(insn.within_inames for insn in insns)) - frozenset(inames)
# Determine the vector lane width
# TODO infer the numpy type here
vec_size = get_vcl_type_size(np.float64)
#
# Inspect the given instructions for dependent quantities
#
quantities = {}
for insn in insns:
for expr in DependencyMapper()(insn.expression):
basename = get_pymbolic_basename(expr)
quantities.setdefault(basename, frozenset())
quantities[basename] = quantities[basename].union(frozenset([expr]))
assert all(len(q) == 1 for q in quantities.values())
# Add vector size buffers for all these quantities
replacemap_arr = {}
replacemap_vec = {}
for quantity in quantities:
expr, = quantities[quantity]
arrname = quantity + '_buffered_arr'
knl = add_temporary_with_vector_view(knl,
arrname,
dtype=np.float64,
shape=(vec_size,),
dim_tags="c",
base_storage=quantity + '_base_storage',
)
replacemap_arr[quantity] = prim.Subscript(prim.Variable(arrname), (prim.Variable('rotate_index'),))
replacemap_vec[expr] = prim.Variable(get_vector_view_name(arrname))
write_match = lp.match.Or(tuple(lp.match.Writes(q) for q in quantities))
iname_match = lp.match.And(tuple(lp.match.Iname(i) for i in inames))
match = lp.match.And((write_match, iname_match))
write_insns = lp.find_instructions(knl, match)
other_insns = [i for i in knl.instructions if i.id not in [j.id for j in insns + write_insns]]
new_insns = []
temporaries = knl.temporary_variables
for insn in write_insns:
if isinstance(insn, lp.Assignment):
new_insns.append(insn.copy(assignee=replacemap_arr[get_pymbolic_basename(insn.assignee)],
)
)
elif isinstance(insn, lp.CInstruction):
pass
else:
raise NotImplementedError
#
# Add two counter variables to the kernel
#
# Insert a flat consecutive counter 'total_index'
temporaries['total_index'] = lp.TemporaryVariable('total_index', # name
dtype=np.int32,
)
new_insns.append(lp.Assignment(prim.Variable("total_index"), # assignee
0, # expression
within_inames=common_inames,
within_inames_is_final=True,
id="assign_total_index",
))
new_insns.append(lp.Assignment(prim.Variable("total_index"), # assignee
prim.Sum((prim.Variable("total_index"), 1)), # expression
within_inames=common_inames.union(inames),
within_inames_is_final=True,
depends_on=frozenset([i.id for i in write_insns]).union(frozenset({"assign_total_index"})),
depends_on_is_final=True,
id="update_total_index",
))
# Insert a rotating index, that counts 0 , .. , vecsize - 1
temporaries['rotate_index'] = lp.TemporaryVariable('rotate_index', # name
dtype=np.int32,
)
new_insns.append(lp.Assignment(prim.Variable("rotate_index"), # assignee
0, # expression
within_inames=common_inames,
within_inames_is_final=True,
id="assign_rotate_index",
))
new_insns.append(lp.Assignment(prim.Variable("rotate_index"), # assignee
prim.Remainder(prim.Sum((prim.Variable("rotate_index"), 1)), vec_size), # expression
within_inames=common_inames.union(inames),
within_inames_is_final=True,
depends_on=frozenset([i.id for i in write_insns]).union(frozenset({"assign_rotate_index"})),
depends_on_is_final=True,
id="update_rotate_index",
))
#
# Construct a flat loop for the given instructions
#
# new_insns = []
# other_insns = [i for i in knl.instructions if i.id not in [j.id for j in insns]]
#
# size = prim.Product(tuple(pw_aff_to_expr(knl.get_iname_bounds(i).size) for i in inames))
# size = prim.FloorDiv(size, vec_size)
#
# temporaries = knl.temporary_variables
# temporaries["flatsize"] = lp.TemporaryVariable("flatsize",
# dtype=np.int32,
# shape=(),
# )
# new_insns.append(lp.Assignment(prim.Variable("flatsize"),
# size,
# )
# )
#
# # Add an additional domain to the kernel
# new_iname = "flat_{}".format("_".join(inames))
# domain = "{{ [{0}] : 0<={0}<flatsize }}".format(new_iname, str(size))
# domain = parse_domains(domain, {})
# knl = knl.copy(domains=knl.domains + domain,
# temporary_variables=temporaries)
#
# # Split and tag the flat iname
# knl = lp.split_iname(knl, new_iname, vec_size, inner_tag="vec")
# new_inames = ("{}_outer".format(new_iname), "{}_inner".format(new_iname))
# knl = lp.assume(knl, "flatsize mod {} = 0".format(vec_size))
#
# for insn in insns:
# # Get a vector view of the lhs expression
# lhsname = get_pymbolic_basename(insn.assignee)
# knl = add_vector_view(knl, lhsname)
# lhsname = get_vector_view_name(lhsname)
#
# new_insns.append(lp.Assignment(prim.Subscript(prim.Variable(lhsname), tuple(prim.Variable(i) for i in new_inames)),
# prim.Subscript(prim.Variable(get_vector_view_name("wk_precomputed")), tuple(prim.Variable(i) for i in new_inames)),
# within_inames=frozenset(new_inames),
# within_inames_is_final=True,
# )
# )
return knl.copy(instructions=new_insns + other_insns)
...@@ -46,3 +46,12 @@ def add_vector_view(knl, tmpname): ...@@ -46,3 +46,12 @@ def add_vector_view(knl, tmpname):
) )
return knl.copy(temporary_variables=temporaries) return knl.copy(temporary_variables=temporaries)
def add_temporary_with_vector_view(knl, name, *args, **kwargs):
temps = knl.temporary_variables
assert name not in temps
temps[name] = lp.TemporaryVariable(name, *args, **kwargs)
knl = knl.copy(temporary_variables=temps)
knl = add_vector_view(knl, name)
return knl
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment