Skip to content
Snippets Groups Projects
Commit c13b3424 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Have index handling information be covered by memoization

parent 4ed824c8
No related branches found
No related tags found
No related merge requests found
......@@ -289,16 +289,16 @@ def collect_vector_data_rotate(knl):
)
)
# Rotate back!
if rotating:
new_insns.append(lp.CallInstruction((), # assignees
prim.Call(prim.Variable("transpose_reg"),
tuple(prim.Subscript(prim.Variable(lhsname), (prim.Variable("vec_index") + i, prim.Variable(new_iname))) for i in range(4))),
depends_on=frozenset({Tagged("vec_write")}),
within_inames=common_inames.union(inames).union(frozenset({new_iname})),
within_inames_is_final=True,
id="{}_rotateback".format(lhsname),
))
# Rotate back!
if rotating and "{}_rotateback".format(lhsname) not in [i.id for i in new_insns]:
new_insns.append(lp.CallInstruction((), # assignees
prim.Call(prim.Variable("transpose_reg"),
tuple(prim.Subscript(prim.Variable(lhsname), (prim.Variable("vec_index") + i, prim.Variable(new_iname))) for i in range(4))),
depends_on=frozenset({Tagged("vec_write")}),
within_inames=common_inames.union(inames).union(frozenset({new_iname})),
within_inames_is_final=True,
id="{}_rotateback".format(lhsname),
))
from loopy.kernel.creation import resolve_dependencies
return resolve_dependencies(knl.copy(instructions=new_insns + other_insns))
......@@ -24,10 +24,14 @@ class SumFactInterface(PDELabInterface):
return pymbolic_reference_gradient(element, restriction, number)
def pymbolic_trialfunction_gradient(self, element, restriction, component, visitor=None):
return pymbolic_trialfunction_gradient(element, restriction, component, visitor)
ret, indices = pymbolic_trialfunction_gradient(element, restriction, component, visitor)
visitor.indices = indices
return ret
def pymbolic_trialfunction(self, element, restriction, component, visitor=None):
return pymbolic_trialfunction(element, restriction, component, visitor)
ret, indices = pymbolic_trialfunction(element, restriction, component, visitor)
visitor.indices = indices
return ret
def quadrature_inames(self):
return quadrature_inames()
......
......@@ -75,7 +75,7 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor):
# Get the vectorization info. If this happens during the dry run, we get dummies
from dune.perftool.sumfact.vectorization import get_vectorization_info
a_matrices, buffer, input, index = get_vectorization_info(a_matrices, 0)
a_matrices, buffer, input, index = get_vectorization_info(a_matrices, restriction)
# Initialize the buffer for the sum fact kernel
shape = (product(mat.cols for mat in a_matrices),)
......@@ -101,6 +101,7 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor):
1,
preferred_position=i,
insn_dep=insn_dep,
restriction=restriction,
)
buffers.append(var)
......@@ -109,9 +110,7 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor):
# with the position in the vector register.
if index:
assert len(visitor.indices) == 1
indices = visitor.indices
visitor.indices = None
return maybe_wrap_subscript(var, tuple(prim.Variable(i) for i in quadrature_inames()) + indices)
return maybe_wrap_subscript(var, tuple(prim.Variable(i) for i in quadrature_inames()) + visitor.indices), None
# TODO this should be quite conditional!!!
for i, buf in enumerate(buffers):
......@@ -126,7 +125,7 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor):
forced_iname_deps_is_final=True,
)
return prim.Variable(name)
return prim.Variable(name), visitor.indices
@kernel_cached
......@@ -140,7 +139,7 @@ def pymbolic_trialfunction(element, restriction, component, visitor):
# Get the vectorization info. If this happens during the dry run, we get dummies
from dune.perftool.sumfact.vectorization import get_vectorization_info
a_matrices, buffer, input, index = get_vectorization_info(a_matrices, 0)
a_matrices, buffer, input, index = get_vectorization_info(a_matrices, restriction)
# Flip flop buffers for sumfactorization
shape = (product(mat.cols for mat in a_matrices),)
......@@ -164,6 +163,7 @@ def pymbolic_trialfunction(element, restriction, component, visitor):
preferred_position=None,
insn_dep=frozenset({Writes(input)}),
outshape=tuple(mat.rows for mat in a_matrices if mat.rows != 1),
restriction=restriction,
)
if index:
......@@ -173,7 +173,7 @@ def pymbolic_trialfunction(element, restriction, component, visitor):
return prim.Subscript(var,
tuple(prim.Variable(i) for i in quadrature_inames()) + index
)
), visitor.indices
@iname
......
......@@ -128,7 +128,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
# Get the vectorization info. If this happens during the dry run, we get dummies
from dune.perftool.sumfact.vectorization import get_vectorization_info
a_matrices, buffer, input, index = get_vectorization_info(a_matrices, restriction)
a_matrices, buffer, input, index = get_vectorization_info(a_matrices, accterm.argument.restriction)
# Initialize a base storage for this buffer and get a temporay pointing to it
shape = tuple(mat.cols for mat in a_matrices if mat.cols != 1)
......@@ -185,6 +185,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
insn_dep=insn_dep,
additional_inames=frozenset(visitor.inames),
preferred_position=pref_pos,
restriction=accterm.argument.restriction,
)
inames = tuple(sumfact_iname(mat.rows, 'accum') for mat in a_matrices)
......@@ -240,7 +241,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
insn_dep = emit_sumfact_kernel(None, restriction, insn_dep)
@generator_factory(item_tags=("sumfactkernel",), context_tags=("kernel",), cache_key_generator=lambda a, b, s, **kw: (a, b, s))
@generator_factory(item_tags=("sumfactkernel",), context_tags=("kernel",), cache_key_generator=lambda a, b, s, **kw: (a, b, s, kw.get("restriction", 0)))
def sum_factorization_kernel(a_matrices, buf, stage,
insn_dep=frozenset({}),
additional_inames=frozenset({}),
......
......@@ -30,17 +30,16 @@ def get_vectorization_info(a_matrices, restriction):
def no_vectorization(sumfacts):
for sumf in sumfacts:
for res in (Restriction.NONE, Restriction.POSITIVE, Restriction.NEGATIVE):
vectorization_info(sumf.a_matrices,
res,
sumf.a_matrices,
get_counted_variable("buffer"),
get_counted_variable(restricted_name("input", sumf.restriction)),
None)
vectorization_info(sumf.a_matrices,
sumf.restriction,
sumf.a_matrices,
get_counted_variable("buffer"),
get_counted_variable(restricted_name("input", sumf.restriction)),
None)
def decide_stage_vectorization_strategy(sumfacts, stage):
stage_sumfacts = frozenset([sf for sf in sumfacts if sf.stage == stage])
def decide_stage_vectorization_strategy(sumfacts, stage, restriction):
stage_sumfacts = frozenset([sf for sf in sumfacts if sf.stage == stage and sf.restriction == restriction])
if len(stage_sumfacts) in (3, 4):
# Map the sum factorization to their position in the joint kernel
position_mapping = {}
......@@ -107,8 +106,9 @@ def decide_vectorization_strategy():
if not get_option("vectorize_grads"):
no_vectorization(sumfacts)
else:
decide_stage_vectorization_strategy(sumfacts, 1)
decide_stage_vectorization_strategy(sumfacts, 3)
for stage in (1, 3):
for restriction in (Restriction.NONE, Restriction.POSITIVE, Restriction.NEGATIVE):
decide_stage_vectorization_strategy(sumfacts, stage, restriction)
class HasSumfactMapper(lp.symbolic.CombineMapper):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment