Skip to content
Snippets Groups Projects
Commit 46f74b0b authored by René Heß's avatar René Heß
Browse files

Cleanup permutations for vectorized case

parent 76161b7c
No related branches found
No related tags found
No related merge requests found
......@@ -135,7 +135,7 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord):
return ImmutableRecord.__repr__(self)
def get_keyword_arguments(self):
"""Get dictionary of keyword arguments needed to initialize this classIFIERS
"""Get dictionary of keyword arguments needed to initialize this class
Extract keyword arguments from the ImmutableRecord and modify
accordingly. You need to set the correct matrix sequence before using
......@@ -172,7 +172,10 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord):
return lfs_inames(get_leaf(self.trial_element, self.trial_element_index), self.restriction)
def realize_input(self, inames, shape, vec_iname, vec_shape, buf, ftags):
# TODO: This should happen in stage 2 and not in stage 3
# The result of stage 2 has the correct quadrature permutation but no
# cost permutation is applied. The inames for this method are
# quadrature and cost permuted. This means we need to reverse the cost
# permutation to access the result of stage 2 in the correct way.
shape = permute_backward(shape, self.cost_permutation)
inames = permute_backward(inames, self.cost_permutation)
......@@ -258,9 +261,8 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord):
return frozenset({dep})
def realize_direct_output(self, result, inames, shape, which=0, reverse_cost_permutation=True, **args):
if reverse_cost_permutation:
inames = permute_backward(inames, self.cost_permutation)
def realize_direct_output(self, result, inames, shape, which=0, **args):
inames = permute_backward(inames, self.cost_permutation)
inames = permute_backward(inames, self.quadrature_permutation)
direct_output = "fastdg{}".format(which)
......
......@@ -95,7 +95,7 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord):
return repr(self)
def get_keyword_arguments(self):
"""Get dictionary of keyword arguments needed to initialize this classIFIERS
"""Get dictionary of keyword arguments needed to initialize this class
Extract keyword arguments from the ImmutableRecord and modify
accordingly. You need to set the correct matrix sequence before using
......
......@@ -104,7 +104,7 @@ class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableRecord):
return repr(self)
def get_keyword_arguments(self):
"""Get dictionary of keyword arguments needed to initialize this classIFIERS
"""Get dictionary of keyword arguments needed to initialize this class
Extract keyword arguments from the ImmutableRecord and modify
accordingly. You need to set the correct matrix sequence before using
......
......@@ -194,7 +194,6 @@ def realize_sumfact_kernel_function(sf):
if l == 0 and sf.stage == 1 and sf.interface.direct_is_possible:
input_summand = sf.interface.realize_direct_input(input_inames, inp_shape)
elif l == 0:
# TODO: Simplify arguments!
input_summand = sf.interface.realize_input(input_inames,
inp_shape,
vec_iname,
......
......@@ -31,8 +31,19 @@ import inspect
class SumfactKernelInterfaceBase(object):
""" A base class for the input/output of a sum factorization kernel
"""A base class for the input/output of a sum factorization kernel
In stage 1, this represents the input object, in stage 3 the output object.
Notes about permutations:
- setup_input: handle cost and quadrature permutation
- realize_input stage 1: no permutations
- realize_input stage 3: only cost permutation
- realize_direct_input: cost and quadrature permutation
- accumulate_output: no permutation
- realize_direct_output: cost and quadrature permutation
In the vectorized case most permutation handling is forwarded to the scalar
kernels.
"""
def setup_input(self, sf, insn_dep, index=0):
"""Create and fill an input array for a stage 1 sumfact kernel function (non fastdg)
......@@ -57,7 +68,9 @@ class SumfactKernelInterfaceBase(object):
Stage 1 : Input is already permuted the right way in setup_input.
Stage 3 : TODO -> Check permutation in accumulation.py
Stage 3 : The inames are cost and quadrature permuted but the input is
only quadrature permuted. This means we need to reverse the cost
permutation on the inames.
Parameters
----------
......@@ -118,12 +131,12 @@ class SumfactKernelInterfaceBase(object):
"""
raise NotImplementedError
def realize_direct_output(self, result, iname, shape, which=0, reverse_cost_permutation=True, **kwargs):
def realize_direct_output(self, result, iname, shape, which=0, **kwargs):
"""Accumulate results directly in the sumfact kernel function (fastdg)
This happens inside the sumfact kernel function.
TODO: Add note about permutation
Needs to handle cost and quadrature permutation.
Parameters
----------
......@@ -133,8 +146,6 @@ class SumfactKernelInterfaceBase(object):
shape : tuple of ints
which : int
TODO Doc me!
reverse_cost_permutation : tuple of ints
TODO Doc me!
**kwargs :
Key word arguments passed to loopy instruction
"""
......@@ -209,33 +220,30 @@ class SumfactKernelInterfaceBase(object):
class VectorSumfactKernelInput(SumfactKernelInterfaceBase):
def __init__(self, interfaces, perm):
def __init__(self, interfaces):
assert(isinstance(interfaces, tuple))
self.interfaces = interfaces
self.vector_cost_permutation = perm
def __repr__(self):
return "_".join(repr(i) for i in self.interfaces)
@property
def quadrature_permutation(self):
# TODO: For now we assure that all kerneles have the same quadrature_permutation
# TODO: For now we only vectorize sumfact kernels with the same
# quadrature permutation. This should be extended.
for i in self.interfaces:
assert i.quadrature_permutation == self.interfaces[0].quadrature_permutation
return self.interfaces[0].quadrature_permutation
@property
def cost_permutation(self):
# The cost_permutation of the underlying scalar SumfactKernel can be
# different for each kernel.
# raise RuntimeError("cost_permutation should not be called on VectorSumfactKernelInput")
# TODO!
cost_permutation = self.interfaces[0].cost_permutation
# This should hold true due to the choice of quadrature
# permutation. For both structured and unstructured grids the order of
# the global directions should be the same leading to the same cost
# permutation for all those sum factorization kernels.
for i in self.interfaces:
assert i.cost_permutation == cost_permutation
return self.vector_cost_permutation
assert i.cost_permutation == self.interfaces[0].cost_permutation
return self.interfaces[0].cost_permutation
@property
def stage(self):
......@@ -251,8 +259,6 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase):
return dep
def realize_input(self, inames, shape, vec_iname, vec_shape, buf, ftags):
# TODO: vector_cost_permutation not used!
# Get a temporary that interprets the base storage of the input
# as a column-major matrix. In later iteration of the matrix loop
# this reinterprets the output of the previous iteration.
......@@ -268,8 +274,6 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase):
return prim.Subscript(prim.Variable(inp), inames + vec_iname)
def realize_direct_input(self, inames, shape):
# TODO: vector_cost_permutation not used!
# Check whether the input exhibits a favorable structure
# (whether we can broadcast scalar values into SIMD registers)
total = set(self.interfaces)
......@@ -317,25 +321,26 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase):
class VectorSumfactKernelOutput(SumfactKernelInterfaceBase):
def __init__(self, interfaces, perm):
def __init__(self, interfaces):
self.interfaces = interfaces
self._cost_permutation = perm
def __repr__(self):
return "_".join(repr(o) for o in self.interfaces)
@property
def cost_permutation(self):
return self._cost_permutation
# This should hold true due to the choice of quadrature
# permutation. For both structured and unstructured grids the order of
# the global directions should be the same leading to the same cost
# permutation for all those sum factorization kernels.
for i in self.interfaces:
assert i.cost_permutation == self.interfaces[0].cost_permutation
return self.interfaces[0].cost_permutation
@property
def quadrature_permutation(self):
# TODO: This should be turned into an error.
#
#
# The quadrature permutation could be different for different scalar kernels!
# raise RuntimeError('quadrature_permutation should not be called on VectorSumfactKernelOutput')
#
# TODO: For now we only vectorize sumfact kernels with the same
# quadrature permutation. This should be extended .
for i in self.interfaces:
assert i.quadrature_permutation == self.interfaces[0].quadrature_permutation
return self.interfaces[0].quadrature_permutation
......@@ -381,10 +386,6 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase):
return prim.Subscript(prim.Variable(inp), inames + vec_iname)
def realize_direct_output(self, result, inames, shape, **args):
# TODO: Find out what needs to happen here
inames = permute_backward(inames, self.cost_permutation)
shape = permute_backward(shape, self.cost_permutation)
outputs = set(self.interfaces)
# If multiple horizontal_add's are to be performed with 'result'
......@@ -403,7 +404,6 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase):
inames,
shape,
which=which,
reverse_cost_permutation=False,
**args))
return deps
......@@ -610,7 +610,6 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
# quadrature_permutation. This should be handled like upper/lower
# vectorization
return tuple(m.quadrature_size for m in self.matrix_sequence_quadrature_permuted) + tuple(m.basis_size for m in self.matrix_sequence_quadrature_permuted) + (self.stage, self.buffer, self.interface.within_inames) + (self.interface.direct_is_possible, self.interface.quadrature_permutation)
# return tuple(m.quadrature_size for m in self.matrix_sequence_quadrature_permuted) + tuple(m.basis_size for m in self.matrix_sequence_quadrature_permuted) + (self.stage, self.buffer, self.interface.within_inames) + (self.interface.direct_is_possible,)
@property
def cache_key(self):
......@@ -951,7 +950,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
@property
def cost_permutation(self):
return self.kernels[0].cost_permutation
raise RuntimeError("cost_permutation should not be used on VectorizedSumfactKernel.")
@property
def stage(self):
......@@ -1001,11 +1000,10 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
@property
def interface(self):
perm = self.cost_permutation
if self.stage == 1:
return VectorSumfactKernelInput(tuple(k.interface for k in self.kernels), perm)
return VectorSumfactKernelInput(tuple(k.interface for k in self.kernels))
else:
return VectorSumfactKernelOutput(tuple(k.interface for k in self.kernels), perm)
return VectorSumfactKernelOutput(tuple(k.interface for k in self.kernels))
@property
def cache_key(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment