Skip to content
Snippets Groups Projects
Commit 137bedb8 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Fix interplay of quadrature loop vectorization with vertical vectorization

parent c55424f5
No related branches found
No related tags found
No related merge requests found
...@@ -28,20 +28,14 @@ import re ...@@ -28,20 +28,14 @@ import re
class TransposeReg(lp.symbolic.FunctionIdentifier): class TransposeReg(lp.symbolic.FunctionIdentifier):
def __init__(self, def __init__(self,
vector_width=None, horizontal=1,
components=1, vertical=1,
): ):
if vector_width is None: self.horizontal = horizontal
vector_width = get_vcl_type_size(np.float64) self.vertical = vertical
# Non-quadratic transposes are not yet implemented
assert components == vector_width
self.vector_width = vector_width
self.components = components
def __getinitargs__(self): def __getinitargs__(self):
return (self.vector_width, self.components) return (self.horizontal, self.vertical)
@property @property
def name(self): def name(self):
...@@ -55,8 +49,8 @@ def rotate_function_mangler(knl, func, arg_dtypes): ...@@ -55,8 +49,8 @@ def rotate_function_mangler(knl, func, arg_dtypes):
# passing the vector registers as references and have them # passing the vector registers as references and have them
# changed. Loopy assumes this function to be read-only. # changed. Loopy assumes this function to be read-only.
include_file("dune/perftool/sumfact/transposereg.hh", filetag="operatorfile") include_file("dune/perftool/sumfact/transposereg.hh", filetag="operatorfile")
vcl = lp.types.NumpyType(get_vcl_type(np.float64, vector_width=func.vector_width)) vcl = lp.types.NumpyType(get_vcl_type(np.float64, vector_width=func.horizontal * func.vertical))
return lp.CallMangleInfo(func.name, (), (vcl,) * func.components) return lp.CallMangleInfo(func.name, (), (vcl,) * func.horizontal)
class VectorIndices(object): class VectorIndices(object):
...@@ -206,14 +200,17 @@ def collect_vector_data_rotate(knl): ...@@ -206,14 +200,17 @@ def collect_vector_data_rotate(knl):
# 1. Rotating the input data # 1. Rotating the input data
knl = add_vector_view(knl, quantity, flatview=True) knl = add_vector_view(knl, quantity, flatview=True)
new_insns.append(lp.CallInstruction((), # assignees if horizontal > 1:
prim.Call(TransposeReg(vector_width=horizontal*vertical, components=horizontal), new_insns.append(lp.CallInstruction((), # assignees
tuple(prim.Subscript(prim.Variable(get_vector_view_name(quantity)), (vector_indices.get(horizontal) + i, prim.Variable(new_iname))) for i in range(horizontal))), prim.Call(TransposeReg(vertical=vertical, horizontal=horizontal),
depends_on=frozenset({'continue_stmt'}), tuple(prim.Subscript(prim.Variable(get_vector_view_name(quantity)),
within_inames=common_inames.union(inames).union(frozenset({new_iname})), (vector_indices.get(horizontal) + i, prim.Variable(new_iname)))
within_inames_is_final=True, for i in range(horizontal))),
id="{}_rotate".format(quantity), depends_on=frozenset({'continue_stmt'}),
)) within_inames=common_inames.union(inames).union(frozenset({new_iname})),
within_inames_is_final=True,
id="{}_rotate".format(quantity),
))
# Add substitution rules # Add substitution rules
for expr in quantities[quantity]: for expr in quantities[quantity]:
...@@ -311,13 +308,16 @@ def collect_vector_data_rotate(knl): ...@@ -311,13 +308,16 @@ def collect_vector_data_rotate(knl):
rotating = "gradvec" in insn.tags rotating = "gradvec" in insn.tags
if rotating: if rotating:
# from pudb import set_trace; set_trace()
assert isinstance(insn.assignee, prim.Subscript) assert isinstance(insn.assignee, prim.Subscript)
tag = get_pymbolic_tag(insn.assignee) tag = get_pymbolic_tag(insn.assignee)
if tag is None: if tag is None:
print insn.assignee print insn.assignee
horizontal, vertical = tuple(int(i) for i in re.match("vecsumfac_h(.*)_v(.*)", tag).groups()) horizontal, vertical = tuple(int(i) for i in re.match("vecsumfac_h(.*)_v(.*)", tag).groups())
last_index = insn.assignee.index[-1] if horizontal > 1:
assert last_index in tuple(range(horizontal * vertical)) last_index = insn.assignee.index[-1]
else:
last_index = 0
else: else:
last_index = 0 last_index = 0
horizontal = 1 horizontal = 1
...@@ -336,10 +336,12 @@ def collect_vector_data_rotate(knl): ...@@ -336,10 +336,12 @@ def collect_vector_data_rotate(knl):
) )
# Rotate back! # Rotate back!
if rotating and "{}_rotateback".format(lhsname) not in [i.id for i in new_insns]: if rotating and "{}_rotateback".format(lhsname) not in [i.id for i in new_insns] and horizontal > 1:
new_insns.append(lp.CallInstruction((), # assignees new_insns.append(lp.CallInstruction((), # assignees
prim.Call(TransposeReg(components=horizontal, vector_width=horizontal*vertical), prim.Call(TransposeReg(horizontal=horizontal, vertical=vertical),
tuple(prim.Subscript(prim.Variable(lhsname), (vector_indices.get(horizontal) + i, prim.Variable(new_iname))) for i in range(horizontal))), tuple(prim.Subscript(prim.Variable(lhsname),
(vector_indices.get(horizontal) + i, prim.Variable(new_iname)))
for i in range(horizontal))),
depends_on=frozenset({Tagged("vec_write")}), depends_on=frozenset({Tagged("vec_write")}),
within_inames=common_inames.union(inames).union(frozenset({new_iname})), within_inames=common_inames.union(inames).union(frozenset({new_iname})),
within_inames_is_final=True, within_inames_is_final=True,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment