Skip to content
Snippets Groups Projects
Commit 4d1f3cd2 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Put amatrix sequence construction into a function!

parent f80b58f4
No related branches found
No related tags found
No related merge requests found
...@@ -3,7 +3,7 @@ from dune.perftool.ufl.modified_terminals import Restriction ...@@ -3,7 +3,7 @@ from dune.perftool.ufl.modified_terminals import Restriction
from dune.perftool.options import get_option from dune.perftool.options import get_option
from dune.perftool.pdelab.argument import name_coefficientcontainer from dune.perftool.pdelab.argument import name_coefficientcontainer
from dune.perftool.pdelab.geometry import world_dimension
from dune.perftool.generation import (class_member, from dune.perftool.generation import (class_member,
domain, domain,
function_mangler, function_mangler,
...@@ -242,3 +242,25 @@ def name_large_theta(transpose=False, derivative=False): ...@@ -242,3 +242,25 @@ def name_large_theta(transpose=False, derivative=False):
shape=shape + (4,), shape=shape + (4,),
potentially_vectorized=True, potentially_vectorized=True,
) )
def construct_amatrix_sequence(transpose=False, derivative=None, face=None):
# Get the standard AMatrix:
rows = quadrature_points_per_direction()
cols = basis_functions_per_direction()
if transpose:
rows, cols = cols, rows
mat = AMatrix(rows, cols, transpose=transpose)
# Construct a tuple of it
dim = world_dimension()
result = [mat] * dim
# Insert special matrices: derivative
if derivative is not None:
assert isinstance(derivative, int) and derivative < dim
result[derivative] = AMatrix(rows, cols, transpose=transpose, derivative=True)
return tuple(result)
...@@ -15,6 +15,7 @@ from dune.perftool.generation import (backend, ...@@ -15,6 +15,7 @@ from dune.perftool.generation import (backend,
) )
from dune.perftool.sumfact.amatrix import (AMatrix, from dune.perftool.sumfact.amatrix import (AMatrix,
basis_functions_per_direction, basis_functions_per_direction,
construct_amatrix_sequence,
name_theta, name_theta,
quadrature_points_per_direction, quadrature_points_per_direction,
) )
...@@ -55,12 +56,6 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor): ...@@ -55,12 +56,6 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor):
shape_impl = ('arr',) * rank shape_impl = ('arr',) * rank
temporary_variable(name, shape=shape, shape_impl=shape_impl) temporary_variable(name, shape=shape, shape_impl=shape_impl)
# Calculate values with sumfactorization
rows = quadrature_points_per_direction()
cols = basis_functions_per_direction()
theta_matrix = AMatrix(rows, cols)
dtheta_matrix = AMatrix(rows, cols, derivative=True)
# TODO: # TODO:
# - This only covers rank 1 # - This only covers rank 1
# - Avoid setting up whole gradient if only one component is needed? # - Avoid setting up whole gradient if only one component is needed?
...@@ -70,9 +65,8 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor): ...@@ -70,9 +65,8 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor):
buffers = [] buffers = []
insn_dep = None insn_dep = None
for i in range(dim): for i in range(dim):
a_matrices = [theta_matrix] * dim # Construct the matrix sequence for this sum factorization
a_matrices[i] = dtheta_matrix a_matrices = construct_amatrix_sequence(derivative=i)
a_matrices = tuple(a_matrices)
# Get the vectorization info. If this happens during the dry run, we get dummies # Get the vectorization info. If this happens during the dry run, we get dummies
from dune.perftool.sumfact.vectorization import get_vectorization_info from dune.perftool.sumfact.vectorization import get_vectorization_info
...@@ -136,11 +130,8 @@ def pymbolic_trialfunction(element, restriction, component, visitor): ...@@ -136,11 +130,8 @@ def pymbolic_trialfunction(element, restriction, component, visitor):
formdata = get_global_context_value('formdata') formdata = get_global_context_value('formdata')
dim = formdata.geometric_dimension dim = formdata.geometric_dimension
# Setup sumfactorization # Construct the matrix sequence for this sum factorization
rows = quadrature_points_per_direction() a_matrices = construct_amatrix_sequence()
cols = basis_functions_per_direction()
a_matrix = AMatrix(rows, cols)
a_matrices = (a_matrix,) * dim
# Get the vectorization info. If this happens during the dry run, we get dummies # Get the vectorization info. If this happens during the dry run, we get dummies
from dune.perftool.sumfact.vectorization import get_vectorization_info from dune.perftool.sumfact.vectorization import get_vectorization_info
......
...@@ -40,6 +40,7 @@ from dune.perftool.sumfact.amatrix import (AMatrix, ...@@ -40,6 +40,7 @@ from dune.perftool.sumfact.amatrix import (AMatrix,
basis_functions_per_direction, basis_functions_per_direction,
name_large_theta, name_large_theta,
name_theta, name_theta,
construct_amatrix_sequence,
) )
from dune.perftool.loopy.symbolic import SumfactKernel from dune.perftool.loopy.symbolic import SumfactKernel
from dune.perftool.tools import get_pymbolic_basename from dune.perftool.tools import get_pymbolic_basename
...@@ -123,24 +124,10 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): ...@@ -123,24 +124,10 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
insn_dep = None insn_dep = None
for i, buf in enumerate(buffers): for i, buf in enumerate(buffers):
# Get the a matrices needed for this accumulation term # Construct the matrix sequence for this sum factorization
rows = basis_functions_per_direction() a_matrices = construct_amatrix_sequence(transpose=True,
cols = quadrature_points_per_direction() derivative=i if accterm.argument.index else None,
theta_matrix = AMatrix(rows, cols, transpose=True) )
# If this is a gradient we need different matrices
if accterm.argument.index:
rows = basis_functions_per_direction()
cols = quadrature_points_per_direction()
dtheta_matrix = AMatrix(rows, cols, transpose=True, derivative=True)
a_matrices = [theta_matrix] * dim
a_matrices[i] = dtheta_matrix
a_matrices = tuple(a_matrices)
pref_pos = i
else:
a_matrices = (theta_matrix,) * dim
pref_pos = None
# Get the vectorization info. If this happens during the dry run, we get dummies # Get the vectorization info. If this happens during the dry run, we get dummies
from dune.perftool.sumfact.vectorization import get_vectorization_info from dune.perftool.sumfact.vectorization import get_vectorization_info
...@@ -194,6 +181,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): ...@@ -194,6 +181,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
# Add a sum factorization kernel that implements the multiplication # Add a sum factorization kernel that implements the multiplication
# with the test function (stage 3) # with the test function (stage 3)
pref_pos = i if accterm.argument.index else None
result, insn_dep = sum_factorization_kernel(a_matrices, result, insn_dep = sum_factorization_kernel(a_matrices,
buffer, buffer,
3, 3,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment