Skip to content
Snippets Groups Projects
Commit 7d5a613f authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Merge branch 'feature/rework-accumulation-splitting' into 'master'

Rework accumulation splitting

See merge request !119
parents 21d4cd57 542ddd4c
No related branches found
No related tags found
No related merge requests found
......@@ -314,10 +314,10 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
# If this is a gradient, we generate an iname
additional_inames = frozenset({})
if accterm.argument.index:
if accterm.new_indices is not None:
from ufl.domain import find_geometric_dimension
dim = find_geometric_dimension(accterm.argument.expr)
for i in accterm.argument.index._indices:
for i in accterm.new_indices:
if i not in visitor.dimension_indices:
additional_inames = additional_inames.union(frozenset({grad_iname(i, dim)}))
......@@ -326,7 +326,12 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
return
# We also traverse the test function to get its pymbolic equivalent
test_expr = visitor(accterm.argument.expr)
if accterm.new_indices is not None:
from ufl.classes import Indexed, MultiIndex
accum_expr = Indexed(accterm.argument.expr, MultiIndex(accterm.new_indices))
else:
accum_expr = accterm.argument.expr
test_expr = visitor(accum_expr)
# Combine expression and test function
from pymbolic.primitives import Product
......@@ -393,15 +398,15 @@ def visit_integrals(integrals):
traverse_lfs_tree(ma)
# Now split the given integrand into accumulation expressions
from dune.perftool.ufl.transformations.extract_accumulation_terms import split_into_accumulation_terms
from dune.perftool.ufl.extract_accumulation_terms import split_into_accumulation_terms
accterms = split_into_accumulation_terms(integrand)
# Iterate over the terms and generate a kernel
for term in accterms:
for accterm in accterms:
# Adjust the index map for the visitor
from copy import deepcopy
indexmap = deepcopy(dimension_indices)
for i, j in term.indexmap.items():
for i, j in accterm.indexmap.items():
if i in indexmap:
indexmap[j] = indexmap[i]
......@@ -422,15 +427,19 @@ def visit_integrals(integrals):
set_subst_rule(name, expr)
# Ensure CSE on detjac * quadrature weight
domain = term.argument.argexpr.ufl_domain()
domain = accterm.argument.argexpr.ufl_domain()
if measure == "cell":
set_subst_rule("integration_factor_cell1", uc.QuadratureWeight(domain)*uc.Abs(uc.JacobianDeterminant(domain)))
set_subst_rule("integration_factor_cell2", uc.Abs(uc.JacobianDeterminant(domain))*uc.QuadratureWeight(domain))
set_subst_rule("integration_factor_cell1",
uc.QuadratureWeight(domain)*uc.Abs(uc.JacobianDeterminant(domain)))
set_subst_rule("integration_factor_cell2",
uc.Abs(uc.JacobianDeterminant(domain))*uc.QuadratureWeight(domain))
else:
set_subst_rule("integration_factor_facet1", uc.FacetJacobianDeterminant(domain)*uc.QuadratureWeight(domain))
set_subst_rule("integration_factor_facet2", uc.QuadratureWeight(domain)*uc.FacetJacobianDeterminant(domain))
set_subst_rule("integration_factor_facet1",
uc.FacetJacobianDeterminant(domain)*uc.QuadratureWeight(domain))
set_subst_rule("integration_factor_facet2",
uc.QuadratureWeight(domain)*uc.FacetJacobianDeterminant(domain))
get_backend(interface="accum_insn")(visitor, term, measure, subdomain_id)
get_backend(interface="accum_insn")(visitor, accterm, measure, subdomain_id)
def generate_kernel(integrals):
......
......@@ -129,8 +129,8 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
# If this is a gradient, we find the gradient iname
additional_inames = frozenset({})
if accterm.argument.index:
for i in accterm.argument.index._indices:
if accterm.new_indices is not None:
for i in accterm.new_indices:
if i not in visitor.dimension_indices:
from dune.perftool.pdelab.localoperator import grad_iname
additional_inames = additional_inames.union(frozenset({grad_iname(i, dim)}))
......@@ -138,7 +138,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
def emit_sumfact_kernel(i, restriction, insn_dep):
# Construct the matrix sequence for this sum factorization
a_matrices = construct_amatrix_sequence(transpose=True,
derivative=i if accterm.argument.index else None,
derivative=i if accterm.new_indices else None,
facedir=get_facedir(accterm.argument.restriction),
facemod=get_facemod(accterm.argument.restriction),
)
......@@ -159,8 +159,9 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
index = ()
vectag = frozenset()
base_storage_size = product(max(mat.rows, mat.cols) for mat in a_matrices)
temp = initialize_buffer(buf,
base_storage_size=product(max(mat.rows, mat.cols) for mat in a_matrices),
base_storage_size=base_storage_size,
num=2
).get_temporary(shape=shape,
dim_tags=dim_tags,
......@@ -168,9 +169,11 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
)
# Those input fields, that are padded need to be set to zero
# in order to do a horizontal_add lateron
# in order to do a horizontal_add later on
for pad in padding:
instruction(assignee=prim.Subscript(prim.Variable(temp), tuple(Variable(i) for i in quadrature_inames()) + (pad,)),
assignee = prim.Subscript(prim.Variable(temp),
tuple(Variable(i) for i in quadrature_inames()) + (pad,))
instruction(assignee=assignee,
expression=0,
forced_iname_deps=frozenset(quadrature_inames() + visitor.inames),
forced_iname_deps_is_final=True,
......@@ -219,7 +222,8 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
within_inames=frozenset(visitor.inames))})
inames = tuple(accum_iname((accterm.argument.restriction, restriction), mat.rows, i) for i, mat in enumerate(a_matrices))
inames = tuple(accum_iname((accterm.argument.restriction, restriction), mat.rows, i)
for i, mat in enumerate(a_matrices))
# Collect the lfs and lfs indices for the accumulate call
test_lfs = determine_accumulation_space(accterm.argument.expr, 0, measure)
......@@ -234,7 +238,8 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
if rank == 2:
# TODO the next line should get its inames from
# elsewhere. This is *NOT* robust (but works right now)
ansatz_lfs.index = flatten_index(tuple(Variable(visitor.inames[i]) for i in range(world_dimension())),
ansatz_lfs.index = flatten_index(tuple(Variable(visitor.inames[i])
for i in range(world_dimension())),
(basis_functions_per_direction(),) * dim,
order="f"
)
......@@ -258,14 +263,15 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
# Add a sum factorization kernel that implements the multiplication
# with the test function (stage 3)
pref_pos = i if accterm.argument.index else None
pref_pos = i if accterm.new_indices else None
result, insn_dep = sum_factorization_kernel(a_matrices,
buf,
3,
insn_dep=insn_dep,
additional_inames=frozenset(visitor.inames),
preferred_position=pref_pos,
restriction=(accterm.argument.restriction, restriction),
restriction=(accterm.argument.restriction,
restriction),
direct_output=direct_output,
visitor=visitor
)
......@@ -313,19 +319,21 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
depends_on=insn_dep,
)
# Mark the transformation that moves the quadrature loop inside the trialfunction loops for application
# Mark the transformation that moves the quadrature loop
# inside the trialfunction loops for application
transform(nest_quadrature_loops, visitor.inames)
return insn_dep
# Extract the restrictions on argument-1:
jac_restrictions = frozenset(tuple(ma.restriction for ma in extract_modified_arguments(accterm.term, argnumber=1)))
jac_restrictions = frozenset(tuple(ma.restriction for ma in
extract_modified_arguments(accterm.term, argnumber=1, do_index=True)))
if not jac_restrictions:
jac_restrictions = frozenset({0})
insn_dep = None
for restriction in jac_restrictions:
if accterm.argument.index:
if accterm.new_indices:
for i in range(world_dimension()):
insn_dep = emit_sumfact_kernel(i, restriction, insn_dep)
else:
......@@ -458,9 +466,9 @@ def sum_factorization_kernel(a_matrices,
It can make sense to permute the order of directions. If you have
a small m_l (e.g. stage 1 on faces) it is better to do direction l
first. This can be done permuting:
first. This can be done by:
- The order of the A matrices.
- Permuting the order of the A matrices.
- Permuting the input tensor.
- Permuting the output tensor (this assures that the directions of
the output tensor are again ordered from 0 to d-1).
......
"""
This module defines an UFL transformation, that takes a UFL expression
and transforms it into a sum of accumulation terms.
and transforms it into a list of accumulation terms.
"""
from dune.perftool.ufl.flatoperators import construct_binary_operator
from dune.perftool.ufl.modified_terminals import extract_modified_arguments
from dune.perftool.ufl.transformations import ufl_transformation
from dune.perftool.ufl.transformations.replace import replace_expression
from dune.perftool.ufl.transformations.identitypropagation import identity_propagation
from dune.perftool.ufl.transformations.zeropropagation import zero_propagation
from dune.perftool.ufl.transformations.reindexing import reindexing
from dune.perftool.ufl.modified_terminals import analyse_modified_argument, ModifiedArgument
from dune.perftool.pdelab.restriction import Restriction
from ufl.classes import Zero, Identity, Indexed, IntValue, MultiIndex
from ufl.classes import Zero, Identity, Indexed, IntValue, MultiIndex, Product
from ufl.core.multiindex import indices
from pytools import Record
......@@ -21,33 +23,50 @@ class AccumulationTerm(Record):
term,
argument,
indexmap={},
new_indices=None
):
assert isinstance(argument, ModifiedArgument)
Record.__init__(self,
term=term,
argument=argument,
indexmap=indexmap,
new_indices=new_indices
)
@ufl_transformation(name="accterms", extraction_lambda=lambda l: [at.term for at in l])
def split_into_accumulation_terms(expr, indexmap={}):
def split_into_accumulation_terms(expr):
"""Split an UFL expression into several accumulation parts and return a list
For a residual evaluation we split for different test functions
and according to the restriction (sefl/neighbor at skeletons). For
the jacobians we also need to split according to the ansatz
functions (and their restriction).
Note: This function is not an UFL transformation. Nonetheless it
has the @ufl_transformation decorator for debugging purposes.
Arguments:
----------
expr: UFL expression we want to split
"""
# Store AccumulationTerms in this list
ret = []
# Extract a list of modified terminals for the test function
# One accumulation instruction will be generated for each of these.
test_args = extract_modified_arguments(expr, argnumber=0)
test_args = extract_modified_arguments(expr, argnumber=0, do_index=False)
# Extract a list of modified terminals for the ansatz function
# in jacobian forms.
all_jacobian_args = extract_modified_arguments(expr, argnumber=1)
all_jacobian_args = extract_modified_arguments(expr, argnumber=1, do_index=False)
for test_arg in test_args:
# Do this as a multi step replacement procedure to avoid UFL nagging
# about too much stuff in reconstructing the expressions
# 1) We first cut the expression to the relevant modified test_function
# Build a replacement dictionary
# by replacing all other test functions with zero
replacement = {ma.expr: Zero(shape=ma.expr.ufl_shape,
free_indices=ma.expr.ufl_free_indices,
index_dimensions=ma.expr.ufl_index_dimensions)
......@@ -55,39 +74,103 @@ def split_into_accumulation_terms(expr, indexmap={}):
replacement[test_arg.expr] = test_arg.expr
replace_expr = replace_expression(expr, replacemap=replacement)
# 2) Cut the test function itself from the expression
# 2) Propagate indexed zeros to simplify expression
replace_expr = zero_propagation(replace_expr)
# 3) Cut the test function itself from the expression
#
# This is done by replacing the test function with an
# appropriate product of identity matrices. This way we can
# make sure that the indices of the result will be right. This
# is best explained by an example:
#
# Suppose we have the following expression:
#
# \sum_{i,j} a_{i,j} (\nabla v)_{i,j} + \sum_{k,l} b_{k,l} (\nable v)_{k,l}
#
# If we want to cut the gradient of the test function v we
# need to make sure, that both sums have the right indices:
#
# \sum_{m,n} (a_{m,n} + b_{m,n}) (\nabla v)_{m,n}
#
# and we extract (a_{m,n} + b_{m,n}). We achieve that by the
# following replacements:
#
# (\nabla v)_{i,j} -> I_{m,i} I_{n,j}
# (\nabla v)_{k,l} -> I_{m,k} I_{n,l}
#
# Resulting in:
#
# \sum_{i,j} a_{i,j} I_{m,i} I_{n,j} + \sum_{k,l} b_{k,l} I_{m,k} I_{n,l}
#
# In step 4 this will collaps to: a_{m,n} + b_{m,n}
replacement = {}
indexmap = {}
if test_arg.index:
newi = indices(len(test_arg.index))
identities = tuple(Indexed(Identity(2), MultiIndex((i,) + (j,))) for i, j in zip(newi, test_arg.index._indices))
indexmap = {i: j for i, j in zip(test_arg.index._indices, newi)}
from dune.perftool.ufl.flatoperators import construct_binary_operator
from ufl.classes import Product
replacement = {test_arg.expr: construct_binary_operator(identities, Product)}
test_arg = analyse_modified_argument(reindexing(test_arg.expr, replacemap=indexmap))
else:
replacement = {test_arg.expr: IntValue(1)}
newi = None
# Get all appearances of test functions with their indices
indexed_test_args = extract_modified_arguments(replace_expr, argnumber=0, do_index=True)
for indexed_test_arg in indexed_test_args:
if indexed_test_arg.index:
# If the test function is indexed, create a new multiindex of this shape
# -> (m,n) in the example above
if newi is None:
newi = indices(len(indexed_test_arg.index))
# Replace indexed test function with a product of identities.
identities = tuple(Indexed(Identity(2), MultiIndex((i,) + (j,)))
for i, j in zip(newi, indexed_test_arg.index._indices))
replacement.update({indexed_test_arg.expr:
construct_binary_operator(identities, Product)})
indexmap.update({i: j for i, j in zip(indexed_test_arg.index._indices, newi)})
indexed_test_arg = analyse_modified_argument(reindexing(indexed_test_arg.expr,
replacemap=indexmap))
else:
replacement.update({indexed_test_arg.expr: IntValue(1)})
replace_expr = replace_expression(replace_expr, replacemap=replacement)
# 3) Collapse any identity nodes that may have been introduced by replacing vectors
# 4) Collapse any identity nodes that may have been introduced by replacing vectors
replace_expr = identity_propagation(replace_expr)
# 4) Further split according to trial function in jacobian terms
# 5) Further split according to trial function in jacobian terms
#
# Note: We need to split according to the FunctionView. For
# example in Stokes with test functions v and q we have to
# split between those.
#
# But: We don't want to split according to gradients since
# this would break the input buffers of stage 3 of
# sumfactorization
if all_jacobian_args:
jac_args = extract_modified_arguments(replace_expr, argnumber=1)
for restriction in (Restriction.NONE, Restriction.POSITIVE, Restriction.NEGATIVE):
trial_args = extract_modified_arguments(replace_expr,
argnumber=1,
do_index=False,
do_gradient=False)
for trial_arg in trial_args:
# 5.1) Restrict to this trial argument
replacement = {ma.expr: Zero(shape=ma.expr.ufl_shape,
free_indices=ma.expr.ufl_free_indices,
index_dimensions=ma.expr.ufl_index_dimensions)
if ma.restriction != restriction else ma.expr
for ma in jac_args}
for ma in trial_args}
replacement[trial_arg.expr] = trial_arg.expr
jac_expr = replace_expression(replace_expr, replacemap=replacement)
if not isinstance(jac_expr, Zero):
ret.append(AccumulationTerm(jac_expr, test_arg, indexmap))
# 5.2) Propagate indexed zeros to simplify expression
jac_expr = zero_propagation(jac_expr)
# 5.3) Accumulate according to restriction
indexed_jac_args = extract_modified_arguments(jac_expr, argnumber=1, do_index=True)
for restriction in (Restriction.NONE, Restriction.POSITIVE, Restriction.NEGATIVE):
replacement = {ma.expr: Zero(shape=ma.expr.ufl_shape,
free_indices=ma.expr.ufl_free_indices,
index_dimensions=ma.expr.ufl_index_dimensions)
if ma.restriction != restriction else ma.expr
for ma in indexed_jac_args}
acc_expr = replace_expression(jac_expr, replacemap=replacement)
if not isinstance(jac_expr, Zero):
ret.append(AccumulationTerm(acc_expr, test_arg, indexmap, newi))
else:
if not isinstance(replace_expr, Zero):
ret.append(AccumulationTerm(replace_expr, test_arg, indexmap))
ret.append(AccumulationTerm(replace_expr, test_arg, indexmap, newi))
return ret
......@@ -92,7 +92,8 @@ class ModifiedTerminalTracker(MultiFunction):
class ModifiedArgumentAnalysis(ModifiedTerminalTracker):
def __init__(self):
def __init__(self, do_index=False):
self.do_index = do_index
self.index = None
ModifiedTerminalTracker.__init__(self)
......@@ -101,7 +102,8 @@ class ModifiedArgumentAnalysis(ModifiedTerminalTracker):
return self.call(o)
def indexed(self, o):
self.index = o.ufl_operands[1]
if self.do_index:
self.index = o.ufl_operands[1]
return self.call(o.ufl_operands[0])
def form_argument(self, o):
......@@ -116,22 +118,27 @@ class ModifiedArgumentAnalysis(ModifiedTerminalTracker):
)
def analyse_modified_argument(expr):
return ModifiedArgumentAnalysis()(expr)
def analyse_modified_argument(expr, **kwargs):
return ModifiedArgumentAnalysis(**kwargs)(expr)
class _ModifiedArgumentExtractor(MultiFunction):
""" A multifunction that extracts and returns the set of modified arguments """
def __call__(self, o, argnumber=None, coeffcount=None):
def __call__(self, o, argnumber=None, coeffcount=None, do_index=False, do_gradient=True):
self.argnumber = argnumber
self.coeffcount = coeffcount
self.do_index = do_index
self.do_gradient = do_gradient
self.modified_arguments = set()
ret = self.call(o)
if ret:
# This indicates that this entire expression was a modified thing...
self.modified_arguments.add(ret)
return tuple(analyse_modified_argument(ma) for ma in self.modified_arguments)
return tuple(analyse_modified_argument(ma,
do_index=self.do_index
)
for ma in self.modified_arguments)
def expr(self, o):
for op in o.ufl_operands:
......@@ -143,11 +150,27 @@ class _ModifiedArgumentExtractor(MultiFunction):
if self.call(o.ufl_operands[0]):
return o
indexed = pass_on
def indexed(self, o):
if self.do_index:
return self.pass_on(o)
else:
self.expr(o)
def reference_grad(self, o):
if self.do_gradient:
return self.pass_on(o)
else:
self.expr(o)
def grad(self, o):
if self.do_gradient:
return self.pass_on(o)
else:
self.expr(o)
positive_restricted = pass_on
negative_restricted = pass_on
grad = pass_on
reference_grad = pass_on
function_view = pass_on
reference_value = pass_on
......
"""
A transformation to help the form splitting algorithm split
vector and tensor expressions. In a nutshell does:
vector and tensor expressions. In a nutshell it does:
\sum_i f(i)I(i,k) => f(k)
"""
......
"""
A transformation that propagates zeros. It transforms:
Indexed(Zero((shape),(),()) -> Zero((),(free_indices),(index_dimensions))
"""
from dune.perftool.ufl.transformations import ufl_transformation
from ufl.algorithms import MultiFunction
from ufl.classes import Zero
class ZeroPropagation(MultiFunction):
call = MultiFunction.__call__
def __call__(self, expr):
# self.replacemap = GetIndexMap()(expr)
return self.call(expr)
def expr(self, o):
return self.reuse_if_untouched(o, *tuple(self.call(op) for op in o.ufl_operands))
def indexed(self, o):
op, i = o.ufl_operands
if isinstance(op, Zero):
return Zero(shape=o.ufl_shape,
free_indices=o.ufl_free_indices,
index_dimensions=o.ufl_index_dimensions)
else:
return self.reuse_if_untouched(o, self.call(op), self.call(i))
@ufl_transformation(name='zero')
def zero_propagation(expr):
return ZeroPropagation()(expr)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment