diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py index 04bf64bad3349da40f153261cf308f0ba6a2edb7..f50c478a804bc22911593d7c235f682d5b5c22a4 100644 --- a/python/dune/perftool/sumfact/basis.py +++ b/python/dune/perftool/sumfact/basis.py @@ -146,77 +146,44 @@ def pymbolic_coefficient_gradient(element, restriction, component, coeff_func, v shape_impl = ('arr',) * rank temporary_variable(name, shape=shape, shape_impl=shape_impl) - # Whether direct indexing into the output is possible. This happens - # if the positioning within a SIMD vectors coincides with the index! - direct_indexing_is_possible = True - - expressions = {} - insn_dep = frozenset() - for indices in itertools.product(*map(range, shape)): - # Do not consider derivatives of rank > 2 - assert len(indices) <= 2 - - # Construct the matrix sequence for this sum factorization - matrix_sequence = construct_basis_matrix_sequence(derivative=indices[-1], - facedir=get_facedir(restriction), - facemod=get_facemod(restriction), - basis_size=basis_size, - ) - - # Index needed to acces the right coefficient container for vector valued functions + if len(visitor_indices) == 1: coeff_func_index = None - if len(indices) == 2: - coeff_func_index = indices[0] - - inp = LFSSumfactKernelInput(coeff_func=coeff_func, - coeff_func_index=coeff_func_index, - element=element, - component=component, - restriction=restriction, - ) + grad_index, = visitor_indices + else: + grad_index, coeff_func_index = visitor_indices - # The sum factorization kernel object gathering all relevant information - sf = SumfactKernel(matrix_sequence=matrix_sequence, - preferred_position=indices[-1], - input=inp, - ) - - from dune.perftool.sumfact.vectorization import attach_vectorization_info - vsf = attach_vectorization_info(sf) - - if indices[-1] != vsf.vec_index(sf): - direct_indexing_is_possible = False - - # Add a sum factorization kernel that implements the - # evaluation of the gradients of basis functions at quadrature - # points (stage 1) - from dune.perftool.sumfact.realization import realize_sum_factorization_kernel - var, insn_dep = realize_sum_factorization_kernel(vsf.copy(insn_dep=vsf.insn_dep.union(insn_dep))) - - expressions.update({indices: prim.Subscript(var, vsf.quadrature_index(sf))}) - - # Check whether we want to return early with something that has the indexing - # already handled! This happens with vectorization when the index coincides - # with the position in the vector register. - if direct_indexing_is_possible: - assert len(visitor_indices) == 1 - return maybe_wrap_subscript(var, vsf.quadrature_index(sf, visitor_indices)), None - - # TODO this should be quite conditional!!! - for indices in itertools.product(*map(range, shape)): - # Write solution from sumfactorization to gradient variable - assignee = prim.Subscript(prim.Variable(name), indices) - instruction(assignee=assignee, - expression=expressions[indices], - forced_iname_deps=frozenset(get_backend("quad_inames")()), - forced_iname_deps_is_final=True, - ) + # Construct the matrix sequence for this sum factorization + matrix_sequence = construct_basis_matrix_sequence(derivative=grad_index, + facedir=get_facedir(restriction), + facemod=get_facemod(restriction), + basis_size=basis_size, + ) + + inp = LFSSumfactKernelInput(coeff_func=coeff_func, + coeff_func_index=coeff_func_index, + element=element, + component=component, + restriction=restriction, + ) + + # The sum factorization kernel object gathering all relevant information + sf = SumfactKernel(matrix_sequence=matrix_sequence, + preferred_position=grad_index, + input=inp, + ) + + from dune.perftool.sumfact.vectorization import attach_vectorization_info + vsf = attach_vectorization_info(sf) + + from dune.perftool.sumfact.realization import realize_sum_factorization_kernel + var, insn_dep = realize_sum_factorization_kernel(vsf) - return prim.Variable(name), visitor_indices + return prim.Subscript(var, vsf.quadrature_index(sf)), None @kernel_cached def pymbolic_coefficient(element, restriction, component, coeff_func, visitor_indices): + assert visitor_indices is None # Basis functions per direction basis_size = _basis_functions_per_direction(element, component) @@ -243,9 +210,7 @@ def pymbolic_coefficient(element, restriction, component, coeff_func, visitor_in from dune.perftool.sumfact.realization import realize_sum_factorization_kernel var, _ = realize_sum_factorization_kernel(vsf) - return prim.Subscript(var, - vsf.quadrature_index(sf) - ), visitor_indices + return prim.Subscript(var, vsf.quadrature_index(sf)), None @iname