diff --git a/python/dune/perftool/loopy/transformations/collect_precompute.py b/python/dune/perftool/loopy/transformations/collect_precompute.py index 73f13619a14e9bf14df4ad808239f6f834a40dbe..7cdf72bac12b1071a24958745fd6325fc27cc1f9 100644 --- a/python/dune/perftool/loopy/transformations/collect_precompute.py +++ b/python/dune/perftool/loopy/transformations/collect_precompute.py @@ -19,7 +19,7 @@ from loopy.kernel.creation import parse_domains from loopy.symbolic import pw_aff_to_expr from loopy.match import Tagged -from loopy.symbolic import DependencyMapper +from loopy.symbolic import DependencyMapper, IdentityMapper from pytools import product import pymbolic.primitives as prim @@ -65,6 +65,23 @@ class VectorIndices(object): return prim.Variable(name) +class AntiPatternRemover(IdentityMapper): + def map_floor_div(self, expr): + """ (y + (x % n)) // n -> y // n """ + num = expr.numerator + den = expr.denominator + + if isinstance(num, prim.Remainder) and num.denominator == den: + return 0 + + if isinstance(num, prim.Sum) and len(num.children) == 2: + c0, c1 = num.children + if isinstance(c1, prim.Remainder) and c1.denominator == den: + return c0 // den + + return IdentityMapper.map_floor_div(self, expr) + + def collect_vector_data_precompute(knl): # # Process/Assert/Standardize the input @@ -189,7 +206,7 @@ def collect_vector_data_precompute(knl): # Add substitution rules for expr in quantity_exprs: assert isinstance(expr, prim.Subscript) - last_index = expr.index[-1] // vertical + last_index = AntiPatternRemover()(expr.index[-1] // vertical) replacemap[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)), (vector_indices.get(horizontal) + last_index, prim.Variable(vec_iname)), ) @@ -230,7 +247,7 @@ def collect_vector_data_precompute(knl): tag = get_pymbolic_tag(insn.assignee) horizontal, vertical = tuple(int(i) for i in re.match("vecsumfac_h(.*)_v(.*)", tag).groups()) if horizontal > 1: - last_index = insn.assignee.index[-1] // vertical + last_index = AntiPatternRemover()(insn.assignee.index[-1] // vertical) else: last_index = 0 else: