Skip to content
Snippets Groups Projects
Commit 9b80bd3f authored by René Heß's avatar René Heß
Browse files

Implement splitting for Stokes with diagonal_jacobian

parent da1e31f4
No related branches found
No related tags found
No related merge requests found
...@@ -31,7 +31,7 @@ from dune.perftool.pdelab.quadrature import (pymbolic_quadrature_weight, ...@@ -31,7 +31,7 @@ from dune.perftool.pdelab.quadrature import (pymbolic_quadrature_weight,
) )
from dune.perftool.pdelab.spaces import (lfs_inames, from dune.perftool.pdelab.spaces import (lfs_inames,
) )
from dune.perftool.pdelab.tensors import pymbolic_list_tensor from dune.perftool.pdelab.tensors import pymbolic_list_tensor, pymbolic_identity
class PDELabInterface(object): class PDELabInterface(object):
...@@ -92,6 +92,9 @@ class PDELabInterface(object): ...@@ -92,6 +92,9 @@ class PDELabInterface(object):
def pymbolic_list_tensor(self, o, visitor): def pymbolic_list_tensor(self, o, visitor):
return pymbolic_list_tensor(o, visitor) return pymbolic_list_tensor(o, visitor)
def pymbolic_identity(self, o, visitor):
return pymbolic_identity(o, visitor)
# #
# Geometry related generator functions # Geometry related generator functions
# #
......
...@@ -13,10 +13,7 @@ from dune.perftool.generation import (class_basename, ...@@ -13,10 +13,7 @@ from dune.perftool.generation import (class_basename,
from dune.perftool.pdelab.geometry import (name_cell, from dune.perftool.pdelab.geometry import (name_cell,
name_intersection, name_intersection,
) )
from dune.perftool.pdelab.quadrature import (pymbolic_quadrature_position, from dune.perftool.pdelab.quadrature import quadrature_preamble
pymbolic_quadrature_position_in_cell,
quadrature_preamble,
)
from dune.perftool.tools import get_pymbolic_basename from dune.perftool.tools import get_pymbolic_basename
from dune.perftool.cgen.clazz import AccessModifier from dune.perftool.cgen.clazz import AccessModifier
from dune.perftool.pdelab.localoperator import (class_type_from_cache, from dune.perftool.pdelab.localoperator import (class_type_from_cache,
......
""" Code generation for explicitly specified tensors """ """ Code generation for explicitly specified tensors """
from dune.perftool.generation import (get_counted_variable, from dune.perftool.generation import (get_counted_variable,
domain,
kernel_cached, kernel_cached,
iname,
instruction, instruction,
temporary_variable, temporary_variable,
) )
...@@ -33,3 +35,31 @@ def pymbolic_list_tensor(expr, visitor): ...@@ -33,3 +35,31 @@ def pymbolic_list_tensor(expr, visitor):
) )
define_list_tensor(name, expr, visitor) define_list_tensor(name, expr, visitor)
return prim.Variable(name) return prim.Variable(name)
@iname
def identity_iname(name, bound):
name = "id_{}_{}".format(name, bound)
domain(name, bound)
return name
def define_identity(name, expr, visitor):
i = identity_iname("i", expr.ufl_shape[0])
j = identity_iname("j", expr.ufl_shape[1])
instruction(assignee=prim.Subscript(prim.Variable(name), (prim.Variable(i), prim.Variable(j))),
expression=prim.If(prim.Comparison(prim.Variable(i),"==",prim.Variable(j)),1,0),
forced_iname_deps_is_final=True,
)
@kernel_cached
def pymbolic_identity(expr, visitor):
name = get_counted_variable("identity")
temporary_variable(name,
shape=expr.ufl_shape,
shape_impl=('fm',),
dtype=np.float64,
)
define_identity(name, expr, visitor)
return prim.Variable(name)
...@@ -12,7 +12,7 @@ from dune.perftool.ufl.transformations.reindexing import reindexing ...@@ -12,7 +12,7 @@ from dune.perftool.ufl.transformations.reindexing import reindexing
from dune.perftool.ufl.modified_terminals import analyse_modified_argument, ModifiedArgument from dune.perftool.ufl.modified_terminals import analyse_modified_argument, ModifiedArgument
from dune.perftool.pdelab.restriction import Restriction from dune.perftool.pdelab.restriction import Restriction
from ufl.classes import Zero, Identity, Indexed, IntValue, MultiIndex, Product from ufl.classes import Zero, Identity, Indexed, IntValue, MultiIndex, Product, IndexSum
from ufl.core.multiindex import indices from ufl.core.multiindex import indices
from pytools import Record from pytools import Record
...@@ -107,6 +107,7 @@ def split_into_accumulation_terms(expr): ...@@ -107,6 +107,7 @@ def split_into_accumulation_terms(expr):
replacement = {} replacement = {}
indexmap = {} indexmap = {}
newi = None newi = None
backmap = {}
# Get all appearances of test functions with their indices # Get all appearances of test functions with their indices
indexed_test_args = extract_modified_arguments(replace_expr, argnumber=0, do_index=True) indexed_test_args = extract_modified_arguments(replace_expr, argnumber=0, do_index=True)
for indexed_test_arg in indexed_test_args: for indexed_test_arg in indexed_test_args:
...@@ -115,20 +116,53 @@ def split_into_accumulation_terms(expr): ...@@ -115,20 +116,53 @@ def split_into_accumulation_terms(expr):
# -> (m,n) in the example above # -> (m,n) in the example above
if newi is None: if newi is None:
newi = indices(len(indexed_test_arg.index)) newi = indices(len(indexed_test_arg.index))
# Replace indexed test function with a product of identities.
identities = tuple(Indexed(Identity(2), MultiIndex((i,) + (j,))) # This handles the special case with two identical
for i, j in zip(newi, indexed_test_arg.index._indices)) # indices on an test function. E.g. in Stokes on an
replacement.update({indexed_test_arg.expr: # axiparallel grid you get a term:
construct_binary_operator(identities, Product)}) #
indexmap.update({i: j for i, j in zip(indexed_test_arg.index._indices, newi)}) # -(\sum_i K_{i,i} (\nabla v)_{i,i}) w
indexed_test_arg = analyse_modified_argument(reindexing(indexed_test_arg.expr, # = \sum_k \sum_l (-K_{k,k} w I_{k,l} (\nabla v)_{k,l})
replacemap=indexmap)) #
# and we want to split
#
# -K_{k,k} w I_{k,l} corresponding to (\nabla v)_{k,l}.
#
# This is done by:
# - Replacing (\nabla v)_{i,i} with I_{k,i}*(\nabla
# v)_{k,l}. Here (\nabla v)_{k,l} serves as a
# placeholder and will be replaced later on.
# - Propagating the identity in step 4.
# - Replacing (\nabla v)_{k,l} by I_{k,l} after step 4.
if len(set(indexed_test_arg.index._indices)) < len(indexed_test_arg.index._indices):
if len(indexed_test_arg.index._indices)>2:
raise NotImplementedError("Test argument with more than three indices and double occurence ist not implemted.")
mod_index_map = {indexed_test_arg.index: MultiIndex((newi[0], newi[1]))}
mod_indexed_test_arg = replace_expression(indexed_test_arg.expr,
replacemap = mod_index_map)
rep = Product(Indexed(Identity(2),
MultiIndex((newi[0],indexed_test_arg.index[0]))),
mod_indexed_test_arg)
backmap.update({mod_indexed_test_arg:
Indexed(Identity(2), MultiIndex((newi[0],newi[1])))})
replacement.update({indexed_test_arg.expr: rep})
indexmap.update({indexed_test_arg.index[0]: newi[0]})
else:
# Replace indexed test function with a product of identities.
identities = tuple(Indexed(Identity(2), MultiIndex((i,) + (j,)))
for i, j in zip(newi, indexed_test_arg.index._indices))
replacement.update({indexed_test_arg.expr:
construct_binary_operator(identities, Product)})
indexmap.update({i: j for i, j in zip(indexed_test_arg.index._indices, newi)})
else: else:
replacement.update({indexed_test_arg.expr: IntValue(1)}) replacement.update({indexed_test_arg.expr: IntValue(1)})
replace_expr = replace_expression(replace_expr, replacemap=replacement) replace_expr = replace_expression(replace_expr, replacemap=replacement)
# 4) Collapse any identity nodes that may have been introduced by replacing vectors # 4) Collapse any identity nodes that may have been introduced
# by replacing vectors and maybe replace placeholder from last step
replace_expr = identity_propagation(replace_expr) replace_expr = identity_propagation(replace_expr)
replace_expr = replace_expression(replace_expr, replacemap=backmap)
# 5) Further split according to trial function in jacobian terms # 5) Further split according to trial function in jacobian terms
# #
......
...@@ -234,6 +234,9 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): ...@@ -234,6 +234,9 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
# Those handlers would be valid in any code going from UFL to pymbolic # Those handlers would be valid in any code going from UFL to pymbolic
# #
def identity(self, o):
return self.interface.pymbolic_identity(o, self)
def product(self, o): def product(self, o):
return Product(tuple(self.call(op) for op in o.ufl_operands)) return Product(tuple(self.call(op) for op in o.ufl_operands))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment