Skip to content
Snippets Groups Projects
Commit 4d1f3cd2 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Put amatrix sequence construction into a function!

parent f80b58f4
No related branches found
No related tags found
No related merge requests found
......@@ -3,7 +3,7 @@ from dune.perftool.ufl.modified_terminals import Restriction
from dune.perftool.options import get_option
from dune.perftool.pdelab.argument import name_coefficientcontainer
from dune.perftool.pdelab.geometry import world_dimension
from dune.perftool.generation import (class_member,
domain,
function_mangler,
......@@ -242,3 +242,25 @@ def name_large_theta(transpose=False, derivative=False):
shape=shape + (4,),
potentially_vectorized=True,
)
def construct_amatrix_sequence(transpose=False, derivative=None, face=None):
# Get the standard AMatrix:
rows = quadrature_points_per_direction()
cols = basis_functions_per_direction()
if transpose:
rows, cols = cols, rows
mat = AMatrix(rows, cols, transpose=transpose)
# Construct a tuple of it
dim = world_dimension()
result = [mat] * dim
# Insert special matrices: derivative
if derivative is not None:
assert isinstance(derivative, int) and derivative < dim
result[derivative] = AMatrix(rows, cols, transpose=transpose, derivative=True)
return tuple(result)
......@@ -15,6 +15,7 @@ from dune.perftool.generation import (backend,
)
from dune.perftool.sumfact.amatrix import (AMatrix,
basis_functions_per_direction,
construct_amatrix_sequence,
name_theta,
quadrature_points_per_direction,
)
......@@ -55,12 +56,6 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor):
shape_impl = ('arr',) * rank
temporary_variable(name, shape=shape, shape_impl=shape_impl)
# Calculate values with sumfactorization
rows = quadrature_points_per_direction()
cols = basis_functions_per_direction()
theta_matrix = AMatrix(rows, cols)
dtheta_matrix = AMatrix(rows, cols, derivative=True)
# TODO:
# - This only covers rank 1
# - Avoid setting up whole gradient if only one component is needed?
......@@ -70,9 +65,8 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor):
buffers = []
insn_dep = None
for i in range(dim):
a_matrices = [theta_matrix] * dim
a_matrices[i] = dtheta_matrix
a_matrices = tuple(a_matrices)
# Construct the matrix sequence for this sum factorization
a_matrices = construct_amatrix_sequence(derivative=i)
# Get the vectorization info. If this happens during the dry run, we get dummies
from dune.perftool.sumfact.vectorization import get_vectorization_info
......@@ -136,11 +130,8 @@ def pymbolic_trialfunction(element, restriction, component, visitor):
formdata = get_global_context_value('formdata')
dim = formdata.geometric_dimension
# Setup sumfactorization
rows = quadrature_points_per_direction()
cols = basis_functions_per_direction()
a_matrix = AMatrix(rows, cols)
a_matrices = (a_matrix,) * dim
# Construct the matrix sequence for this sum factorization
a_matrices = construct_amatrix_sequence()
# Get the vectorization info. If this happens during the dry run, we get dummies
from dune.perftool.sumfact.vectorization import get_vectorization_info
......
......@@ -40,6 +40,7 @@ from dune.perftool.sumfact.amatrix import (AMatrix,
basis_functions_per_direction,
name_large_theta,
name_theta,
construct_amatrix_sequence,
)
from dune.perftool.loopy.symbolic import SumfactKernel
from dune.perftool.tools import get_pymbolic_basename
......@@ -123,24 +124,10 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
insn_dep = None
for i, buf in enumerate(buffers):
# Get the a matrices needed for this accumulation term
rows = basis_functions_per_direction()
cols = quadrature_points_per_direction()
theta_matrix = AMatrix(rows, cols, transpose=True)
# If this is a gradient we need different matrices
if accterm.argument.index:
rows = basis_functions_per_direction()
cols = quadrature_points_per_direction()
dtheta_matrix = AMatrix(rows, cols, transpose=True, derivative=True)
a_matrices = [theta_matrix] * dim
a_matrices[i] = dtheta_matrix
a_matrices = tuple(a_matrices)
pref_pos = i
else:
a_matrices = (theta_matrix,) * dim
pref_pos = None
# Construct the matrix sequence for this sum factorization
a_matrices = construct_amatrix_sequence(transpose=True,
derivative=i if accterm.argument.index else None,
)
# Get the vectorization info. If this happens during the dry run, we get dummies
from dune.perftool.sumfact.vectorization import get_vectorization_info
......@@ -194,6 +181,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
# Add a sum factorization kernel that implements the multiplication
# with the test function (stage 3)
pref_pos = i if accterm.argument.index else None
result, insn_dep = sum_factorization_kernel(a_matrices,
buffer,
3,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment