Skip to content
Snippets Groups Projects
Commit 6ac72292 authored by René Heß's avatar René Heß
Browse files

Cut split_into_accumulation_terms in two parts

parent acacbeca
No related branches found
No related tags found
No related merge requests found
......@@ -36,29 +36,25 @@ class AccumulationTerm(Record):
@ufl_transformation(name="accterms", extraction_lambda=lambda l: [at.term for at in l])
def split_into_accumulation_terms(expr):
"""Split an UFL expression into several accumulation parts and return a list
expression_list = split_expression(expr)
acc_term_list = []
for e in expression_list:
acc_term_list.append(cut_accumulation_term(e))
For a residual evaluation we split for different test functions
and according to the restriction (sefl/neighbor at skeletons). For
the jacobians we also need to split according to the ansatz
functions (and their restriction).
return acc_term_list
Note: This function is not an UFL transformation. Nonetheless it
has the @ufl_transformation decorator for debugging purposes.
Arguments:
----------
expr: UFL expression we want to split
"""
def split_expression(expr):
# TODO: doc me
# Store AccumulationTerms in this list
ret = []
# Extract a list of modified terminals for the test function
# One accumulation instruction will be generated for each of these.
# Extract a list of modified terminals for the test function. We
# will split the expression into one part for each moidified argument.
test_args = extract_modified_arguments(expr, argnumber=0, do_index=False)
# Extract a list of modified terminals for the ansatz function
# in jacobian forms.
# Extract a list of modified terminals for the ansatz function in
# jacobian forms.
all_jacobian_args = extract_modified_arguments(expr, argnumber=1, do_index=False)
for test_arg in test_args:
......@@ -77,94 +73,7 @@ def split_into_accumulation_terms(expr):
# 2) Propagate indexed zeros to simplify expression
replace_expr = zero_propagation(replace_expr)
# 3) Cut the test function itself from the expression
#
# This is done by replacing the test function with an
# appropriate product of identity matrices. This way we can
# make sure that the indices of the result will be right. This
# is best explained by an example:
#
# Suppose we have the following expression:
#
# \sum_{i,j} a_{i,j} (\nabla v)_{i,j} + \sum_{k,l} b_{k,l} (\nable v)_{k,l}
#
# If we want to cut the gradient of the test function v we
# need to make sure, that both sums have the right indices:
#
# \sum_{m,n} (a_{m,n} + b_{m,n}) (\nabla v)_{m,n}
#
# and we extract (a_{m,n} + b_{m,n}). We achieve that by the
# following replacements:
#
# (\nabla v)_{i,j} -> I_{m,i} I_{n,j}
# (\nabla v)_{k,l} -> I_{m,k} I_{n,l}
#
# Resulting in:
#
# \sum_{i,j} a_{i,j} I_{m,i} I_{n,j} + \sum_{k,l} b_{k,l} I_{m,k} I_{n,l}
#
# In step 4 this will collaps to: a_{m,n} + b_{m,n}
replacement = {}
indexmap = {}
newi = None
backmap = {}
# Get all appearances of test functions with their indices
indexed_test_args = extract_modified_arguments(replace_expr, argnumber=0, do_index=True)
for indexed_test_arg in indexed_test_args:
if indexed_test_arg.index:
# If the test function is indexed, create a new multiindex of this shape
# -> (m,n) in the example above
if newi is None:
newi = indices(len(indexed_test_arg.index))
# This handles the special case with two identical
# indices on an test function. E.g. in Stokes on an
# axiparallel grid you get a term:
#
# -(\sum_i K_{i,i} (\nabla v)_{i,i}) w
# = \sum_k \sum_l (-K_{k,k} w I_{k,l} (\nabla v)_{k,l})
#
# and we want to split
#
# -K_{k,k} w I_{k,l} corresponding to (\nabla v)_{k,l}.
#
# This is done by:
# - Replacing (\nabla v)_{i,i} with I_{k,i}*(\nabla
# v)_{k,l}. Here (\nabla v)_{k,l} serves as a
# placeholder and will be replaced later on.
# - Propagating the identity in step 4.
# - Replacing (\nabla v)_{k,l} by I_{k,l} after step 4.
if len(set(indexed_test_arg.index._indices)) < len(indexed_test_arg.index._indices):
if len(indexed_test_arg.index._indices) > 2:
raise NotImplementedError("Test argument with more than three indices and double occurence ist not implemented.")
mod_index_map = {indexed_test_arg.index: MultiIndex((newi[0], newi[1]))}
mod_indexed_test_arg = replace_expression(indexed_test_arg.expr,
replacemap=mod_index_map)
rep = Product(Indexed(Identity(2),
MultiIndex((newi[0], indexed_test_arg.index[0]))),
mod_indexed_test_arg)
backmap.update({mod_indexed_test_arg:
Indexed(Identity(2), MultiIndex((newi[0], newi[1])))})
replacement.update({indexed_test_arg.expr: rep})
indexmap.update({indexed_test_arg.index[0]: newi[0]})
else:
# Replace indexed test function with a product of identities.
identities = tuple(Indexed(Identity(2), MultiIndex((i,) + (j,)))
for i, j in zip(newi, indexed_test_arg.index._indices))
replacement.update({indexed_test_arg.expr:
construct_binary_operator(identities, Product)})
indexmap.update({i: j for i, j in zip(indexed_test_arg.index._indices, newi)})
else:
replacement.update({indexed_test_arg.expr: IntValue(1)})
replace_expr = replace_expression(replace_expr, replacemap=replacement)
# 4) Collapse any identity nodes that may have been introduced
# by replacing vectors and maybe replace placeholder from last step
replace_expr = identity_propagation(replace_expr)
replace_expr = replace_expression(replace_expr, replacemap=backmap)
# 5) Further split according to trial function in jacobian terms
# 3) Further split according to trial function in jacobian terms
#
# Note: We need to split according to the FunctionView. For
# example in Stokes with test functions v and q we have to
......@@ -179,7 +88,7 @@ def split_into_accumulation_terms(expr):
do_index=False,
do_gradient=False)
for trial_arg in trial_args:
# 5.1) Restrict to this trial argument
# 3.1) Restrict to this trial argument
replacement = {ma.expr: Zero(shape=ma.expr.ufl_shape,
free_indices=ma.expr.ufl_free_indices,
index_dimensions=ma.expr.ufl_index_dimensions)
......@@ -187,10 +96,10 @@ def split_into_accumulation_terms(expr):
replacement[trial_arg.expr] = trial_arg.expr
jac_expr = replace_expression(replace_expr, replacemap=replacement)
# 5.2) Propagate indexed zeros to simplify expression
# 3.2) Propagate indexed zeros to simplify expression
jac_expr = zero_propagation(jac_expr)
# 5.3) Accumulate according to restriction
# 3.3) Split according to restriction
indexed_jac_args = extract_modified_arguments(jac_expr, argnumber=1, do_index=True)
for restriction in (Restriction.NONE, Restriction.POSITIVE, Restriction.NEGATIVE):
replacement = {ma.expr: Zero(shape=ma.expr.ufl_shape,
......@@ -201,10 +110,106 @@ def split_into_accumulation_terms(expr):
acc_expr = replace_expression(jac_expr, replacemap=replacement)
if not isinstance(jac_expr, Zero):
ret.append(AccumulationTerm(acc_expr, test_arg, indexmap, newi))
if not isinstance(acc_expr, Zero):
ret.append(acc_expr)
else:
if not isinstance(replace_expr, Zero):
ret.append(AccumulationTerm(replace_expr, test_arg, indexmap, newi))
ret.append(replace_expr)
return ret
def cut_accumulation_term(expr):
# TODO: doc me
test_args = extract_modified_arguments(expr, argnumber=0, do_index=False)
assert len(test_args) == 1
test_arg = test_args[0]
# 1) Cut the test function itself from the expression
#
# This is done by replacing the test function with an
# appropriate product of identity matrices. This way we can
# make sure that the indices of the result will be right. This
# is best explained by an example:
#
# Suppose we have the following expression:
#
# \sum_{i,j} a_{i,j} (\nabla v)_{i,j} + \sum_{k,l} b_{k,l} (\nable v)_{k,l}
#
# If we want to cut the gradient of the test function v we
# need to make sure, that both sums have the right indices:
#
# \sum_{m,n} (a_{m,n} + b_{m,n}) (\nabla v)_{m,n}
#
# and we extract (a_{m,n} + b_{m,n}). We achieve that by the
# following replacements:
#
# (\nabla v)_{i,j} -> I_{m,i} I_{n,j}
# (\nabla v)_{k,l} -> I_{m,k} I_{n,l}
#
# Resulting in:
#
# \sum_{i,j} a_{i,j} I_{m,i} I_{n,j} + \sum_{k,l} b_{k,l} I_{m,k} I_{n,l}
#
# In step 2 this will collaps to: a_{m,n} + b_{m,n}
replacement = {}
indexmap = {}
newi = None
backmap = {}
# Get all appearances of test functions with their indices
indexed_test_args = extract_modified_arguments(expr, argnumber=0, do_index=True)
for indexed_test_arg in indexed_test_args:
if indexed_test_arg.index:
# If the test function is indexed, create a new multiindex of this shape
# -> (m,n) in the example above
if newi is None:
newi = indices(len(indexed_test_arg.index))
# This handles the special case with two identical
# indices on an test function. E.g. in Stokes on an
# axiparallel grid you get a term:
#
# -(\sum_i K_{i,i} (\nabla v)_{i,i}) w
# = \sum_k \sum_l (-K_{k,k} w I_{k,l} (\nabla v)_{k,l})
#
# and we want to split
#
# -K_{k,k} w I_{k,l} corresponding to (\nabla v)_{k,l}.
#
# This is done by:
# - Replacing (\nabla v)_{i,i} with I_{k,i}*(\nabla
# v)_{k,l}. Here (\nabla v)_{k,l} serves as a
# placeholder and will be replaced later on.
# - Propagating the identity in step 4.
# - Replacing (\nabla v)_{k,l} by I_{k,l} after step 4.
if len(set(indexed_test_arg.index._indices)) < len(indexed_test_arg.index._indices):
if len(indexed_test_arg.index._indices) > 2:
raise NotImplementedError("Test argument with more than three indices and double occurence ist not implemented.")
mod_index_map = {indexed_test_arg.index: MultiIndex((newi[0], newi[1]))}
mod_indexed_test_arg = replace_expression(indexed_test_arg.expr,
replacemap=mod_index_map)
rep = Product(Indexed(Identity(2),
MultiIndex((newi[0], indexed_test_arg.index[0]))),
mod_indexed_test_arg)
backmap.update({mod_indexed_test_arg:
Indexed(Identity(2), MultiIndex((newi[0], newi[1])))})
replacement.update({indexed_test_arg.expr: rep})
indexmap.update({indexed_test_arg.index[0]: newi[0]})
else:
# Replace indexed test function with a product of identities.
identities = tuple(Indexed(Identity(2), MultiIndex((i,) + (j,)))
for i, j in zip(newi, indexed_test_arg.index._indices))
replacement.update({indexed_test_arg.expr:
construct_binary_operator(identities, Product)})
indexmap.update({i: j for i, j in zip(indexed_test_arg.index._indices, newi)})
else:
replacement.update({indexed_test_arg.expr: IntValue(1)})
expr = replace_expression(expr, replacemap=replacement)
# 2) Collapse any identity nodes that may have been introduced
# by replacing vectors and maybe replace placeholder from last step
expr = identity_propagation(expr)
expr = replace_expression(expr, replacemap=backmap)
return AccumulationTerm(expr, test_arg, indexmap, newi)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment