Skip to content
Snippets Groups Projects
Commit d920c3e1 authored by René Heß's avatar René Heß
Browse files

[skip ci][WIP] Code cleanup

parent 3709281b
No related branches found
No related tags found
No related merge requests found
...@@ -93,7 +93,7 @@ def generate_standalone_code(sf, filename): ...@@ -93,7 +93,7 @@ def generate_standalone_code(sf, filename):
f.write(" using DF = {};\n".format(real)) f.write(" using DF = {};\n".format(real))
from dune.codegen.sumfact.tabulation import name_polynomials from dune.codegen.sumfact.tabulation import name_polynomials
degs = tuple(m.basis_size - 1 for m in sf.matrix_sequence) degs = tuple(m.basis_size - 1 for m in sf.matrix_sequence_quadrature_permuted)
for deg in set(degs): for deg in set(degs):
f.write(" Dune::QkStuff::EquidistantLagrangePolynomials<DF, RF, {}> {};\n".format(deg, name_polynomials(deg))) f.write(" Dune::QkStuff::EquidistantLagrangePolynomials<DF, RF, {}> {};\n".format(deg, name_polynomials(deg)))
...@@ -105,8 +105,8 @@ def generate_standalone_code(sf, filename): ...@@ -105,8 +105,8 @@ def generate_standalone_code(sf, filename):
constructor_knl = lp.get_one_scheduled_kernel(constructor_knl) constructor_knl = lp.get_one_scheduled_kernel(constructor_knl)
# Allocate buffers # Allocate buffers
size = max(product(m.quadrature_size for m in sf.matrix_sequence) * sf.vector_width, size = max(product(m.quadrature_size for m in sf.matrix_sequence_quadrature_permuted) * sf.vector_width,
product(m.basis_size for m in sf.matrix_sequence) * sf.vector_width) product(m.basis_size for m in sf.matrix_sequence_quadrature_permuted) * sf.vector_width)
size = int(size * (get_option("precision_bits") / 8)) size = int(size * (get_option("precision_bits") / 8))
f.writelines([" char buffer0[{}] __attribute__ ((aligned (32)));\n".format(size), f.writelines([" char buffer0[{}] __attribute__ ((aligned (32)));\n".format(size),
" char buffer1[{}] __attribute__ ((aligned (32)));\n".format(size), " char buffer1[{}] __attribute__ ((aligned (32)));\n".format(size),
......
...@@ -303,7 +303,12 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase): ...@@ -303,7 +303,12 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase):
@property @property
def quadrature_permutation(self): def quadrature_permutation(self):
# TODO: For now we assure that all kerneles have the same quadrature_permutation # TODO: This should be turned into an error.
#
#
# The quadrature permutation could be different for different scalar kernels!
# raise RuntimeError('quadrature_permutation should not be called on VectorSumfactKernelOutput')
#
for i in self.interfaces: for i in self.interfaces:
assert i.quadrature_permutation == self.interfaces[0].quadrature_permutation assert i.quadrature_permutation == self.interfaces[0].quadrature_permutation
return self.interfaces[0].quadrature_permutation return self.interfaces[0].quadrature_permutation
...@@ -328,13 +333,15 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase): ...@@ -328,13 +333,15 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase):
return prim.Call(prim.Variable(hadd_function), (result,)) return prim.Call(prim.Variable(hadd_function), (result,))
def realize_input(self, inames, shape, vec_iname, vec_shape, buffer, ftags, l): def realize_input(self, inames, shape, vec_iname, vec_shape, buffer, ftags, l):
# TODO: Include permutations of scalar kernels as soon as they could be different # The input for stage 3 is quadrature permuted. The inames and shape
# passed to this method are quadrature and cost permuted. This means we
# need to take the cost permutation back to get the right inames and
# shape for interpreting the input!
shape = permute_backward(shape, self.cost_permutation) shape = permute_backward(shape, self.cost_permutation)
inames = permute_backward(inames, self.cost_permutation) inames = permute_backward(inames, self.cost_permutation)
# Get a temporary that interprets the base storage of the input # Get a temporary that interprets the base storage of the input as a
# as a column-major matrix. In later iteration of the matrix loop # column-major matrix.
# this reinterprets the output of the previous iteration.
inp = buffer.get_temporary("buff_step{}_in".format(l), inp = buffer.get_temporary("buff_step{}_in".format(l),
shape=shape + vec_shape, shape=shape + vec_shape,
dim_tags=ftags, dim_tags=ftags,
...@@ -370,16 +377,26 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase): ...@@ -370,16 +377,26 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase):
return deps return deps
def accumulate_output(self, sf, result, insn_dep): def accumulate_output(self, sf, result, insn_dep):
# TODO: vector_cost_permutation not used!
outputs = set(self.interfaces) outputs = set(self.interfaces)
# Note: Using matrix_sequence_quadrature_permuted is ok in this place since:
#
# - If the grid is unstructured we assume that the polynomial degree
# for each direction is the same.
#
# - If the grid is structured the quadrature permuted matrix sequence
# is the same as the original one. We still need to call this one
# since VectorizedSumfactKernels do not have the matrix_sequence
# attribute.
basis_size = tuple(mat.basis_size for mat in sf.matrix_sequence_quadrature_permuted)
if get_option('grid_unstructured'):
assert len(set(basis_size)) == 1
trial_element, = set(o.trial_element for o in self.interfaces) trial_element, = set(o.trial_element for o in self.interfaces)
trial_element_index = set(o.trial_element_index for o in self.interfaces).pop() trial_element_index = set(o.trial_element_index for o in self.interfaces).pop()
from dune.codegen.sumfact.accumulation import accum_iname from dune.codegen.sumfact.accumulation import accum_iname
element = get_leaf(trial_element, trial_element_index) if trial_element is not None else None element = get_leaf(trial_element, trial_element_index) if trial_element is not None else None
inames = tuple(accum_iname(element, mat.rows, i) inames = tuple(accum_iname(element, size, i) for i, size in enumerate(basis_size))
for i, mat in enumerate(sf.matrix_sequence_quadrature_permuted))
veciname = accum_iname(element, sf.vector_width // len(outputs), "vec") veciname = accum_iname(element, sf.vector_width // len(outputs), "vec")
transform(lp.tag_inames, [(veciname, "vec")]) transform(lp.tag_inames, [(veciname, "vec")])
...@@ -807,9 +824,10 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) ...@@ -807,9 +824,10 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
assert len(set(k.within_inames for k in kernels)) == 1 assert len(set(k.within_inames for k in kernels)) == 1
assert len(set(k.predicates for k in kernels)) == 1 assert len(set(k.predicates for k in kernels)) == 1
# Assert properties of the matrix sequence of the underlying kernels
# For now we don't mix direct and non_direct input. Could be done in an upper/lower way. # For now we don't mix direct and non_direct input. Could be done in an upper/lower way.
assert len(set(tuple(k.interface.direct_is_possible for k in kernels))) == 1 assert len(set(tuple(k.interface.direct_is_possible for k in kernels))) == 1
# Assert properties of the matrix sequence of the underlying kernels
for i in range(kernels[0].length): for i in range(kernels[0].length):
assert len(set(tuple(k.matrix_sequence_quadrature_permuted[i].rows for k in kernels))) == 1 assert len(set(tuple(k.matrix_sequence_quadrature_permuted[i].rows for k in kernels))) == 1
assert len(set(tuple(k.matrix_sequence_quadrature_permuted[i].cols for k in kernels))) == 1 assert len(set(tuple(k.matrix_sequence_quadrature_permuted[i].cols for k in kernels))) == 1
...@@ -819,9 +837,9 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) ...@@ -819,9 +837,9 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
# Join the instruction dependencies of all subkernels # Join the instruction dependencies of all subkernels
insn_dep = insn_dep.union(k.insn_dep for k in kernels) insn_dep = insn_dep.union(k.insn_dep for k in kernels)
# Assert that quadrature permutation is the same for all kernels # Assert that the cost_permutation is the same for all kernels
for k in kernels: for k in kernels:
assert k.interface.quadrature_permutation == kernels[0].interface.quadrature_permutation assert k.interface.cost_permutation == kernels[0].interface.cost_permutation
# We currently assume that all subkernels are consecutive, 0-based within the vector # We currently assume that all subkernels are consecutive, 0-based within the vector
assert None not in kernels assert None not in kernels
...@@ -881,7 +899,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) ...@@ -881,7 +899,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
@property @property
def matrix_sequence_quadrature_permuted(self): def matrix_sequence_quadrature_permuted(self):
# TODO: This should be turned into a RuntimeError # Construct quadrature permuted matrix sequence from scalar case
return tuple(BasisTabulationMatrixArray(tuple(k.matrix_sequence_quadrature_permuted[i] for k in self.kernels), return tuple(BasisTabulationMatrixArray(tuple(k.matrix_sequence_quadrature_permuted[i] for k in self.kernels),
width=self.vector_width, width=self.vector_width,
) )
...@@ -889,13 +907,20 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) ...@@ -889,13 +907,20 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
@property @property
def matrix_sequence_cost_permuted(self): def matrix_sequence_cost_permuted(self):
perm = sumfact_cost_permutation_strategy(self.matrix_sequence_quadrature_permuted, self.stage) # Construct cost permuted matrix sequence from scalar case
matrix_sequence_cost_permuted = permute_forward(self.matrix_sequence_quadrature_permuted, perm) matrix_sequence = tuple(BasisTabulationMatrixArray(tuple(k.matrix_sequence_cost_permuted[i] for k in self.kernels),
return matrix_sequence_cost_permuted width=self.vector_width,)
for i in range(self.length))
# This should already be cost optimal
perm = sumfact_cost_permutation_strategy(matrix_sequence, self.stage)
assert perm == tuple(i for i in range(len(perm)))
return matrix_sequence
@property @property
def cost_permutation(self): def cost_permutation(self):
return sumfact_cost_permutation_strategy(self.matrix_sequence_quadrature_permuted, self.stage) return self.kernels[0].cost_permutation
@property @property
def stage(self): def stage(self):
...@@ -907,7 +932,10 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) ...@@ -907,7 +932,10 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
@property @property
def quadrature_permutation(self): def quadrature_permutation(self):
return self.kernels[0].interface.quadrature_permutation # The quadrature_permutations of the underlying scalar kernels can be
# different from kernel to kernel. So there is no well defined
# quadrature_permutation on the VectorizedSumfactKernel.
raise RuntimeError("quadrature_permutation should not be used on VectorizedSumfactKernel.")
@property @property
def within_inames(self): def within_inames(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment