Skip to content
Snippets Groups Projects
Commit 26f55622 authored by René Heß's avatar René Heß
Browse files

Do not split test function when not doing sumfactorization

parent db2b3524
No related branches found
No related tags found
No related merge requests found
...@@ -309,6 +309,8 @@ def grad_iname(index, dim): ...@@ -309,6 +309,8 @@ def grad_iname(index, dim):
@backend(interface="accum_insn") @backend(interface="accum_insn")
def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
assert(accterm.argument.expr is None)
# First we do the tree traversal to get a pymbolic expression representing this expression # First we do the tree traversal to get a pymbolic expression representing this expression
pymbolic_expr = visitor(accterm.term) pymbolic_expr = visitor(accterm.term)
...@@ -326,22 +328,27 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): ...@@ -326,22 +328,27 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
return return
# We also traverse the test function to get its pymbolic equivalent # We also traverse the test function to get its pymbolic equivalent
if accterm.new_indices is not None: from ufl.classes import Identity
from ufl.classes import Indexed, MultiIndex if accterm.argument.expr is not None:
accum_expr = Indexed(accterm.argument.expr, MultiIndex(accterm.new_indices)) if accterm.new_indices is not None:
else: from ufl.classes import Indexed, MultiIndex
accum_expr = accterm.argument.expr accum_expr = Indexed(accterm.argument.expr, MultiIndex(accterm.new_indices))
test_expr = visitor(accum_expr) else:
accum_expr = accterm.argument.expr
test_expr = visitor(accum_expr)
# Combine expression and test function # Combine expression and test function
from pymbolic.primitives import Product from pymbolic.primitives import Product
pymbolic_expr = Product((pymbolic_expr, test_expr)) pymbolic_expr = Product((pymbolic_expr, test_expr))
# Collect the lfs and lfs indices for the accumulate call if accterm.argument.expr is None:
test_lfs = determine_accumulation_space(accterm.argument.expr, 0, measure) test_lfs = determine_accumulation_space(accterm.term, 0, measure)
else:
test_lfs = determine_accumulation_space(accterm.argument.expr, 0, measure)
# In the jacobian case, also determine the space for the ansatz space # In the jacobian case, also determine the space for the ansatz space
ansatz_lfs = determine_accumulation_space(accterm.term, 1, measure) ansatz_lfs = determine_accumulation_space(accterm.term, 1, measure)
# Collect the lfs and lfs indices for the accumulate call
from dune.perftool.pdelab.argument import name_accumulation_variable from dune.perftool.pdelab.argument import name_accumulation_variable
accumvar = name_accumulation_variable((test_lfs.get_restriction() + ansatz_lfs.get_restriction())) accumvar = name_accumulation_variable((test_lfs.get_restriction() + ansatz_lfs.get_restriction()))
...@@ -394,16 +401,23 @@ def visit_integrals(integrals): ...@@ -394,16 +401,23 @@ def visit_integrals(integrals):
from dune.perftool.pdelab.spaces import traverse_lfs_tree from dune.perftool.pdelab.spaces import traverse_lfs_tree
traverse_lfs_tree(ma) traverse_lfs_tree(ma)
# Now split the given integrand into accumulation expressions # Now split the given integrand into accumulation
# expressions. If we do sumfactorization we cut the test
# argument from the rest of the expression. This gives the
# right input for the sumfactorization kernel of stage 3.
from dune.perftool.ufl.extract_accumulation_terms import split_into_accumulation_terms from dune.perftool.ufl.extract_accumulation_terms import split_into_accumulation_terms
accterms = split_into_accumulation_terms(integrand) if get_option('sumfact'):
accterms = split_into_accumulation_terms(integrand, cut_test_arg=True)
else:
accterms = split_into_accumulation_terms(integrand)
# Iterate over the terms and generate a kernel # Iterate over the terms and generate a kernel
for accterm in accterms: for accterm in accterms:
# Get dimension indices # Get dimension indices
indexmap = {}
from dune.perftool.ufl.dimensionindex import dimension_index_mapping from dune.perftool.ufl.dimensionindex import dimension_index_mapping
indexmap = dimension_index_mapping(accterm.test_arg()) if accterm.argument.expr is not None:
# For jacobian there can also be dimension indices in the expression indexmap.update(dimension_index_mapping(accterm.indexed_test_arg()))
indexmap.update(dimension_index_mapping(accterm.term)) indexmap.update(dimension_index_mapping(accterm.term))
# Get a transformer instance for this kernel # Get a transformer instance for this kernel
...@@ -423,7 +437,7 @@ def visit_integrals(integrals): ...@@ -423,7 +437,7 @@ def visit_integrals(integrals):
set_subst_rule(name, expr) set_subst_rule(name, expr)
# Ensure CSE on detjac * quadrature weight # Ensure CSE on detjac * quadrature weight
domain = accterm.argument.argexpr.ufl_domain() domain = accterm.term.ufl_domain()
if measure == "cell": if measure == "cell":
set_subst_rule("integration_factor_cell1", set_subst_rule("integration_factor_cell1",
uc.QuadratureWeight(domain) * uc.Abs(uc.JacobianDeterminant(domain))) uc.QuadratureWeight(domain) * uc.Abs(uc.JacobianDeterminant(domain)))
......
""" """Transform an UFL expression into list of accumulation terms."""
This module defines an UFL transformation, that takes a UFL expression
and transforms it into a list of accumulation terms.
"""
from dune.perftool.ufl.flatoperators import construct_binary_operator from dune.perftool.ufl.flatoperators import construct_binary_operator
from dune.perftool.ufl.modified_terminals import extract_modified_arguments from dune.perftool.ufl.modified_terminals import extract_modified_arguments
from dune.perftool.ufl.transformations import ufl_transformation from dune.perftool.ufl.transformations import ufl_transformation
from dune.perftool.ufl.transformations.replace import replace_expression from dune.perftool.ufl.transformations.replace import replace_expression
from dune.perftool.ufl.transformations.identitypropagation import identity_propagation from dune.perftool.ufl.transformations.identitypropagation import identity_propagation
from dune.perftool.ufl.transformations.zeropropagation import zero_propagation from dune.perftool.ufl.transformations.zeropropagation import zero_propagation
from dune.perftool.ufl.transformations.reindexing import reindexing from dune.perftool.ufl.modified_terminals import ModifiedArgument
from dune.perftool.ufl.modified_terminals import analyse_modified_argument, ModifiedArgument
from dune.perftool.pdelab.restriction import Restriction from dune.perftool.pdelab.restriction import Restriction
from ufl.classes import Zero, Identity, Indexed, IntValue, MultiIndex, Product, IndexSum from ufl.classes import Zero, Identity, Indexed, IntValue, MultiIndex, Product, IndexSum
...@@ -19,37 +16,62 @@ from pytools import Record ...@@ -19,37 +16,62 @@ from pytools import Record
class AccumulationTerm(Record): class AccumulationTerm(Record):
"""Store informations about accumulation terms
Arguments:
----------
term: The UFL expression we want to accumulate
argument: Corresponding test function (in case we split it) without indices
new_indices: Indices of the test function
"""
def __init__(self, def __init__(self,
term, term,
argument, argument=None,
new_indices=None new_indices=None
): ):
assert isinstance(argument, ModifiedArgument) assert (isinstance(argument, ModifiedArgument) or argument is None)
Record.__init__(self, Record.__init__(self,
term=term, term=term,
argument=argument, argument=argument,
new_indices=new_indices new_indices=new_indices
) )
def test_arg(self): def indexed_test_arg(self):
"""Return test argument of this accumulation term with its indices"""
if self.new_indices is None: if self.new_indices is None:
return self.argument.expr return self.argument.expr
else: else:
return Indexed(self.argument.expr, MultiIndex(self.new_indices)) return Indexed(self.argument.expr, MultiIndex(self.new_indices))
@ufl_transformation(name="accterms", extraction_lambda=lambda l: [at.term for at in l]) def split_into_accumulation_terms(expr, cut_test_arg=False):
def split_into_accumulation_terms(expr): """Split an expression into a list of AccumulationTerms
Arguments:
----------
expr: The UFL expression we split into AccumulationTerms
cut_test_arg: Wheter we want to return AccumulationTerms where the
test function is cut from the rest of the expression.
"""
expression_list = split_expression(expr) expression_list = split_expression(expr)
acc_term_list = [] acc_term_list = []
for e in expression_list: for e in expression_list:
acc_term_list.append(cut_accumulation_term(e)) if cut_test_arg:
acc_term_list.append(cut_accumulation_term(e))
else:
acc_term_list.append(AccumulationTerm(e, ModifiedArgument(expr=None)))
return acc_term_list return acc_term_list
@ufl_transformation(name="splitt_expression", extraction_lambda=lambda l: [e for e in l])
def split_expression(expr): def split_expression(expr):
# TODO: doc me """Split expression into a list of expressions
Note: This is not an ufl transformation since it does not return
an expression. It is nonetheless decorated with ufl_transformation
since the decorator can handle lists of expressions and it can be
useful for debugging to get the corresponding trees.
"""
# Store AccumulationTerms in this list # Store AccumulationTerms in this list
ret = [] ret = []
...@@ -62,8 +84,9 @@ def split_expression(expr): ...@@ -62,8 +84,9 @@ def split_expression(expr):
all_jacobian_args = extract_modified_arguments(expr, argnumber=1, do_index=False) all_jacobian_args = extract_modified_arguments(expr, argnumber=1, do_index=False)
for test_arg in test_args: for test_arg in test_args:
# Do this as a multi step replacement procedure to avoid UFL nagging # Do this as a multi step replacement procedure to avoid UFL
# about too much stuff in reconstructing the expressions # nagging about too much stuff in reconstructing the
# expressions
# 1) We first cut the expression to the relevant modified test_function # 1) We first cut the expression to the relevant modified test_function
# by replacing all other test functions with zero # by replacing all other test functions with zero
...@@ -123,8 +146,15 @@ def split_expression(expr): ...@@ -123,8 +146,15 @@ def split_expression(expr):
return ret return ret
@ufl_transformation(name="cut_accum", extraction_lambda=lambda l: [l.term])
def cut_accumulation_term(expr): def cut_accumulation_term(expr):
# TODO: doc me """Cut test function from expression and return AccumulationTerm
Note: This assumes that there is only one test function in the
expression. You need to make sure to split your expression into
appropriate parts before calling this!
"""
# If there are multiple test arguments something went wrong!
test_args = extract_modified_arguments(expr, argnumber=0, do_index=False) test_args = extract_modified_arguments(expr, argnumber=0, do_index=False)
assert len(test_args) == 1 assert len(test_args) == 1
test_arg = test_args[0] test_arg = test_args[0]
......
""" Define the general infrastructure for debuggable UFL transformations""" """Infrastructure for printing trees of UFL expressions."""
class UFLTransformationWrapper(object): class UFLTransformationWrapper(object):
...@@ -47,7 +47,6 @@ class UFLTransformationWrapper(object): ...@@ -47,7 +47,6 @@ class UFLTransformationWrapper(object):
# We do also assume that the transformation returns an ufl expression or a list there of # We do also assume that the transformation returns an ufl expression or a list there of
ret_for_print = self.extractExpressionListFromResult(ret) ret_for_print = self.extractExpressionListFromResult(ret)
assert isinstance(ret_for_print, list) and all(isinstance(e, Expr) for e in ret_for_print) assert isinstance(ret_for_print, list) and all(isinstance(e, Expr) for e in ret_for_print)
# Maybe output the returned expression # Maybe output the returned expression
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment