Skip to content
Snippets Groups Projects
Commit 67b7b6e4 authored by René Heß's avatar René Heß
Browse files

Do not split gradien in non sumfact code path

parent 26f55622
No related branches found
No related tags found
No related merge requests found
...@@ -309,42 +309,19 @@ def grad_iname(index, dim): ...@@ -309,42 +309,19 @@ def grad_iname(index, dim):
@backend(interface="accum_insn") @backend(interface="accum_insn")
def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
# When we do not do sumfactorization we do not split the test function
assert(accterm.argument.expr is None) assert(accterm.argument.expr is None)
# First we do the tree traversal to get a pymbolic expression representing this expression # First we do the tree traversal to get a pymbolic expression representing this expression
pymbolic_expr = visitor(accterm.term) pymbolic_expr = visitor(accterm.term)
# If this is a gradient, we generate an iname
additional_inames = frozenset({})
if accterm.new_indices is not None:
from ufl.domain import find_geometric_dimension
dim = find_geometric_dimension(accterm.argument.expr)
for i in accterm.new_indices:
if i not in visitor.dimension_indices:
additional_inames = additional_inames.union(frozenset({grad_iname(i, dim)}))
# It may happen that an entire accumulation term vanishes. We do nothing in that case # It may happen that an entire accumulation term vanishes. We do nothing in that case
if pymbolic_expr == 0: if pymbolic_expr == 0:
return return
# We also traverse the test function to get its pymbolic equivalent # Collect the lfs and lfs indices for the accumulate call
from ufl.classes import Identity test_lfs = determine_accumulation_space(accterm.term, 0, measure)
if accterm.argument.expr is not None:
if accterm.new_indices is not None:
from ufl.classes import Indexed, MultiIndex
accum_expr = Indexed(accterm.argument.expr, MultiIndex(accterm.new_indices))
else:
accum_expr = accterm.argument.expr
test_expr = visitor(accum_expr)
# Combine expression and test function
from pymbolic.primitives import Product
pymbolic_expr = Product((pymbolic_expr, test_expr))
if accterm.argument.expr is None:
test_lfs = determine_accumulation_space(accterm.term, 0, measure)
else:
test_lfs = determine_accumulation_space(accterm.argument.expr, 0, measure)
# In the jacobian case, also determine the space for the ansatz space # In the jacobian case, also determine the space for the ansatz space
ansatz_lfs = determine_accumulation_space(accterm.term, 1, measure) ansatz_lfs = determine_accumulation_space(accterm.term, 1, measure)
...@@ -368,7 +345,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): ...@@ -368,7 +345,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
instruction(assignees=(), instruction(assignees=(),
expression=expr, expression=expr,
forced_iname_deps=additional_inames.union(frozenset(visitor.inames).union(frozenset(quad_inames))), forced_iname_deps=frozenset(visitor.inames).union(frozenset(quad_inames)),
forced_iname_deps_is_final=True, forced_iname_deps_is_final=True,
predicates=predicates predicates=predicates
) )
...@@ -407,7 +384,7 @@ def visit_integrals(integrals): ...@@ -407,7 +384,7 @@ def visit_integrals(integrals):
# right input for the sumfactorization kernel of stage 3. # right input for the sumfactorization kernel of stage 3.
from dune.perftool.ufl.extract_accumulation_terms import split_into_accumulation_terms from dune.perftool.ufl.extract_accumulation_terms import split_into_accumulation_terms
if get_option('sumfact'): if get_option('sumfact'):
accterms = split_into_accumulation_terms(integrand, cut_test_arg=True) accterms = split_into_accumulation_terms(integrand, cut_test_arg=True, split_gradients=True)
else: else:
accterms = split_into_accumulation_terms(integrand) accterms = split_into_accumulation_terms(integrand)
......
...@@ -44,7 +44,7 @@ class AccumulationTerm(Record): ...@@ -44,7 +44,7 @@ class AccumulationTerm(Record):
return Indexed(self.argument.expr, MultiIndex(self.new_indices)) return Indexed(self.argument.expr, MultiIndex(self.new_indices))
def split_into_accumulation_terms(expr, cut_test_arg=False): def split_into_accumulation_terms(expr, cut_test_arg=False, split_gradients=False):
"""Split an expression into a list of AccumulationTerms """Split an expression into a list of AccumulationTerms
Arguments: Arguments:
...@@ -53,7 +53,7 @@ def split_into_accumulation_terms(expr, cut_test_arg=False): ...@@ -53,7 +53,7 @@ def split_into_accumulation_terms(expr, cut_test_arg=False):
cut_test_arg: Wheter we want to return AccumulationTerms where the cut_test_arg: Wheter we want to return AccumulationTerms where the
test function is cut from the rest of the expression. test function is cut from the rest of the expression.
""" """
expression_list = split_expression(expr) expression_list = split_expression(expr, split_gradients=split_gradients)
acc_term_list = [] acc_term_list = []
for e in expression_list: for e in expression_list:
if cut_test_arg: if cut_test_arg:
...@@ -64,7 +64,7 @@ def split_into_accumulation_terms(expr, cut_test_arg=False): ...@@ -64,7 +64,7 @@ def split_into_accumulation_terms(expr, cut_test_arg=False):
@ufl_transformation(name="splitt_expression", extraction_lambda=lambda l: [e for e in l]) @ufl_transformation(name="splitt_expression", extraction_lambda=lambda l: [e for e in l])
def split_expression(expr): def split_expression(expr, split_gradients=False):
"""Split expression into a list of expressions """Split expression into a list of expressions
Note: This is not an ufl transformation since it does not return Note: This is not an ufl transformation since it does not return
...@@ -77,11 +77,11 @@ def split_expression(expr): ...@@ -77,11 +77,11 @@ def split_expression(expr):
# Extract a list of modified terminals for the test function. We # Extract a list of modified terminals for the test function. We
# will split the expression into one part for each moidified argument. # will split the expression into one part for each moidified argument.
test_args = extract_modified_arguments(expr, argnumber=0, do_index=False) test_args = extract_modified_arguments(expr, argnumber=0, do_index=False, do_gradient=split_gradients)
# Extract a list of modified terminals for the ansatz function in # Extract a list of modified terminals for the ansatz function in
# jacobian forms. # jacobian forms.
all_jacobian_args = extract_modified_arguments(expr, argnumber=1, do_index=False) all_jacobian_args = extract_modified_arguments(expr, argnumber=1, do_index=False, do_gradient=split_gradients)
for test_arg in test_args: for test_arg in test_args:
# Do this as a multi step replacement procedure to avoid UFL # Do this as a multi step replacement procedure to avoid UFL
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment