Skip to content
Snippets Groups Projects
Commit fbb97441 authored by Marcel Koch's avatar Marcel Koch
Browse files

use lp.symbolic.FunctionIdentifier for vcl stores, loads, permutes

parent 5e0b69af
No related branches found
No related tags found
No related merge requests found
import loopy as lp
import numpy as np
import pymbolic.primitives as prim
from dune.perftool.loopy.target import dtype_floatingpoint
from dune.perftool.loopy.temporary import DuneTemporaryVariable
from dune.perftool.loopy.symbolic import substitute
from dune.perftool.loopy.vcl import get_vcl_type_size
from dune.perftool.loopy.vcl import get_vcl_type_size, VCLPermute, VCLLoad, VCLStore
from dune.perftool.options import get_option
from dune.perftool.pdelab.argument import PDELabAccumulationFunction
from dune.perftool.pdelab.geometry import world_dimension
......@@ -26,6 +27,9 @@ def add_vcl_temporaries(knl):
def add_vcl_accum_insns(knl, iname_inner, iname_outer):
nptype = dtype_floatingpoint()
vcl_size = get_vcl_type_size(np.float64)
from loopy.match import Tagged
accum_insns = set(lp.find_instructions(knl, Tagged('accum')))
......@@ -33,7 +37,6 @@ def add_vcl_accum_insns(knl, iname_inner, iname_outer):
vng = knl.get_var_name_generator()
idg = knl.get_instruction_id_generator()
new_vec_temporaries = dict()
vcl_size = get_vcl_type_size(np.float64)
for insn in knl.instructions:
# somehow CInstructions are not hashable....
if isinstance(insn, lp.MultiAssignmentBase) and insn in accum_insns:
......@@ -104,11 +107,8 @@ def add_vcl_accum_insns(knl, iname_inner, iname_outer):
# r+=a[iy]
id_accum = idg('insn_mod_accum')
expr_accum = prim.Sum((var_left,
prim.Call(
prim.Variable('permute{}d<-1,{}>'.format(vcl_size,
','.join(map(str, range(vcl_size - 1))))
),
(var_right,)),
prim.Call(VCLPermute(nptype, vcl_size, (-1,) + tuple(range(vcl_size - 1))),
(var_right,)),
substitute(insn.assignee, {iname_ix: 0})))
new_insns.append(lp.Assignment(assignee=substitute(insn.assignee, {iname_ix: 0}),
expression=expr_accum,
......@@ -119,10 +119,7 @@ def add_vcl_accum_insns(knl, iname_inner, iname_outer):
tags=frozenset({'accum'})))
# a[iy] = permute
id_permute = idg('insn_permute')
expr_permute = prim.Call(prim.Variable('permute{}d<3,{}>'.format(vcl_size,
','.join(['-1'] * (vcl_size - 1)))
),
(var_right,))
expr_permute = prim.Call(VCLPermute(nptype, vcl_size, (3,) + (-1,) * (vcl_size - 1)), (var_right,))
new_insns.append(lp.Assignment(assignee=var_left,
expression=expr_permute,
id=id_permute,
......@@ -170,6 +167,7 @@ def add_vcl_access(knl, iname_inner):
map_variable = map_constant
map_function_symbol = map_constant
map_loopy_function_identifier = map_constant
def map_subscript(self, expr):
if expr.aggregate.name.endswith('alias'):
......@@ -198,7 +196,7 @@ def add_vcl_access(knl, iname_inner):
# add load instruction
load_id = idg('insn_' + name_vec + '_load')
call_load = prim.Call(prim.Variable(name_vec + '.load'), (prim.Sum((prim.Variable(name_alias), index)),))
call_load = prim.Call(VCLLoad(name_vec), (prim.Sum((prim.Variable(name_alias), index)),))
load_insns.append(lp.CallInstruction(assignees=(), expression=call_load,
id=load_id, within_inames=insn.within_inames | insn.reduction_inames(),))
read_dependencies.setdefault(id, set())
......@@ -220,7 +218,7 @@ def add_vcl_access(knl, iname_inner):
# add store instruction
store_id = idg('insn_' + name_vec + '_store')
call_store = prim.Call(prim.Variable(name_vec + '.store'), (prim.Sum((prim.Variable(name_alias), index)),))
call_store = prim.Call(VCLStore(name_vec), (prim.Sum((prim.Variable(name_alias), index)),))
store_insns.append(lp.CallInstruction(assignees=(), expression=call_store,
id=store_id, within_inames=insn.within_inames,
depends_on=insn.depends_on | frozenset({id}) | read_dependencies[id]))
......
......@@ -80,6 +80,21 @@ def vcl_cast_mangler(knl, func, arg_dtypes):
return lp.CallMangleInfo(func.name, (lp.types.NumpyType(func.nptype),), (arg_dtypes[0],))
class VCLPermute(lp.symbolic.FunctionIdentifier):
def __init__(self, nptype, vector_width, permutation):
self.nptype = nptype
self.vector_width = vector_width
self.permutation = permutation
def __getinitargs__(self):
return (self.nptype, self.vector_width,self.permutation)
@property
def name(self):
return "permute{}<{}>".format(get_vcl_typename(self.nptype, vector_width=self.vector_width)[-2:],
','.join(map(str, self.permutation)))
@function_mangler
def vcl_function_mangler(knl, func, arg_dtypes):
if func == "mul_add":
......@@ -97,13 +112,40 @@ def vcl_function_mangler(knl, func, arg_dtypes):
vcl = lp.types.NumpyType(get_vcl_type(dtype))
return lp.CallMangleInfo("horizontal_add", (lp.types.NumpyType(dtype.dtype),), (vcl,))
if isinstance(func, str) and func.startswith('permute'):
if isinstance(func, VCLPermute):
dtype = arg_dtypes[0]
vcl = lp.types.NumpyType(get_vcl_type(dtype))
return lp.CallMangleInfo(func, (vcl,), (vcl,))
return lp.CallMangleInfo(func.name, (vcl,), (vcl,))
class VCLLoad(lp.symbolic.FunctionIdentifier):
def __init__(self, vec):
self.vec = vec
def __getinitargs__(self):
return (self.vec,)
if isinstance(func, str) and func.endswith('.load'):
return lp.CallMangleInfo(func, (), (lp.types.NumpyType(np.int32),))
@property
def name(self):
return "{}.load".format(self.vec)
class VCLStore(lp.symbolic.FunctionIdentifier):
def __init__(self, vec):
self.vec = vec
def __getinitargs__(self):
return (self.vec,)
@property
def name(self):
return "{}.store".format(self.vec)
@function_mangler
def vcl_store_and_load_mangler(knl, func, arg_dtypes):
if isinstance(func, VCLLoad):
return lp.CallMangleInfo(func.name, (), (lp.types.NumpyType(np.int32),))
if isinstance(func, str) and func.endswith('.store'):
return lp.CallMangleInfo(func, (), (lp.types.NumpyType(np.int32),))
if isinstance(func, VCLStore):
return lp.CallMangleInfo(func.name, (), (lp.types.NumpyType(np.int32),))
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment