diff --git a/python/dune/perftool/loopy/target.py b/python/dune/perftool/loopy/target.py index bdee930a9dd9b68183282d3905131d317241af8e..b492f1374b7b7f3057aea78cf2a28921a320c15a 100644 --- a/python/dune/perftool/loopy/target.py +++ b/python/dune/perftool/loopy/target.py @@ -189,7 +189,7 @@ class DuneASTBuilder(CASTBuilder): alignment = [] size = [] for t in temps: - if t.custom_base_storage == bs: + if isinstance(t, DuneTemporaryVariable) and t.custom_base_storage == bs: # TODO: Extract correct size alignment.append(8) from pytools import product diff --git a/python/dune/perftool/loopy/transformations/vectorize_quad.py b/python/dune/perftool/loopy/transformations/vectorize_quad.py index fa5b03c204b4d77f628a7566897d33dae3e67a7c..5022ebaf8d58875705777d5c107fae2e0bb34f15 100644 --- a/python/dune/perftool/loopy/transformations/vectorize_quad.py +++ b/python/dune/perftool/loopy/transformations/vectorize_quad.py @@ -7,8 +7,7 @@ from dune.perftool.generation import (function_mangler, ) from dune.perftool.loopy.target import dtype_floatingpoint from dune.perftool.loopy.vcl import get_vcl_type, get_vcl_type_size -from dune.perftool.loopy.transformations.vectorview import (add_temporary_with_vector_view, - add_vector_view, +from dune.perftool.loopy.transformations.vectorview import (add_vector_view, get_vector_view_name, ) from dune.perftool.loopy.symbolic import substitute @@ -149,7 +148,7 @@ def _vectorize_quadrature_loop(knl, inames, suffix): knl = knl.copy(temporary_variables=tmps) # Introduce a vector view of the precomputation result - knl = add_vector_view(knl, prec_quantity, flatview=True) + knl = add_vector_view(knl, prec_quantity) # # Construct a flat loop for the given instructions @@ -196,7 +195,7 @@ def _vectorize_quadrature_loop(knl, inames, suffix): horizontal, vertical = tuple(int(i) for i in re.match("vecsumfac_h(.*)_v(.*)", tag).groups()) # 1. Rotating the input data - knl = add_vector_view(knl, quantity, flatview=True) + knl = add_vector_view(knl, quantity) if horizontal > 1: new_insns.append(lp.CallInstruction((), # assignees prim.Call(TransposeReg(vertical=vertical, horizontal=horizontal), @@ -219,7 +218,7 @@ def _vectorize_quadrature_loop(knl, inames, suffix): elif tag is not None and tag == 'sumfac': # Add a vector view to this quantity expr, = quantity_exprs - knl = add_vector_view(knl, quantity, flatview=True) + knl = add_vector_view(knl, quantity) replacemap[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)), (vector_indices.get(1), prim.Variable(vec_iname)), ) @@ -243,7 +242,7 @@ def _vectorize_quadrature_loop(knl, inames, suffix): for insn in insns: # Get a vector view of the lhs expression lhsname = get_pymbolic_basename(insn.assignee) - knl = add_vector_view(knl, lhsname, pad_to=vec_size, flatview=True) + knl = add_vector_view(knl, lhsname, pad_to=vec_size) lhsname = get_vector_view_name(lhsname) rotating = "gradvec" in insn.tags diff --git a/python/dune/perftool/loopy/transformations/vectorview.py b/python/dune/perftool/loopy/transformations/vectorview.py index 9b74e3d8ca67cc28584def2ec0055b5504febfc4..0ff267067ec6d7d94fc77995edb5c30fd4c00391 100644 --- a/python/dune/perftool/loopy/transformations/vectorview.py +++ b/python/dune/perftool/loopy/transformations/vectorview.py @@ -5,7 +5,9 @@ being a an array of SIMD vectors """ from dune.perftool.loopy.target import dtype_floatingpoint +from dune.perftool.loopy.temporary import DuneTemporaryVariable from dune.perftool.loopy.vcl import get_vcl_type_size +from dune.perftool.tools import round_to_multiple import loopy as lp import numpy as np @@ -17,86 +19,47 @@ def get_vector_view_name(tmpname): return tmpname + "_vec" -def add_vector_view(knl, tmpname, pad_to=None, flatview=False): - """ - Kernel transformation to add a vector view temporary - that interprets the same memory as another temporary - """ +def add_vector_view(knl, tmpname, pad_to=1): temporaries = knl.temporary_variables - assert tmpname in temporaries temp = temporaries[tmpname] - vecname = get_vector_view_name(tmpname) + vectemp = get_vector_view_name(tmpname) bsname = tmpname + "_base" + vecsize = get_vcl_type_size(temp.dtype) - if vecname in knl.temporary_variables: + # Enforce idempotency + if vectemp in temporaries: return knl - # Add base storage to the original temporary! - if not temp.base_storage: - temp = temp.copy(base_storage=bsname, - _base_storage_access_may_be_aliasing=True, - ) - temporaries[tmpname] = temp - else: - bsname = temp.base_storage - - # Determine the shape by dividing total size by vector size - # Also apply the padding we need for rotation - # TODO: *Only* apply this padding if really needed (a bit hard to figure out) - vecsize = get_vcl_type_size(temp.dtype) - if all(isinstance(s, int) for s in temp.shape): - size = pt.product(temp.shape) // vecsize - if size % vecsize != 0: - size = (size // vecsize + 1) * vecsize + # Modify the original temporary to use our custom base storage mechanism + if isinstance(temp, DuneTemporaryVariable): + if temp.custom_base_storage: + bsname = temp.custom_base_storage + else: + temp = temp.copy(custom_base_storage=bsname) + temporaries[tmpname] = temp else: - size = prim.FloorDiv(prim.Product(temp.shape), vecsize) - size = (size // vecsize + 1) * vecsize - - # Maybe do some padding. - if pad_to: - size = (size // pad_to + 1) * pad_to + temp = DuneTemporaryVariable(tmpname, + custom_base_storage=bsname, + **temp.get_copy_kwargs() + ) + temporaries[tmpname] = temp - # Some vectorview are intentionally flat! (e.g. the output buffers of - # sum factorization kernels - if flatview: - shape = (size, vecsize) - dim_tags = "c,vec" - else: - shape = temp.shape - # This works around a loopy weirdness (which might as well be a bug) - # TODO: investigate this! - if len(shape) == 1: - shape = (1, vecsize) - dim_tags = "c,vec" - else: - dim_tags = temp.dim_tags[:-1] + ("vec",) + size = round_to_multiple(pt.product(temp.shape), vecsize) // vecsize + size = round_to_multiple(size, pad_to) # Now add a vector view temporary - vecname = tmpname + "_vec" - temporaries[vecname] = lp.TemporaryVariable(vecname, - dim_tags=dim_tags, - shape=shape, - base_storage=bsname, - dtype=dtype_floatingpoint(), - scope=lp.temp_var_scope.PRIVATE, - _base_storage_access_may_be_aliasing=True, - ) - - # Avoid that any of these temporaries are eliminated - silenced = ['temp_to_write({})'.format(tmpname), - 'temp_to_write({})'.format(vecname), - 'read_no_write({})'.format(tmpname), - 'read_no_write({})'.format(vecname), + temporaries[vectemp] = DuneTemporaryVariable(vectemp, + dim_tags="c,vec", + shape=(size, vecsize), + custom_base_storage=bsname, + scope=lp.temp_var_scope.PRIVATE, + managed=True, + ) + + # Avoid that these temporaries are eliminated + silenced = ['temp_to_write({})'.format(vectemp), + 'read_no_write({})'.format(vectemp), ] return knl.copy(temporary_variables=temporaries, silenced_warnings=knl.silenced_warnings + silenced) - - -def add_temporary_with_vector_view(knl, name, *args, **kwargs): - temps = knl.temporary_variables - assert name not in temps - temps[name] = lp.TemporaryVariable(name, *args, **kwargs) - knl = knl.copy(temporary_variables=temps) - knl = add_vector_view(knl, name) - return knl