Skip to content
Snippets Groups Projects
Commit db2b3524 authored by René Heß's avatar René Heß
Browse files

Simplify dimension index handling

parent 6ac72292
No related branches found
No related tags found
No related merge requests found
......@@ -380,10 +380,6 @@ def visit_integrals(integrals):
from dune.perftool.ufl.transformations.axiparallel import diagonal_jacobian
integrand = diagonal_jacobian(integrand)
# Gather dimension indices
from dune.perftool.ufl.dimensionindex import dimension_index_mapping
dimension_indices = dimension_index_mapping(integrand)
# Generate code for the LFS trees present in the form
from dune.perftool.ufl.modified_terminals import extract_modified_arguments
test_ma = extract_modified_arguments(integrand, argnumber=0)
......@@ -404,12 +400,11 @@ def visit_integrals(integrals):
# Iterate over the terms and generate a kernel
for accterm in accterms:
# Adjust the index map for the visitor
from copy import deepcopy
indexmap = deepcopy(dimension_indices)
for i, j in accterm.indexmap.items():
if i in indexmap:
indexmap[j] = indexmap[i]
# Get dimension indices
from dune.perftool.ufl.dimensionindex import dimension_index_mapping
indexmap = dimension_index_mapping(accterm.test_arg())
# For jacobian there can also be dimension indices in the expression
indexmap.update(dimension_index_mapping(accterm.term))
# Get a transformer instance for this kernel
if get_option('sumfact'):
......
......@@ -22,17 +22,21 @@ class AccumulationTerm(Record):
def __init__(self,
term,
argument,
indexmap={},
new_indices=None
):
assert isinstance(argument, ModifiedArgument)
Record.__init__(self,
term=term,
argument=argument,
indexmap=indexmap,
new_indices=new_indices
)
def test_arg(self):
if self.new_indices is None:
return self.argument.expr
else:
return Indexed(self.argument.expr, MultiIndex(self.new_indices))
@ufl_transformation(name="accterms", extraction_lambda=lambda l: [at.term for at in l])
def split_into_accumulation_terms(expr):
......@@ -153,7 +157,6 @@ def cut_accumulation_term(expr):
#
# In step 2 this will collaps to: a_{m,n} + b_{m,n}
replacement = {}
indexmap = {}
newi = None
backmap = {}
# Get all appearances of test functions with their indices
......@@ -194,7 +197,6 @@ def cut_accumulation_term(expr):
backmap.update({mod_indexed_test_arg:
Indexed(Identity(2), MultiIndex((newi[0], newi[1])))})
replacement.update({indexed_test_arg.expr: rep})
indexmap.update({indexed_test_arg.index[0]: newi[0]})
else:
# Replace indexed test function with a product of identities.
identities = tuple(Indexed(Identity(2), MultiIndex((i,) + (j,)))
......@@ -202,7 +204,6 @@ def cut_accumulation_term(expr):
replacement.update({indexed_test_arg.expr:
construct_binary_operator(identities, Product)})
indexmap.update({i: j for i, j in zip(indexed_test_arg.index._indices, newi)})
else:
replacement.update({indexed_test_arg.expr: IntValue(1)})
expr = replace_expression(expr, replacemap=replacement)
......@@ -212,4 +213,4 @@ def cut_accumulation_term(expr):
expr = identity_propagation(expr)
expr = replace_expression(expr, replacemap=backmap)
return AccumulationTerm(expr, test_arg, indexmap, newi)
return AccumulationTerm(expr, test_arg, newi)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment