Skip to content
Snippets Groups Projects
Commit 1c0046f5 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Remove harmful antipattern

parent 1797a0ec
No related branches found
No related tags found
No related merge requests found
...@@ -19,7 +19,7 @@ from loopy.kernel.creation import parse_domains ...@@ -19,7 +19,7 @@ from loopy.kernel.creation import parse_domains
from loopy.symbolic import pw_aff_to_expr from loopy.symbolic import pw_aff_to_expr
from loopy.match import Tagged from loopy.match import Tagged
from loopy.symbolic import DependencyMapper from loopy.symbolic import DependencyMapper, IdentityMapper
from pytools import product from pytools import product
import pymbolic.primitives as prim import pymbolic.primitives as prim
...@@ -65,6 +65,23 @@ class VectorIndices(object): ...@@ -65,6 +65,23 @@ class VectorIndices(object):
return prim.Variable(name) return prim.Variable(name)
class AntiPatternRemover(IdentityMapper):
def map_floor_div(self, expr):
""" (y + (x % n)) // n -> y // n """
num = expr.numerator
den = expr.denominator
if isinstance(num, prim.Remainder) and num.denominator == den:
return 0
if isinstance(num, prim.Sum) and len(num.children) == 2:
c0, c1 = num.children
if isinstance(c1, prim.Remainder) and c1.denominator == den:
return c0 // den
return IdentityMapper.map_floor_div(self, expr)
def collect_vector_data_precompute(knl): def collect_vector_data_precompute(knl):
# #
# Process/Assert/Standardize the input # Process/Assert/Standardize the input
...@@ -189,7 +206,7 @@ def collect_vector_data_precompute(knl): ...@@ -189,7 +206,7 @@ def collect_vector_data_precompute(knl):
# Add substitution rules # Add substitution rules
for expr in quantity_exprs: for expr in quantity_exprs:
assert isinstance(expr, prim.Subscript) assert isinstance(expr, prim.Subscript)
last_index = expr.index[-1] // vertical last_index = AntiPatternRemover()(expr.index[-1] // vertical)
replacemap[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)), replacemap[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)),
(vector_indices.get(horizontal) + last_index, prim.Variable(vec_iname)), (vector_indices.get(horizontal) + last_index, prim.Variable(vec_iname)),
) )
...@@ -230,7 +247,7 @@ def collect_vector_data_precompute(knl): ...@@ -230,7 +247,7 @@ def collect_vector_data_precompute(knl):
tag = get_pymbolic_tag(insn.assignee) tag = get_pymbolic_tag(insn.assignee)
horizontal, vertical = tuple(int(i) for i in re.match("vecsumfac_h(.*)_v(.*)", tag).groups()) horizontal, vertical = tuple(int(i) for i in re.match("vecsumfac_h(.*)_v(.*)", tag).groups())
if horizontal > 1: if horizontal > 1:
last_index = insn.assignee.index[-1] // vertical last_index = AntiPatternRemover()(insn.assignee.index[-1] // vertical)
else: else:
last_index = 0 last_index = 0
else: else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment