Skip to content
Snippets Groups Projects
Commit 93b61c97 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

[bugfix] fix occurence of expression both in vector and in array statement

parent 640224a2
No related branches found
No related tags found
No related merge requests found
...@@ -105,7 +105,10 @@ def collect_vector_data_rotate(knl): ...@@ -105,7 +105,10 @@ def collect_vector_data_rotate(knl):
# Add vector size buffers for all these quantities # Add vector size buffers for all these quantities
replacemap_vec = {} replacemap_vec = {}
replacemap_arr = {}
for quantity in quantities: for quantity in quantities:
quantity_exprs = quantities[quantity]
# Check whether there is an instruction that writes this quantity within # Check whether there is an instruction that writes this quantity within
# the given inames. If so, we need a buffer array. # the given inames. If so, we need a buffer array.
iname_match = lp.match.And(tuple(lp.match.Iname(i) for i in inames)) iname_match = lp.match.And(tuple(lp.match.Iname(i) for i in inames))
...@@ -141,16 +144,22 @@ def collect_vector_data_rotate(knl): ...@@ -141,16 +144,22 @@ def collect_vector_data_rotate(knl):
else: else:
return () return ()
for expr in quantities[quantity]: for expr in quantity_exprs:
replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(arrname)), get_quantity_subscripts(expr, zero=True) + (prim.Variable(new_iname),)) replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(arrname)), get_quantity_subscripts(expr, zero=True) + (prim.Variable(new_iname),))
for insn in write_insns: while write_insns:
insn = write_insns.pop()
if isinstance(insn, lp.Assignment): if isinstance(insn, lp.Assignment):
assignee = prim.Subscript(prim.Variable(arrname), get_quantity_subscripts(insn.assignee) + (prim.Variable('rotate_index'),)) assignee = prim.Subscript(prim.Variable(arrname), get_quantity_subscripts(insn.assignee) + (prim.Variable('rotate_index'),))
new_insns.append(insn.copy(assignee=assignee, new_insns.append(insn.copy(assignee=assignee,
expression=substitute(insn.expression, replacemap_arr),
depends_on_is_final=True, depends_on_is_final=True,
) )
) )
for e in quantity_exprs:
replacemap_arr[e] = prim.Subscript(prim.Variable(arrname), get_quantity_subscripts(e) + (prim.Variable('rotate_index'),))
elif isinstance(insn, lp.CInstruction): elif isinstance(insn, lp.CInstruction):
# This entire code path should go away as we either # This entire code path should go away as we either
# * switch CInstructions to implicit iname assignments (see https://github.com/inducer/loopy/issues/55) # * switch CInstructions to implicit iname assignments (see https://github.com/inducer/loopy/issues/55)
...@@ -189,7 +198,7 @@ def collect_vector_data_rotate(knl): ...@@ -189,7 +198,7 @@ def collect_vector_data_rotate(knl):
else: else:
raise NotImplementedError raise NotImplementedError
elif quantity in knl.temporary_variables: elif quantity in knl.temporary_variables:
tag, = set(get_pymbolic_tag(expr) for expr in quantities[quantity]) tag, = set(get_pymbolic_tag(expr) for expr in quantity_exprs)
if tag is not None and tag.startswith('vecsumfac'): if tag is not None and tag.startswith('vecsumfac'):
# Extract information from the tag # Extract information from the tag
horizontal, vertical = tuple(int(i) for i in re.match("vecsumfac_h(.*)_v(.*)", tag).groups()) horizontal, vertical = tuple(int(i) for i in re.match("vecsumfac_h(.*)_v(.*)", tag).groups())
...@@ -213,7 +222,7 @@ def collect_vector_data_rotate(knl): ...@@ -213,7 +222,7 @@ def collect_vector_data_rotate(knl):
)) ))
# Add substitution rules # Add substitution rules
for expr in quantities[quantity]: for expr in quantity_exprs:
assert isinstance(expr, prim.Subscript) assert isinstance(expr, prim.Subscript)
last_index = expr.index[-1] // vertical last_index = expr.index[-1] // vertical
replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)), replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)),
...@@ -221,7 +230,7 @@ def collect_vector_data_rotate(knl): ...@@ -221,7 +230,7 @@ def collect_vector_data_rotate(knl):
) )
elif tag is not None and tag == 'sumfac': elif tag is not None and tag == 'sumfac':
# Add a vector view to this quantity # Add a vector view to this quantity
expr, = quantities[quantity] expr, = quantity_exprs
knl = add_vector_view(knl, quantity, flatview=True) knl = add_vector_view(knl, quantity, flatview=True)
replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)), replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)),
(vector_indices.get(1), prim.Variable(new_iname)), (vector_indices.get(1), prim.Variable(new_iname)),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment