Skip to content
Snippets Groups Projects
Commit ee14ba98 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Reimplement add_vector_view using the new custom base storage mechanism

parent 1d8c8aa1
No related branches found
No related tags found
No related merge requests found
......@@ -189,7 +189,7 @@ class DuneASTBuilder(CASTBuilder):
alignment = []
size = []
for t in temps:
if t.custom_base_storage == bs:
if isinstance(t, DuneTemporaryVariable) and t.custom_base_storage == bs:
# TODO: Extract correct size
alignment.append(8)
from pytools import product
......
......@@ -7,8 +7,7 @@ from dune.perftool.generation import (function_mangler,
)
from dune.perftool.loopy.target import dtype_floatingpoint
from dune.perftool.loopy.vcl import get_vcl_type, get_vcl_type_size
from dune.perftool.loopy.transformations.vectorview import (add_temporary_with_vector_view,
add_vector_view,
from dune.perftool.loopy.transformations.vectorview import (add_vector_view,
get_vector_view_name,
)
from dune.perftool.loopy.symbolic import substitute
......@@ -149,7 +148,7 @@ def _vectorize_quadrature_loop(knl, inames, suffix):
knl = knl.copy(temporary_variables=tmps)
# Introduce a vector view of the precomputation result
knl = add_vector_view(knl, prec_quantity, flatview=True)
knl = add_vector_view(knl, prec_quantity)
#
# Construct a flat loop for the given instructions
......@@ -196,7 +195,7 @@ def _vectorize_quadrature_loop(knl, inames, suffix):
horizontal, vertical = tuple(int(i) for i in re.match("vecsumfac_h(.*)_v(.*)", tag).groups())
# 1. Rotating the input data
knl = add_vector_view(knl, quantity, flatview=True)
knl = add_vector_view(knl, quantity)
if horizontal > 1:
new_insns.append(lp.CallInstruction((), # assignees
prim.Call(TransposeReg(vertical=vertical, horizontal=horizontal),
......@@ -219,7 +218,7 @@ def _vectorize_quadrature_loop(knl, inames, suffix):
elif tag is not None and tag == 'sumfac':
# Add a vector view to this quantity
expr, = quantity_exprs
knl = add_vector_view(knl, quantity, flatview=True)
knl = add_vector_view(knl, quantity)
replacemap[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)),
(vector_indices.get(1), prim.Variable(vec_iname)),
)
......@@ -243,7 +242,7 @@ def _vectorize_quadrature_loop(knl, inames, suffix):
for insn in insns:
# Get a vector view of the lhs expression
lhsname = get_pymbolic_basename(insn.assignee)
knl = add_vector_view(knl, lhsname, pad_to=vec_size, flatview=True)
knl = add_vector_view(knl, lhsname, pad_to=vec_size)
lhsname = get_vector_view_name(lhsname)
rotating = "gradvec" in insn.tags
......
......@@ -5,7 +5,9 @@ being a an array of SIMD vectors
"""
from dune.perftool.loopy.target import dtype_floatingpoint
from dune.perftool.loopy.temporary import DuneTemporaryVariable
from dune.perftool.loopy.vcl import get_vcl_type_size
from dune.perftool.tools import round_to_multiple
import loopy as lp
import numpy as np
......@@ -17,86 +19,47 @@ def get_vector_view_name(tmpname):
return tmpname + "_vec"
def add_vector_view(knl, tmpname, pad_to=None, flatview=False):
"""
Kernel transformation to add a vector view temporary
that interprets the same memory as another temporary
"""
def add_vector_view(knl, tmpname, pad_to=1):
temporaries = knl.temporary_variables
assert tmpname in temporaries
temp = temporaries[tmpname]
vecname = get_vector_view_name(tmpname)
vectemp = get_vector_view_name(tmpname)
bsname = tmpname + "_base"
vecsize = get_vcl_type_size(temp.dtype)
if vecname in knl.temporary_variables:
# Enforce idempotency
if vectemp in temporaries:
return knl
# Add base storage to the original temporary!
if not temp.base_storage:
temp = temp.copy(base_storage=bsname,
_base_storage_access_may_be_aliasing=True,
)
temporaries[tmpname] = temp
else:
bsname = temp.base_storage
# Determine the shape by dividing total size by vector size
# Also apply the padding we need for rotation
# TODO: *Only* apply this padding if really needed (a bit hard to figure out)
vecsize = get_vcl_type_size(temp.dtype)
if all(isinstance(s, int) for s in temp.shape):
size = pt.product(temp.shape) // vecsize
if size % vecsize != 0:
size = (size // vecsize + 1) * vecsize
# Modify the original temporary to use our custom base storage mechanism
if isinstance(temp, DuneTemporaryVariable):
if temp.custom_base_storage:
bsname = temp.custom_base_storage
else:
temp = temp.copy(custom_base_storage=bsname)
temporaries[tmpname] = temp
else:
size = prim.FloorDiv(prim.Product(temp.shape), vecsize)
size = (size // vecsize + 1) * vecsize
# Maybe do some padding.
if pad_to:
size = (size // pad_to + 1) * pad_to
temp = DuneTemporaryVariable(tmpname,
custom_base_storage=bsname,
**temp.get_copy_kwargs()
)
temporaries[tmpname] = temp
# Some vectorview are intentionally flat! (e.g. the output buffers of
# sum factorization kernels
if flatview:
shape = (size, vecsize)
dim_tags = "c,vec"
else:
shape = temp.shape
# This works around a loopy weirdness (which might as well be a bug)
# TODO: investigate this!
if len(shape) == 1:
shape = (1, vecsize)
dim_tags = "c,vec"
else:
dim_tags = temp.dim_tags[:-1] + ("vec",)
size = round_to_multiple(pt.product(temp.shape), vecsize) // vecsize
size = round_to_multiple(size, pad_to)
# Now add a vector view temporary
vecname = tmpname + "_vec"
temporaries[vecname] = lp.TemporaryVariable(vecname,
dim_tags=dim_tags,
shape=shape,
base_storage=bsname,
dtype=dtype_floatingpoint(),
scope=lp.temp_var_scope.PRIVATE,
_base_storage_access_may_be_aliasing=True,
)
# Avoid that any of these temporaries are eliminated
silenced = ['temp_to_write({})'.format(tmpname),
'temp_to_write({})'.format(vecname),
'read_no_write({})'.format(tmpname),
'read_no_write({})'.format(vecname),
temporaries[vectemp] = DuneTemporaryVariable(vectemp,
dim_tags="c,vec",
shape=(size, vecsize),
custom_base_storage=bsname,
scope=lp.temp_var_scope.PRIVATE,
managed=True,
)
# Avoid that these temporaries are eliminated
silenced = ['temp_to_write({})'.format(vectemp),
'read_no_write({})'.format(vectemp),
]
return knl.copy(temporary_variables=temporaries,
silenced_warnings=knl.silenced_warnings + silenced)
def add_temporary_with_vector_view(knl, name, *args, **kwargs):
temps = knl.temporary_variables
assert name not in temps
temps[name] = lp.TemporaryVariable(name, *args, **kwargs)
knl = knl.copy(temporary_variables=temps)
knl = add_vector_view(knl, name)
return knl
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment