Skip to content
Snippets Groups Projects
Commit 75c89208 authored by René Heß's avatar René Heß
Browse files

Ignore indices when splitting into accumulation terms

Example: Before this commit the expression

\sum_i a_i (\nabla v)_i + \sum_j b_j (\nabla v)_j

was split into two accumulation terms:

1) a_k with corresponding test function (\nabla v)_k
2) b_l with corresponding test function (\nabla v)_l

Now we split into:

a_k + b_k with corresponding test function (\nabla v)_k

This is possible since we have linearity in the test function.

TODO: Jacobians are not yet working
parent 61fa1d0e
No related branches found
No related tags found
No related merge requests found
......@@ -314,10 +314,10 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
# If this is a gradient, we generate an iname
additional_inames = frozenset({})
if accterm.argument.index:
if accterm.new_indices is not None:
from ufl.domain import find_geometric_dimension
dim = find_geometric_dimension(accterm.argument.expr)
for i in accterm.argument.index._indices:
for i in accterm.new_indices:
if i not in visitor.dimension_indices:
additional_inames = additional_inames.union(frozenset({grad_iname(i, dim)}))
......@@ -326,7 +326,12 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
return
# We also traverse the test function to get its pymbolic equivalent
test_expr = visitor(accterm.argument.expr)
if accterm.new_indices is not None:
from ufl.classes import Indexed, MultiIndex
accum_expr = Indexed(accterm.argument.expr, MultiIndex(accterm.new_indices))
else:
accum_expr = accterm.argument.expr
test_expr = visitor(accum_expr)
# Combine expression and test function
from pymbolic.primitives import Product
......@@ -397,11 +402,11 @@ def visit_integrals(integrals):
accterms = split_into_accumulation_terms(integrand)
# Iterate over the terms and generate a kernel
for term in accterms:
for accterm in accterms:
# Adjust the index map for the visitor
from copy import deepcopy
indexmap = deepcopy(dimension_indices)
for i, j in term.indexmap.items():
for i, j in accterm.indexmap.items():
if i in indexmap:
indexmap[j] = indexmap[i]
......@@ -422,13 +427,17 @@ def visit_integrals(integrals):
set_subst_rule(name, expr)
# Ensure CSE on detjac * quadrature weight
domain = term.argument.argexpr.ufl_domain()
domain = accterm.argument.argexpr.ufl_domain()
if measure == "cell":
set_subst_rule("integration_factor_cell1", uc.QuadratureWeight(domain)*uc.Abs(uc.JacobianDeterminant(domain)))
set_subst_rule("integration_factor_cell2", uc.Abs(uc.JacobianDeterminant(domain))*uc.QuadratureWeight(domain))
set_subst_rule("integration_factor_cell1",
uc.QuadratureWeight(domain)*uc.Abs(uc.JacobianDeterminant(domain)))
set_subst_rule("integration_factor_cell2",
uc.Abs(uc.JacobianDeterminant(domain))*uc.QuadratureWeight(domain))
else:
set_subst_rule("integration_factor_facet1", uc.FacetJacobianDeterminant(domain)*uc.QuadratureWeight(domain))
set_subst_rule("integration_factor_facet2", uc.QuadratureWeight(domain)*uc.FacetJacobianDeterminant(domain))
set_subst_rule("integration_factor_facet1",
uc.FacetJacobianDeterminant(domain)*uc.QuadratureWeight(domain))
set_subst_rule("integration_factor_facet2",
uc.QuadratureWeight(domain)*uc.FacetJacobianDeterminant(domain))
get_backend(interface="accum_insn")(visitor, term, measure, subdomain_id)
......
"""
This module defines an UFL transformation, that takes a UFL expression
and transforms it into a sum of accumulation terms.
and transforms it into a list of accumulation terms.
"""
from dune.perftool.ufl.flatoperators import construct_binary_operator
from dune.perftool.ufl.modified_terminals import extract_modified_arguments
from dune.perftool.ufl.transformations import ufl_transformation
from dune.perftool.ufl.transformations.replace import replace_expression
from dune.perftool.ufl.transformations.identitypropagation import identity_propagation
from dune.perftool.ufl.transformations.zeropropagation import zero_propagation
from dune.perftool.ufl.transformations.reindexing import reindexing
from dune.perftool.ufl.modified_terminals import analyse_modified_argument, ModifiedArgument
from dune.perftool.pdelab.restriction import Restriction
from ufl.classes import Zero, Identity, Indexed, IntValue, MultiIndex
from ufl.classes import Zero, Identity, Indexed, IntValue, MultiIndex, Product
from ufl.core.multiindex import indices
from pytools import Record
......@@ -21,12 +23,14 @@ class AccumulationTerm(Record):
term,
argument,
indexmap={},
new_indices=None
):
assert isinstance(argument, ModifiedArgument)
Record.__init__(self,
term=term,
argument=argument,
indexmap=indexmap,
new_indices=new_indices
)
......@@ -44,14 +48,14 @@ def split_into_accumulation_terms(expr):
Arguments:
----------
exrp: UFL expression we want to split
expr: UFL expression we want to split
"""
# Store AccumulationTerms in this list
ret = []
# Extract a list of modified terminals for the test function
# One accumulation instruction will be generated for each of these.
test_args = extract_modified_arguments(expr, argnumber=0)
test_args = extract_modified_arguments(expr, argnumber=0, do_index=False)
# Extract a list of modified terminals for the ansatz function
# in jacobian forms.
......@@ -62,7 +66,7 @@ def split_into_accumulation_terms(expr):
# about too much stuff in reconstructing the expressions
# 1) We first cut the expression to the relevant modified test_function
# Build a replacement dictionary
# by replacing all other test functions with zero
replacement = {ma.expr: Zero(shape=ma.expr.ufl_shape,
free_indices=ma.expr.ufl_free_indices,
index_dimensions=ma.expr.ufl_index_dimensions)
......@@ -70,25 +74,67 @@ def split_into_accumulation_terms(expr):
replacement[test_arg.expr] = test_arg.expr
replace_expr = replace_expression(expr, replacemap=replacement)
# 2) Cut the test function itself from the expression
# 2) Propagate indexed zeros to simplify expression
replace_expr = zero_propagation(replace_expr)
# 3) Cut the test function itself from the expression
#
# This is done by replacing the test function with an
# appropriate product of identity matrices. This way we can
# make sure that the indices of the result will be right. This
# is best explained by an example:
#
# Suppose we have the following expression:
#
# \sum_{i,j} a_{i,j} (\nabla v)_{i,j} + \sum_{k,l} b_{k,l} (\nable v)_{k,l}
#
# If we want to cut the gradient of the test function v we
# need to make sure, that both sums have the right indices:
#
# \sum_{m,n} (a_{m,n} + b_{m,n}) (\nabla v)_{m,n}
#
# and we extract (a_{m,n} + b_{m,n}). We achieve that by the
# following replacements:
#
# (\nabla v)_{i,j} -> I_{m,i} I_{n,j}
# (\nabla v)_{k,l} -> I_{m,k} I_{n,l}
#
# Resulting in:
#
# \sum_{i,j} a_{i,j} I_{m,i} I_{n,j} + \sum_{k,l} b_{k,l} I_{m,k} I_{n,l}
#
# In step 4 this will collaps to: a_{m,n} + b_{m,n}
replacement = {}
indexmap = {}
if test_arg.index:
newi = indices(len(test_arg.index))
identities = tuple(Indexed(Identity(2), MultiIndex((i,) + (j,))) for i, j in zip(newi, test_arg.index._indices))
indexmap = {i: j for i, j in zip(test_arg.index._indices, newi)}
from dune.perftool.ufl.flatoperators import construct_binary_operator
from ufl.classes import Product
replacement = {test_arg.expr: construct_binary_operator(identities, Product)}
test_arg = analyse_modified_argument(reindexing(test_arg.expr, replacemap=indexmap))
else:
replacement = {test_arg.expr: IntValue(1)}
newi = None
# Get all appearances of test functions with their indices
indexed_test_args = extract_modified_arguments(replace_expr, argnumber=0, do_index=True)
for indexed_test_arg in indexed_test_args:
# from pudb import set_trace; set_trace()
if indexed_test_arg.index:
# If the test function is indexed, create a new multiindex of this shape
# -> (m,n) in the example above
if newi is None:
newi = indices(len(indexed_test_arg.index))
# Replace indexed test function with a product of identities.
identities = tuple(Indexed(Identity(2), MultiIndex((i,) + (j,)))
for i, j in zip(newi, indexed_test_arg.index._indices))
replacement.update({indexed_test_arg.expr:
construct_binary_operator(identities, Product)})
indexmap.update({i: j for i, j in zip(indexed_test_arg.index._indices, newi)})
indexed_test_arg = analyse_modified_argument(reindexing(indexed_test_arg.expr,
replacemap=indexmap))
else:
replacement.update({indexed_test_arg.expr: IntValue(1)})
replace_expr = replace_expression(replace_expr, replacemap=replacement)
# 3) Collapse any identity nodes that may have been introduced by replacing vectors
# 4) Collapse any identity nodes that may have been introduced by replacing vectors
replace_expr = identity_propagation(replace_expr)
# 4) Further split according to trial function in jacobian terms
# 5) Further split according to trial function in jacobian terms
if all_jacobian_args:
# TODO -> Jacobians not yet implemented!
assert(False)
jac_args = extract_modified_arguments(replace_expr, argnumber=1)
for restriction in (Restriction.NONE, Restriction.POSITIVE, Restriction.NEGATIVE):
......@@ -100,9 +146,9 @@ def split_into_accumulation_terms(expr):
jac_expr = replace_expression(replace_expr, replacemap=replacement)
if not isinstance(jac_expr, Zero):
ret.append(AccumulationTerm(jac_expr, test_arg, indexmap))
ret.append(AccumulationTerm(jac_expr, test_arg, indexmap, newi))
else:
if not isinstance(replace_expr, Zero):
ret.append(AccumulationTerm(replace_expr, test_arg, indexmap))
ret.append(AccumulationTerm(replace_expr, test_arg, indexmap, newi))
return ret
......@@ -92,7 +92,8 @@ class ModifiedTerminalTracker(MultiFunction):
class ModifiedArgumentAnalysis(ModifiedTerminalTracker):
def __init__(self):
def __init__(self, do_index=False):
self.do_index = do_index
self.index = None
ModifiedTerminalTracker.__init__(self)
......@@ -101,7 +102,8 @@ class ModifiedArgumentAnalysis(ModifiedTerminalTracker):
return self.call(o)
def indexed(self, o):
self.index = o.ufl_operands[1]
if self.do_index:
self.index = o.ufl_operands[1]
return self.call(o.ufl_operands[0])
def form_argument(self, o):
......@@ -116,22 +118,23 @@ class ModifiedArgumentAnalysis(ModifiedTerminalTracker):
)
def analyse_modified_argument(expr):
return ModifiedArgumentAnalysis()(expr)
def analyse_modified_argument(expr, **kwargs):
return ModifiedArgumentAnalysis(**kwargs)(expr)
class _ModifiedArgumentExtractor(MultiFunction):
""" A multifunction that extracts and returns the set of modified arguments """
def __call__(self, o, argnumber=None, coeffcount=None):
def __call__(self, o, argnumber=None, coeffcount=None, do_index=False):
self.argnumber = argnumber
self.coeffcount = coeffcount
self.do_index = do_index
self.modified_arguments = set()
ret = self.call(o)
if ret:
# This indicates that this entire expression was a modified thing...
self.modified_arguments.add(ret)
return tuple(analyse_modified_argument(ma) for ma in self.modified_arguments)
return tuple(analyse_modified_argument(ma, do_index=self.do_index) for ma in self.modified_arguments)
def expr(self, o):
for op in o.ufl_operands:
......@@ -143,7 +146,12 @@ class _ModifiedArgumentExtractor(MultiFunction):
if self.call(o.ufl_operands[0]):
return o
indexed = pass_on
def indexed(self, o):
if self.do_index:
return self.pass_on(o)
else:
self.expr(o)
positive_restricted = pass_on
negative_restricted = pass_on
grad = pass_on
......
"""
A transformation to help the form splitting algorithm split
vector and tensor expressions. In a nutshell does:
vector and tensor expressions. In a nutshell it does:
\sum_i f(i)I(i,k) => f(k)
"""
......
"""
A transformation that propagates zeros. It transforms:
Indexed(Zero((shape),(),()) -> Zero((),(free_indices),(index_dimensions))
"""
from dune.perftool.ufl.transformations import ufl_transformation
from ufl.algorithms import MultiFunction
from ufl.classes import Zero
class ZeroPropagation(MultiFunction):
call = MultiFunction.__call__
def __call__(self, expr):
# self.replacemap = GetIndexMap()(expr)
return self.call(expr)
def expr(self, o):
return self.reuse_if_untouched(o, *tuple(self.call(op) for op in o.ufl_operands))
def indexed(self, o):
op, i = o.ufl_operands
if isinstance(op, Zero):
return Zero(shape=o.ufl_shape,
free_indices=o.ufl_free_indices,
index_dimensions=o.ufl_index_dimensions)
else:
return self.reuse_if_untouched(o, self.call(op), self.call(i))
@ufl_transformation(name='zero')
def zero_propagation(expr):
return ZeroPropagation()(expr)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment