diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py index 4580d46a244878cbe59a5b41794fb700c48223c5..73952843abd68e90ed4aa1d126872d879a1741fc 100644 --- a/python/dune/perftool/sumfact/sumfact.py +++ b/python/dune/perftool/sumfact/sumfact.py @@ -159,8 +159,9 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): index = () vectag = frozenset() + base_storage_size = product(max(mat.rows, mat.cols) for mat in a_matrices) temp = initialize_buffer(buf, - base_storage_size=product(max(mat.rows, mat.cols) for mat in a_matrices), + base_storage_size=base_storage_size, num=2 ).get_temporary(shape=shape, dim_tags=dim_tags, @@ -168,9 +169,11 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): ) # Those input fields, that are padded need to be set to zero - # in order to do a horizontal_add lateron + # in order to do a horizontal_add later on for pad in padding: - instruction(assignee=prim.Subscript(prim.Variable(temp), tuple(Variable(i) for i in quadrature_inames()) + (pad,)), + assignee = prim.Subscript(prim.Variable(temp), + tuple(Variable(i) for i in quadrature_inames()) + (pad,)) + instruction(assignee=assignee, expression=0, forced_iname_deps=frozenset(quadrature_inames() + visitor.inames), forced_iname_deps_is_final=True, @@ -316,13 +319,15 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): depends_on=insn_dep, ) - # Mark the transformation that moves the quadrature loop inside the trialfunction loops for application + # Mark the transformation that moves the quadrature loop + # inside the trialfunction loops for application transform(nest_quadrature_loops, visitor.inames) return insn_dep # Extract the restrictions on argument-1: - jac_restrictions = frozenset(tuple(ma.restriction for ma in extract_modified_arguments(accterm.term, argnumber=1))) + jac_restrictions = frozenset(tuple(ma.restriction for ma in + extract_modified_arguments(accterm.term, argnumber=1, do_index=True))) if not jac_restrictions: jac_restrictions = frozenset({0}) @@ -461,9 +466,9 @@ def sum_factorization_kernel(a_matrices, It can make sense to permute the order of directions. If you have a small m_l (e.g. stage 1 on faces) it is better to do direction l - first. This can be done permuting: + first. This can be done by: - - The order of the A matrices. + - Permuting the order of the A matrices. - Permuting the input tensor. - Permuting the output tensor (this assures that the directions of the output tensor are again ordered from 0 to d-1). diff --git a/python/dune/perftool/ufl/extract_accumulation_terms.py b/python/dune/perftool/ufl/extract_accumulation_terms.py index 33032376b834843ceaf28c41e55e9b9f3ed70251..4ef70070281eb5b9dd7d8c468a94cabe46208292 100644 --- a/python/dune/perftool/ufl/extract_accumulation_terms.py +++ b/python/dune/perftool/ufl/extract_accumulation_terms.py @@ -132,31 +132,18 @@ def split_into_accumulation_terms(expr): # 5) Further split according to trial function in jacobian terms if all_jacobian_args: - jac_args = extract_modified_arguments(replace_expr, argnumber=1, do_index=False) - for jac_arg in jac_args: - # 5.1) Cut the expression to this ansatz function + indexed_jac_args = extract_modified_arguments(replace_expr, argnumber=1, do_index=True) + for restriction in (Restriction.NONE, Restriction.POSITIVE, Restriction.NEGATIVE): replacement = {ma.expr: Zero(shape=ma.expr.ufl_shape, free_indices=ma.expr.ufl_free_indices, index_dimensions=ma.expr.ufl_index_dimensions) - for ma in jac_args} - replacement[jac_arg.expr] = jac_arg.expr - jac_expr = replace_expression(replace_expr, replacemap=replacement) - - # 5.2) Propagate indexed zeros to simplify expression - jac_expr = zero_propagation(jac_expr) + if ma.restriction != restriction else ma.expr + for ma in indexed_jac_args} - indexed_jac_args = extract_modified_arguments(jac_expr, argnumber=1, do_index=True) - for restriction in (Restriction.NONE, Restriction.POSITIVE, Restriction.NEGATIVE): - replacement = {ma.expr: Zero(shape=ma.expr.ufl_shape, - free_indices=ma.expr.ufl_free_indices, - index_dimensions=ma.expr.ufl_index_dimensions) - if ma.restriction != restriction else ma.expr - for ma in indexed_jac_args} - - jac_accum_expr = replace_expression(jac_expr, replacemap=replacement) + jac_expr = replace_expression(replace_expr, replacemap=replacement) - if not isinstance(jac_expr, Zero): - ret.append(AccumulationTerm(jac_accum_expr, test_arg, indexmap, newi)) + if not isinstance(jac_expr, Zero): + ret.append(AccumulationTerm(jac_expr, test_arg, indexmap, newi)) else: if not isinstance(replace_expr, Zero): ret.append(AccumulationTerm(replace_expr, test_arg, indexmap, newi))