Skip to content
Snippets Groups Projects
Commit 32587c37 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Correctly precompute quantities

parent 4aafb27b
No related branches found
No related tags found
No related merge requests found
""" A kernel transformation that collects data until the vector size is reached """
""" A kernel transformation that collects data until the vector size is reached """
from dune.perftool.tools import get_pymbolic_basename
from pymbolic.mapper.dependency import DependencyMapper
from pymbolic.mapper.substitutor import substitute
import pymbolic.primitives as prim
import loopy as lp
import numpy as np
from pymbolic.mapper.dependency import DependencyMapper
def collect_vector_data(knl, insns, inames, vec_size=4):
# TODO: In theory this vec_size should be deduced from the types
......@@ -42,9 +46,13 @@ def collect_vector_data(knl, insns, inames, vec_size=4):
# * Find all written quantities in the instructions from 1)
# * Find the instructions that write to these quantities
# * Filter only those that depend on the given inames
quantities = []
quantities = {}
depmapper = DependencyMapper()
for insn in insns:
quantities.extend(insn.read_dependency_names() - inames)
for expr in depmapper(insn.expression):
basename = get_pymbolic_basename(expr)
quantities.setdefault(basename, frozenset())
quantities[basename] = quantities[basename].union(frozenset([expr]))
write_match = lp.match.Or(tuple(lp.match.Writes(q) for q in quantities))
iname_match = lp.match.And(tuple(lp.match.Iname(i) for i in inames))
......@@ -101,7 +109,7 @@ def collect_vector_data(knl, insns, inames, vec_size=4):
prim.Sum((prim.Variable("total_index"), 1)), # expression
within_inames=common_inames.union(inames),
within_inames_is_final=True,
depends_on=frozenset({"assign_total_index"}),
depends_on=frozenset([i.id for i in write_insns]).union(frozenset({"assign_total_index"})),
depends_on_is_final=True,
id="update_total_index",
))
......@@ -122,22 +130,42 @@ def collect_vector_data(knl, insns, inames, vec_size=4):
prim.Remainder(prim.Sum((prim.Variable("rotate_index"), 1)), vec_size), # expression
within_inames=common_inames.union(inames),
within_inames_is_final=True,
depends_on=frozenset({"assign_rotate_index"}),
depends_on=frozenset([i.id for i in write_insns]).union(frozenset({"assign_rotate_index"})),
depends_on_is_final=True,
id="update_rotate_index",
))
# Pre-evaluate all the needed quantities
for quantity in quantities:
name = quantity + '_buffered'
temporaries[name] = lp.TemporaryVariable(name,
dtype=np.float64,
shape=(vec_size,),
)
replacemap_arr = {}
replacemap_vec = {}
for quantity, quantity_exprs in quantities.items():
# TODO for now I only consider the case where an array occurs but once!
assert len(quantity_exprs) == 1
quantity_expr, = quantity_exprs
# Determine the shape of the
arrname = quantity + '_buffered_arr'
temporaries[arrname] = lp.TemporaryVariable(arrname,
dtype=np.float64,
shape=(vec_size,),
dim_tags="c",
base_storage=quantity + '_base_storage',
)
vecname = quantity + '_buffered_vec'
temporaries[vecname] = lp.TemporaryVariable(vecname,
dtype=np.float64,
shape=(vec_size,),
dim_tags="vec",
base_storage=quantity + '_base_storage',
)
replacemap_arr[quantity] = prim.Subscript(prim.Variable(arrname), (prim.Variable('rotate_index'),))
replacemap_vec[quantity_expr] = prim.Variable(vecname)
for insn in write_insns:
# Dummy
new_insns.append(insn.copy(assignee=prim.Subscript(prim.Variable(name), (prim.Variable('rotate_index'),)),
new_insns.append(insn.copy(assignee=replacemap_arr[get_pymbolic_basename(insn.assignee)],
)
)
......@@ -152,12 +180,12 @@ def collect_vector_data(knl, insns, inames, vec_size=4):
id="continue_stmt",
))
#
# TODO! Replace the precomputed quantities in the target instructions
#
for insn in insns:
new_insns.append(insn.copy(expression=insn.expression))
# TODO do something about the assignee!
new_insns.append(insn.copy(expression=substitute(insn.expression, variable_assignments=replacemap_vec),
depends_on=insn.depends_on.union(frozenset(["continue_stmt"]))
)
)
# Return a kernel
return knl.copy(instructions=dep_insns + other_insns + new_insns,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment