diff --git a/python/dune/perftool/sumfact/amatrix.py b/python/dune/perftool/sumfact/amatrix.py index 2950782a57351d5391e6aabf83ae14c0e25b4cc0..7a2a4d2872783ec9d3e4abde359ddca40b8eb630 100644 --- a/python/dune/perftool/sumfact/amatrix.py +++ b/python/dune/perftool/sumfact/amatrix.py @@ -3,7 +3,7 @@ from dune.perftool.ufl.modified_terminals import Restriction from dune.perftool.options import get_option from dune.perftool.pdelab.argument import name_coefficientcontainer - +from dune.perftool.pdelab.geometry import world_dimension from dune.perftool.generation import (class_member, domain, function_mangler, @@ -242,3 +242,25 @@ def name_large_theta(transpose=False, derivative=False): shape=shape + (4,), potentially_vectorized=True, ) + + +def construct_amatrix_sequence(transpose=False, derivative=None, face=None): + # Get the standard AMatrix: + rows = quadrature_points_per_direction() + cols = basis_functions_per_direction() + + if transpose: + rows, cols = cols, rows + + mat = AMatrix(rows, cols, transpose=transpose) + + # Construct a tuple of it + dim = world_dimension() + result = [mat] * dim + + # Insert special matrices: derivative + if derivative is not None: + assert isinstance(derivative, int) and derivative < dim + result[derivative] = AMatrix(rows, cols, transpose=transpose, derivative=True) + + return tuple(result) diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py index fda761c73f08db88c8fc8712512e881bbf6999a6..89a72af1d0cb697f2029d7300605de65f011a313 100644 --- a/python/dune/perftool/sumfact/basis.py +++ b/python/dune/perftool/sumfact/basis.py @@ -15,6 +15,7 @@ from dune.perftool.generation import (backend, ) from dune.perftool.sumfact.amatrix import (AMatrix, basis_functions_per_direction, + construct_amatrix_sequence, name_theta, quadrature_points_per_direction, ) @@ -55,12 +56,6 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor): shape_impl = ('arr',) * rank temporary_variable(name, shape=shape, shape_impl=shape_impl) - # Calculate values with sumfactorization - rows = quadrature_points_per_direction() - cols = basis_functions_per_direction() - theta_matrix = AMatrix(rows, cols) - dtheta_matrix = AMatrix(rows, cols, derivative=True) - # TODO: # - This only covers rank 1 # - Avoid setting up whole gradient if only one component is needed? @@ -70,9 +65,8 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor): buffers = [] insn_dep = None for i in range(dim): - a_matrices = [theta_matrix] * dim - a_matrices[i] = dtheta_matrix - a_matrices = tuple(a_matrices) + # Construct the matrix sequence for this sum factorization + a_matrices = construct_amatrix_sequence(derivative=i) # Get the vectorization info. If this happens during the dry run, we get dummies from dune.perftool.sumfact.vectorization import get_vectorization_info @@ -136,11 +130,8 @@ def pymbolic_trialfunction(element, restriction, component, visitor): formdata = get_global_context_value('formdata') dim = formdata.geometric_dimension - # Setup sumfactorization - rows = quadrature_points_per_direction() - cols = basis_functions_per_direction() - a_matrix = AMatrix(rows, cols) - a_matrices = (a_matrix,) * dim + # Construct the matrix sequence for this sum factorization + a_matrices = construct_amatrix_sequence() # Get the vectorization info. If this happens during the dry run, we get dummies from dune.perftool.sumfact.vectorization import get_vectorization_info diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py index 665fe31306859dc0f40a407ce0cc520ec9ba8407..9fbd638bc5dd7be84df60b538338919662a24d4a 100644 --- a/python/dune/perftool/sumfact/sumfact.py +++ b/python/dune/perftool/sumfact/sumfact.py @@ -40,6 +40,7 @@ from dune.perftool.sumfact.amatrix import (AMatrix, basis_functions_per_direction, name_large_theta, name_theta, + construct_amatrix_sequence, ) from dune.perftool.loopy.symbolic import SumfactKernel from dune.perftool.tools import get_pymbolic_basename @@ -123,24 +124,10 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): insn_dep = None for i, buf in enumerate(buffers): - # Get the a matrices needed for this accumulation term - rows = basis_functions_per_direction() - cols = quadrature_points_per_direction() - theta_matrix = AMatrix(rows, cols, transpose=True) - - # If this is a gradient we need different matrices - if accterm.argument.index: - rows = basis_functions_per_direction() - cols = quadrature_points_per_direction() - dtheta_matrix = AMatrix(rows, cols, transpose=True, derivative=True) - - a_matrices = [theta_matrix] * dim - a_matrices[i] = dtheta_matrix - a_matrices = tuple(a_matrices) - pref_pos = i - else: - a_matrices = (theta_matrix,) * dim - pref_pos = None + # Construct the matrix sequence for this sum factorization + a_matrices = construct_amatrix_sequence(transpose=True, + derivative=i if accterm.argument.index else None, + ) # Get the vectorization info. If this happens during the dry run, we get dummies from dune.perftool.sumfact.vectorization import get_vectorization_info @@ -194,6 +181,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): # Add a sum factorization kernel that implements the multiplication # with the test function (stage 3) + pref_pos = i if accterm.argument.index else None result, insn_dep = sum_factorization_kernel(a_matrices, buffer, 3,