Skip to content
Snippets Groups Projects
Commit a3ce6477 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Working vertical mass matrix example

parent ed879613
No related branches found
No related tags found
No related merge requests found
...@@ -96,7 +96,7 @@ def _realize_sum_factorization_kernel(sf): ...@@ -96,7 +96,7 @@ def _realize_sum_factorization_kernel(sf):
if sf.vectorized: if sf.vectorized:
ftags = ftags + ",vec" ftags = ftags + ",vec"
ctags = ctags + ",vec" ctags = ctags + ",vec"
vec_shape = (sf.horizontal_width,) vec_shape = (sf.vector_width,)
# Measure times and count operations in c++ code # Measure times and count operations in c++ code
if get_option("instrumentation_level") >= 4: if get_option("instrumentation_level") >= 4:
...@@ -133,7 +133,7 @@ def _realize_sum_factorization_kernel(sf): ...@@ -133,7 +133,7 @@ def _realize_sum_factorization_kernel(sf):
out_inames = tuple(sumfact_iname(length, "out_inames_" + str(k)) for k, length in enumerate(out_shape)) out_inames = tuple(sumfact_iname(length, "out_inames_" + str(k)) for k, length in enumerate(out_shape))
vec_iname = () vec_iname = ()
if matrix.vectorized: if matrix.vectorized:
iname = sumfact_iname(sf.horizontal_width, "vec") iname = sumfact_iname(sf.vector_width, "vec")
vec_iname = (prim.Variable(iname),) vec_iname = (prim.Variable(iname),)
transform(lp.tag_inames, [(iname, "vec")]) transform(lp.tag_inames, [(iname, "vec")])
...@@ -160,7 +160,7 @@ def _realize_sum_factorization_kernel(sf): ...@@ -160,7 +160,7 @@ def _realize_sum_factorization_kernel(sf):
globalarg(direct_input, dtype=np.float64, shape=inp_shape, dim_tags=novec_ftags) globalarg(direct_input, dtype=np.float64, shape=inp_shape, dim_tags=novec_ftags)
if matrix.vectorized: if matrix.vectorized:
input_summand = prim.Call(prim.Variable(get_vcl_typename(np.float64, vector_width=sf.horizontal_width)), input_summand = prim.Call(prim.Variable(get_vcl_typename(np.float64, vector_width=sf.vector_width)),
(prim.Subscript(prim.Variable(direct_input), (prim.Subscript(prim.Variable(direct_input),
input_inames),)) input_inames),))
else: else:
......
...@@ -356,7 +356,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) ...@@ -356,7 +356,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
@property @property
def matrix_sequence(self): def matrix_sequence(self):
return tuple(BasisTabulationMatrixArray(tuple(k.matrix_sequence[i] for k in self.kernels), return tuple(BasisTabulationMatrixArray(tuple(k.matrix_sequence[i] for k in self.kernels),
width=self.horizontal_width, width=self.vector_width,
) )
for i in range(self.length)) for i in range(self.length))
...@@ -405,6 +405,9 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) ...@@ -405,6 +405,9 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
indices = set(range(self.horizontal_width)) - set(range(len(self.kernels))) indices = set(range(self.horizontal_width)) - set(range(len(self.kernels)))
return tuple(self.kernels[0].quadrature_index(None) + (i,) for i in indices) return tuple(self.kernels[0].quadrature_index(None) + (i,) for i in indices)
@property
def vector_width(self):
return self.horizontal_width * self.vertical_width
# #
# Define the same properties the normal SumfactKernel defines # Define the same properties the normal SumfactKernel defines
# #
...@@ -433,7 +436,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) ...@@ -433,7 +436,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
@property @property
def flat_input_shape(self): def flat_input_shape(self):
return (product(mat.cols for mat in self.matrix_sequence), self.horizontal_width) return (product(mat.basis_size for mat in self.matrix_sequence), self.horizontal_width)
@property @property
def quadrature_shape(self): def quadrature_shape(self):
...@@ -444,7 +447,11 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) ...@@ -444,7 +447,11 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
assert isinstance(direct_index, tuple) assert isinstance(direct_index, tuple)
return self.kernels[0].quadrature_index(sf) + direct_index return self.kernels[0].quadrature_index(sf) + direct_index
else: else:
return self.kernels[0].quadrature_index(sf) + (self.kernels.index(sf),) try:
horizontal_index = self.kernels.index(sf)
except ValueError:
horizontal_index = 0
return self.kernels[0].quadrature_index(sf) + (horizontal_index,)
@property @property
def quadrature_dimtags(self): def quadrature_dimtags(self):
......
...@@ -70,6 +70,7 @@ def vertical_vectorization_strategy(sumfact, depth): ...@@ -70,6 +70,7 @@ def vertical_vectorization_strategy(sumfact, depth):
vsf = VectorizedSumfactKernel(kernels=tuple(kernels), vsf = VectorizedSumfactKernel(kernels=tuple(kernels),
buffer=buffer, buffer=buffer,
input=input, input=input,
vertical_width=depth,
) )
return _cache_vectorization_info(sumfact, vsf) return _cache_vectorization_info(sumfact, vsf)
else: else:
......
...@@ -9,7 +9,7 @@ dune_add_formcompiler_system_test(UFLFILE mass_3d.ufl ...@@ -9,7 +9,7 @@ dune_add_formcompiler_system_test(UFLFILE mass_3d.ufl
INIFILE mass_3d.mini INIFILE mass_3d.mini
) )
#dune_add_formcompiler_system_test(UFLFILE mass_3d.ufl dune_add_formcompiler_system_test(UFLFILE mass_3d.ufl
# BASENAME sumfact_mass_sliced BASENAME sumfact_mass_sliced
# INIFILE sliced.mini INIFILE sliced.mini
# ) )
...@@ -19,4 +19,4 @@ vectorize_quad = 1, 0 | expand vec ...@@ -19,4 +19,4 @@ vectorize_quad = 1, 0 | expand vec
sumfact = 1 sumfact = 1
[formcompiler.ufl_variants] [formcompiler.ufl_variants]
degree = 1 degree = 3
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment