Skip to content
Snippets Groups Projects
Commit 73efbc9a authored by René Heß's avatar René Heß
Browse files

Make permuted_matrix_sequence a property

parent ba4f2609
No related branches found
No related tags found
No related merge requests found
......@@ -126,11 +126,11 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord):
def realize(self, sf, result, insn_dep, inames=None, additional_inames=()):
trial_leaf_element = get_leaf(self.trial_element, self.trial_element_index) if self.trial_element is not None else None
basis_size = tuple(mat.basis_size for mat in sf.matrix_sequence)
basis_size = tuple(mat.basis_size for mat in sf.permuted_matrix_sequence)
if inames is None:
inames = tuple(accum_iname(trial_leaf_element, mat.rows, i)
for i, mat in enumerate(sf.matrix_sequence))
for i, mat in enumerate(sf.permuted_matrix_sequence))
# Determine the expression to accumulate with. This depends on the vectorization strategy!
from dune.perftool.tools import maybe_wrap_subscript
......
......@@ -89,7 +89,7 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord):
from dune.perftool.sumfact.realization import name_buffer_storage
name = "input_{}".format(sf.buffer)
temporary_variable(name,
shape=(product(mat.basis_size for mat in sf.matrix_sequence), sf.vector_width),
shape=(product(mat.basis_size for mat in sf.permuted_matrix_sequence), sf.vector_width),
custom_base_storage=name_buffer_storage(sf.buffer, 0),
managed=True,
)
......
......@@ -52,7 +52,7 @@ def sumfact_permutation_strategy(sf):
heuristic is used to pick one.
"""
# Extract information from the SumfactKernel object
matrix_sequence = sf.matrix_sequence
matrix_sequence = sf.permuted_matrix_sequence
stage = sf.stage
# Combine permutation and matrix_sequence
......
......@@ -76,8 +76,8 @@ def _realize_sum_factorization_kernel(sf):
for buf in buffers:
# Determine the necessary size of the buffer. We assume that we do not
# underintegrate the form!!!
size = max(product(m.quadrature_size for m in sf.matrix_sequence) * sf.vector_width,
product(m.basis_size for m in sf.matrix_sequence) * sf.vector_width)
size = max(product(m.quadrature_size for m in sf.permuted_matrix_sequence) * sf.vector_width,
product(m.basis_size for m in sf.permuted_matrix_sequence) * sf.vector_width)
temporary_variable("{}_dummy".format(buf),
shape=(size,),
custom_base_storage=buf,
......@@ -161,7 +161,7 @@ def realize_sumfact_kernel_function(sf):
perm = sumfact_permutation_strategy(sf)
# Permute matrix sequence
matrix_sequence = permute_forward(sf.matrix_sequence, perm)
matrix_sequence = permute_forward(sf.permuted_matrix_sequence, perm)
# Product of all matrices
for l, matrix in enumerate(matrix_sequence):
......@@ -212,7 +212,7 @@ def realize_sumfact_kernel_function(sf):
inp_shape = permute_backward(inp_shape, perm)
input_inames = permute_backward(input_inames, perm)
if sf.stage == 1:
# In the unstructured case the sf.matrix_sequence could
# In the unstructured case the sf.permuted_matrix_sequence could
# already be permuted according to
# sf.quadrature_permutation. We also need to reverse this
# permutation to get the input from 0 to d-1.
......
......@@ -159,7 +159,7 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase):
from dune.perftool.sumfact.accumulation import accum_iname
element = get_leaf(trial_element, trial_element_index) if trial_element is not None else None
inames = tuple(accum_iname(element, mat.rows, i)
for i, mat in enumerate(sf.matrix_sequence))
for i, mat in enumerate(sf.permuted_matrix_sequence))
veciname = accum_iname(element, sf.vector_width // len(outputs), "vec")
transform(lp.tag_inames, [(veciname, "vec")])
......@@ -299,10 +299,12 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
for a in SumfactKernel.init_arg_names:
defaultdict[a] = eval(a)
dim = len(matrix_sequence)
# Not sure if this whole permuting would make sense if we would do sum
# factorized evaluation of intersections where len(matrix_sequence)
# would not be equal to world dim.
dim = len(matrix_sequence)
assert dim == world_dimension()
# Get restriction for this sum factorization kernel. Note: For
# accumulation output we have a restriction for the test (index 0) and
......@@ -314,20 +316,9 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
assert len(restriction) is 2
restriction = restriction[0]
perm = sumfact_quadrature_permutation_strategy(dim, restriction)
permuted_matrix_sequence = []
for i in perm:
for mat in matrix_sequence:
if mat.direction == i:
permuted_matrix_sequence.append(mat)
permuted_matrix_sequence = tuple(permuted_matrix_sequence)
from dune.perftool.sumfact.switch import get_facedir, get_facemod
facedir = get_facedir(restriction)
facemod = get_facemod(restriction)
defaultdict['matrix_sequence'] = permuted_matrix_sequence
defaultdict['quadrature_permutation'] = perm
# Store correct quadrature_permutation
quadrature_permuation = sumfact_quadrature_permutation_strategy(dim, restriction)
defaultdict['quadrature_permutation'] = quadrature_permuation
# Call the base class constructors
ImmutableRecord.__init__(self, **defaultdict)
......@@ -341,13 +332,14 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
return tuple(getattr(self, arg) for arg in SumfactKernel.init_arg_names)
def stringifier(self):
# Uses __str__ below
return lp.symbolic.StringifyMapper
def __str__(self):
# Above stringifier just calls back into this
# Return permuted_matrix_sequence
return "SF{}:[{}]->[{}]".format(self.stage,
str(self.interface),
", ".join(str(m) for m in self.matrix_sequence))
", ".join(str(m) for m in self.permuted_matrix_sequence))
mapper_method = "map_sumfact_kernel"
......@@ -358,21 +350,25 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
@property
def function_name(self):
""" The name of the function that implements this kernel """
name = "sfimpl_{}{}".format("_".join(str(m) for m in self.matrix_sequence),
# Use permuted_matrix_sequence here since this is more consistent with
# the vectorized case
name = "sfimpl_{}{}".format("_".join(str(m) for m in self.permuted_matrix_sequence),
self.interface.function_name_suffix)
# On unstructured we need different permutation of the input to realize
# different permuation of quadrature points on self and neighbor. Mange
# different permuation of quadrature points on self and neighbor. Mangle
# the permutation of the quadrature points into the name to generate
# sperate functions.
if self.quadrature_permutation != tuple(range(len(self.matrix_sequence))):
name_quad_perm = "_qpperm_{}".format("_".join(str(a) for a in self.quadrature_permutation))
name_quad_perm = "_qpperm_{}".format("".join(str(a) for a in self.quadrature_permutation))
name = name + name_quad_perm
return name
@property
def parallel_key(self):
""" A key that identifies parallellizable kernels. """
return tuple(m.basis_size for m in self.matrix_sequence) + (self.stage, self.buffer)
return tuple(m.basis_size for m in self.permuted_matrix_sequence) + (self.stage, self.buffer)
@property
def cache_key(self):
......@@ -436,13 +432,26 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
"""
return 0
@property
def permuted_matrix_sequence(self):
"""Matrix sequence ordered according to desired quadrature point ordered
Except for face integrals on 3D unstructured grids this will just be
the matrix sequence. In this special case it might be the reverse order
to ensure that quadrature points are visited in the same order on self
and neighbor.
"""
perm = self.quadrature_permutation
permuted_matrix_sequence = permute_forward(self.matrix_sequence, perm)
return permuted_matrix_sequence
@property
def quadrature_shape(self):
""" The shape of a temporary for the quadrature points
Takes into account the lower dimensionality of faces and vectorization.
"""
return tuple(mat.quadrature_size for mat in self.matrix_sequence)
return tuple(mat.quadrature_size for mat in self.permuted_matrix_sequence)
def quadrature_index(self, sf, visitor):
if visitor.current_info[1] is None:
......@@ -455,14 +464,14 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
element = element.extract_component(element_index)[1]
quad_inames = quadrature_inames(element)
if len(self.matrix_sequence) == local_dimension():
if len(self.permuted_matrix_sequence) == local_dimension():
return tuple(prim.Variable(i) for i in quad_inames)
# Traverse all the quadrature inames and map them to their correct direction
index = []
i = 0
for d in range(world_dimension()):
if self.matrix_sequence[d].face is None:
if self.permuted_matrix_sequence[d].face is None:
index.append(prim.Variable(quad_inames[i]))
i = i + 1
else:
......@@ -485,7 +494,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
Takes into account vectorization.
"""
return tuple(mat.basis_size for mat in self.matrix_sequence)
return tuple(mat.basis_size for mat in self.permuted_matrix_sequence)
@property
def dof_dimtags(self):
......@@ -560,7 +569,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
""" The total number of floating point operations for the kernel
to be carried out """
from dune.perftool.sumfact.permutation import flop_cost
return flop_cost(self.matrix_sequence)
return flop_cost(self.permuted_matrix_sequence)
# Extract the argument list and store it on the class. This needs to be done
......@@ -576,7 +585,6 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
vertical_width=1,
buffer=None,
insn_dep=frozenset(),
quadrature_permutation=None,
):
# Assert the input data structure
assert isinstance(kernels, tuple)
......@@ -590,18 +598,17 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
# Assert properties of the matrix sequence of the underlying kernels
for i in range(kernels[0].length):
assert len(set(tuple(k.matrix_sequence[i].rows for k in kernels))) == 1
assert len(set(tuple(k.matrix_sequence[i].cols for k in kernels))) == 1
assert len(set(tuple(k.matrix_sequence[i].direction for k in kernels))) == 1
assert len(set(tuple(k.matrix_sequence[i].transpose for k in kernels))) == 1
assert len(set(tuple(k.permuted_matrix_sequence[i].rows for k in kernels))) == 1
assert len(set(tuple(k.permuted_matrix_sequence[i].cols for k in kernels))) == 1
assert len(set(tuple(k.permuted_matrix_sequence[i].direction for k in kernels))) == 1
assert len(set(tuple(k.permuted_matrix_sequence[i].transpose for k in kernels))) == 1
# Join the instruction dependencies of all subkernels
insn_dep = insn_dep.union(k.insn_dep for k in kernels)
# Get order of quadrature points
quadrature_permutation = kernels[0].quadrature_permutation
# Assert that quadrature permutation is the same for all kernels
for k in kernels:
assert k.quadrature_permutation == quadrature_permutation
assert k.quadrature_permutation == kernels[0].quadrature_permutation
# We currently assume that all subkernels are consecutive, 0-based within the vector
assert None not in kernels
......@@ -612,7 +619,6 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
buffer=buffer,
insn_dep=insn_dep,
vertical_width=vertical_width,
quadrature_permutation=quadrature_permutation,
)
prim.Variable.__init__(self, "VecSUMFAC")
......@@ -653,14 +659,17 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
#
# Deduce all data fields of normal sum factorization kernels from the underlying kernels
#
@property
def matrix_sequence(self):
return tuple(BasisTabulationMatrixArray(tuple(k.matrix_sequence[i] for k in self.kernels),
return tuple(BasisTabulationMatrixArray(tuple(k.permuted_matrix_sequence[i] for k in self.kernels),
width=self.vector_width,
)
for i in range(self.length))
@property
def permuted_matrix_sequence(self):
return self.matrix_sequence
@property
def stage(self):
return self.kernels[0].stage
......@@ -669,6 +678,10 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
def restriction(self):
return self.kernels[0].restriction
@property
def quadrature_permutation(self):
return self.kernels[0].quadrature_permutation
@property
def within_inames(self):
return self.kernels[0].within_inames
......@@ -726,7 +739,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
def horizontal_index(self, sf):
for i, k in enumerate(self.kernels):
if sf.inout_key == k.inout_key:
if tuple(mat.derivative for mat in sf.matrix_sequence) == tuple(mat.derivative for mat in k.matrix_sequence):
if tuple(mat.derivative for mat in sf.permuted_matrix_sequence) == tuple(mat.derivative for mat in k.permuted_matrix_sequence):
return i
return 0
......@@ -781,7 +794,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
quad_inames = quadrature_inames(element)
sliced = 0
if len(sf.matrix_sequence) == local_dimension():
if len(sf.permuted_matrix_sequence) == local_dimension():
for d in range(local_dimension()):
if self.matrix_sequence[d].slice_size:
sliced = prim.Variable(quad_inames[d])
......@@ -851,7 +864,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
to be executed - neglecting the existence of caches of course
"""
dofs = product(mat.basis_size for mat in self.matrix_sequence)
matrices = sum(mat.memory_traffic for mat in set(matrix_sequence))
matrices = sum(mat.memory_traffic for mat in set(self.matrix_sequence))
fbytes = get_option("precision_bits") / 8
return (dofs + matrices) * fbytes
......
......@@ -16,8 +16,7 @@ from dune.perftool.generation import (backend,
from dune.perftool.pdelab.restriction import (Restriction,
restricted_name,
)
from dune.perftool.sumfact.tabulation import (BasisTabulationMatrixArray,
quadrature_points_per_direction,
from dune.perftool.sumfact.tabulation import (quadrature_points_per_direction,
set_quadrature_points,
)
from dune.perftool.error import PerftoolVectorizationError
......@@ -254,7 +253,7 @@ def level1_optimal_vectorization_strategy(sumfacts, width):
sf = next(iter(sumfacts))
depth = 1
while depth <= width:
i = 0 if sf.matrix_sequence[0].face is None else 1
i = 0 if sf.permuted_matrix_sequence[0].face is None else 1
quad = list(quadrature_points_per_direction())
quad[i] = round_to_multiple(quad[i], depth)
quad_points.append(tuple(quad))
......@@ -390,14 +389,14 @@ def get_vectorization_dict(sumfacts, vertical, horizontal, qp):
continue
# Determine the slicing direction
slice_direction = 0 if sf.matrix_sequence[0].face is None else 1
slice_direction = 0 if sf.permuted_matrix_sequence[0].face is None else 1
if qp[slice_direction] % vertical != 0:
return None
# Split the basis tabulation matrices
oldtab = sf.matrix_sequence[slice_direction]
oldtab = sf.permuted_matrix_sequence[slice_direction]
for i in range(vertical):
seq = list(sf.matrix_sequence)
seq = list(sf.permuted_matrix_sequence)
seq[slice_direction] = oldtab.copy(slice_size=vertical,
slice_index=i)
kernels.append(sf.copy(matrix_sequence=tuple(seq)))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment