Skip to content
Snippets Groups Projects
Commit 9b39feaa authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Refactor direct input mechanism

for further extension
parent 4dfb8de5
No related branches found
No related tags found
No related merge requests found
......@@ -66,6 +66,9 @@ void transpose_reg(Vec8d& a0, Vec8d& a1, Vec8d& a2, Vec8d& a3)
a3 = blend8d<4,5,6,7,12,13,14,15>(b1, b3);
}
/** TODO: Is this transpose using blend8d superior to the swap_halves
* version below using get_low/get_high?
*/
void transpose_reg (Vec8d& a0, Vec8d& a1)
{
Vec8d b0, b1;
......
......@@ -62,8 +62,10 @@ def get_vcl_typename(nptype, register_size=None, vector_width=None):
class ExplicitVCLCast(lp.symbolic.FunctionIdentifier):
def __init__(self, nptype, vector_width):
def __init__(self, nptype, vector_width=None):
self.nptype = nptype
if vector_width is None:
vector_width = get_vcl_type_size(nptype)
self.vector_width = vector_width
def __getinitargs__(self):
......
......@@ -11,6 +11,7 @@ from dune.perftool.generation import (backend,
get_counted_variable,
get_counter,
get_global_context_value,
globalarg,
iname,
instruction,
kernel_cached,
......@@ -70,7 +71,11 @@ class LFSSumfactKernelInput(SumfactKernelInputBase, ImmutableRecord):
def __str__(self):
return repr(self)
def realize(self, sf, index, insn_dep):
@property
def direct_input_is_possible(self):
return get_form_option("fastdg")
def realize(self, sf, insn_dep, index=0):
lfs = name_lfs(self.element, self.restriction, self.element_index)
basisiname = sumfact_iname(name_lfs_bound(lfs), "basis")
container = self.coeff_func(self.restriction)
......@@ -85,18 +90,30 @@ class LFSSumfactKernelInput(SumfactKernelInputBase, ImmutableRecord):
assignee = prim.Subscript(prim.Variable(name),
(prim.Variable(basisiname),) + (index,))
instruction(assignee=assignee,
expression=coeff,
depends_on=sf.insn_dep.union(insn_dep),
tags=frozenset({"sumfact_stage{}".format(sf.stage)}),
)
@property
def direct_input(self):
if get_form_option("fastdg"):
return self.coeff_func(self.restriction)
else:
return None
insn = instruction(assignee=assignee,
expression=coeff,
depends_on=sf.insn_dep.union(insn_dep),
tags=frozenset({"sumfact_stage{}".format(sf.stage)}),
)
return insn_dep.union(frozenset({insn}))
def realize_direct(self, shape, inames):
arg = "{}_access{}".format(self.coeff_func(self.restriction),
"_comp{}".format(self.element_index) if self.element_index else ""
)
from dune.perftool.sumfact.realization import _dof_offset, alias_data_array
globalarg(arg,
shape=shape,
dim_tags=",".join("f" * len(shape)),
offset=_dof_offset(self.element, self.element_index),
)
func = self.coeff_func(self.restriction)
alias_data_array(arg, func)
return prim.Subscript(prim.Variable(arg), inames)
def _basis_functions_per_direction(element):
......
......@@ -75,22 +75,8 @@ def _realize_sum_factorization_kernel(sf):
),
}))
direct_input = sf.input.direct_input
# Set up the input for stage 1
if direct_input is None:
if sf.vectorized:
for i, inputsf in enumerate(sf.kernels):
inputsf.input.realize(sf, i, inputsf.insn_dep.union(insn_dep))
else:
sf.input.realize(sf, 0, insn_dep)
insn_dep = insn_dep.union(frozenset({lp.match.Writes("input_{}".format(sf.buffer))}))
else:
if sf.input.element_index is None:
direct_input_arg = "{}_access".format(direct_input)
else:
direct_input_arg = "{}_access_comp{}".format(direct_input, sf.input.element_index)
if not sf.input.direct_input_is_possible:
insn_dep = insn_dep.union(sf.input.realize(sf, insn_dep))
# Prepare some dim_tags/shapes for later use
ftags = ",".join(["f"] * sf.length)
......@@ -143,24 +129,12 @@ def _realize_sum_factorization_kernel(sf):
# * a global data structure (if FastDGGridOperator is in use)
# * a value from a global data structure, broadcasted to a vector type (vectorized + FastDGGridOperator)
input_inames = (k_expr,) + tuple(prim.Variable(j) for j in out_inames[1:])
if l == 0 and direct_input is not None:
if l == 0 and sf.input.direct_input_is_possible:
# See comment below
input_inames = permute_backward(input_inames, perm)
inp_shape = permute_backward(inp_shape, perm)
globalarg(direct_input_arg,
shape=inp_shape,
dim_tags=novec_ftags,
offset=_dof_offset(sf.input.element, sf.input.element_index),
)
alias_data_array(direct_input_arg, direct_input)
if matrix.vectorized:
input_summand = prim.Call(ExplicitVCLCast(dtype_floatingpoint(), vector_width=sf.vector_width),
(prim.Subscript(prim.Variable(direct_input_arg),
input_inames),))
else:
input_summand = prim.Subscript(prim.Variable(direct_input_arg),
input_inames + vec_iname)
input_summand = sf.input.realize_direct(inp_shape, input_inames)
else:
# If we did permute the order of a matrices above we also
# permuted the order of out_inames. Unfortunately the
......
......@@ -5,6 +5,8 @@ from dune.perftool.generation import get_counted_variable
from dune.perftool.pdelab.geometry import local_dimension, world_dimension
from dune.perftool.sumfact.quadrature import quadrature_inames
from dune.perftool.sumfact.tabulation import BasisTabulationMatrixBase, BasisTabulationMatrixArray
from dune.perftool.loopy.target import dtype_floatingpoint
from dune.perftool.loopy.vcl import ExplicitVCLCast
from pytools import ImmutableRecord, product
......@@ -18,11 +20,52 @@ import inspect
class SumfactKernelInputBase(object):
@property
def direct_input(self):
return None
def direct_input_is_possible(self):
return False
def realize(self, sf, i, dep):
pass
def realize(self, sf, dep, index=0):
return frozenset()
def realize_direct(self, inames):
raise NotImplementedError
class VectorSumfactKernelInput(SumfactKernelInputBase):
def __init__(self, inputs):
assert(isinstance(inputs, tuple))
self.inputs = inputs
@property
def direct_input_is_possible(self):
return all(i.direct_input_is_possible for i in self.inputs)
def realize(self, sf, dep):
for i, inp in enumerate(self.inputs):
dep = dep.union(inp.realize(sf, dep, index=i))
return dep
def realize_direct(self, shape, inames):
# Check whether the input exhibits a favorable structure
# (whether we can broadcast scalar values into SIMD registers)
total = set(self.inputs)
lower = set(self.inputs[:len(self.inputs) // 2])
upper = set(self.inputs[len(self.inputs) // 2:])
if len(total) == 1:
# All input coefficients use the exact same input coefficient.
# We implement this by broadcasting it into a SIMD register
return prim.Call(ExplicitVCLCast(dtype_floatingpoint()),
(self.inputs[0].realize_direct(shape, inames),)
)
elif len(total) == 2 and len(lower) == 1 and len(upper) == 1:
# The lower and the upper part of the SIMD register use
# the same input coefficient, we combine the SIMD register
# from two shorter SIMD types
raise NotImplementedError("Lower/Upper half SIMD loads not implemented!")
else:
# The input does not exhibit a broadcastable structure, we
# need to load scalars into the SIMD vector.
raise NotImplementedError("SIMD loads from scalars not implemented!")
class SumfactKernelBase(object):
......@@ -457,11 +500,6 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
def predicates(self):
return self.kernels[0].predicates
@property
def input(self):
assert len(set(k.input for k in self.kernels)) == 1
return self.kernels[0].input
@property
def accumvar(self):
assert len(set(k.accumvar for k in self.kernels)) == 1
......@@ -486,6 +524,10 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
# Define the same properties the normal SumfactKernel defines
#
@property
def input(self):
return VectorSumfactKernelInput(tuple(k.input for k in self.kernels))
@property
def cache_key(self):
return (tuple(k.cache_key for k in self.kernels), self.buffer)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment