Skip to content
Snippets Groups Projects
Commit db2b3524 authored by René Heß's avatar René Heß
Browse files

Simplify dimension index handling

parent 6ac72292
No related branches found
No related tags found
No related merge requests found
...@@ -380,10 +380,6 @@ def visit_integrals(integrals): ...@@ -380,10 +380,6 @@ def visit_integrals(integrals):
from dune.perftool.ufl.transformations.axiparallel import diagonal_jacobian from dune.perftool.ufl.transformations.axiparallel import diagonal_jacobian
integrand = diagonal_jacobian(integrand) integrand = diagonal_jacobian(integrand)
# Gather dimension indices
from dune.perftool.ufl.dimensionindex import dimension_index_mapping
dimension_indices = dimension_index_mapping(integrand)
# Generate code for the LFS trees present in the form # Generate code for the LFS trees present in the form
from dune.perftool.ufl.modified_terminals import extract_modified_arguments from dune.perftool.ufl.modified_terminals import extract_modified_arguments
test_ma = extract_modified_arguments(integrand, argnumber=0) test_ma = extract_modified_arguments(integrand, argnumber=0)
...@@ -404,12 +400,11 @@ def visit_integrals(integrals): ...@@ -404,12 +400,11 @@ def visit_integrals(integrals):
# Iterate over the terms and generate a kernel # Iterate over the terms and generate a kernel
for accterm in accterms: for accterm in accterms:
# Adjust the index map for the visitor # Get dimension indices
from copy import deepcopy from dune.perftool.ufl.dimensionindex import dimension_index_mapping
indexmap = deepcopy(dimension_indices) indexmap = dimension_index_mapping(accterm.test_arg())
for i, j in accterm.indexmap.items(): # For jacobian there can also be dimension indices in the expression
if i in indexmap: indexmap.update(dimension_index_mapping(accterm.term))
indexmap[j] = indexmap[i]
# Get a transformer instance for this kernel # Get a transformer instance for this kernel
if get_option('sumfact'): if get_option('sumfact'):
......
...@@ -22,17 +22,21 @@ class AccumulationTerm(Record): ...@@ -22,17 +22,21 @@ class AccumulationTerm(Record):
def __init__(self, def __init__(self,
term, term,
argument, argument,
indexmap={},
new_indices=None new_indices=None
): ):
assert isinstance(argument, ModifiedArgument) assert isinstance(argument, ModifiedArgument)
Record.__init__(self, Record.__init__(self,
term=term, term=term,
argument=argument, argument=argument,
indexmap=indexmap,
new_indices=new_indices new_indices=new_indices
) )
def test_arg(self):
if self.new_indices is None:
return self.argument.expr
else:
return Indexed(self.argument.expr, MultiIndex(self.new_indices))
@ufl_transformation(name="accterms", extraction_lambda=lambda l: [at.term for at in l]) @ufl_transformation(name="accterms", extraction_lambda=lambda l: [at.term for at in l])
def split_into_accumulation_terms(expr): def split_into_accumulation_terms(expr):
...@@ -153,7 +157,6 @@ def cut_accumulation_term(expr): ...@@ -153,7 +157,6 @@ def cut_accumulation_term(expr):
# #
# In step 2 this will collaps to: a_{m,n} + b_{m,n} # In step 2 this will collaps to: a_{m,n} + b_{m,n}
replacement = {} replacement = {}
indexmap = {}
newi = None newi = None
backmap = {} backmap = {}
# Get all appearances of test functions with their indices # Get all appearances of test functions with their indices
...@@ -194,7 +197,6 @@ def cut_accumulation_term(expr): ...@@ -194,7 +197,6 @@ def cut_accumulation_term(expr):
backmap.update({mod_indexed_test_arg: backmap.update({mod_indexed_test_arg:
Indexed(Identity(2), MultiIndex((newi[0], newi[1])))}) Indexed(Identity(2), MultiIndex((newi[0], newi[1])))})
replacement.update({indexed_test_arg.expr: rep}) replacement.update({indexed_test_arg.expr: rep})
indexmap.update({indexed_test_arg.index[0]: newi[0]})
else: else:
# Replace indexed test function with a product of identities. # Replace indexed test function with a product of identities.
identities = tuple(Indexed(Identity(2), MultiIndex((i,) + (j,))) identities = tuple(Indexed(Identity(2), MultiIndex((i,) + (j,)))
...@@ -202,7 +204,6 @@ def cut_accumulation_term(expr): ...@@ -202,7 +204,6 @@ def cut_accumulation_term(expr):
replacement.update({indexed_test_arg.expr: replacement.update({indexed_test_arg.expr:
construct_binary_operator(identities, Product)}) construct_binary_operator(identities, Product)})
indexmap.update({i: j for i, j in zip(indexed_test_arg.index._indices, newi)})
else: else:
replacement.update({indexed_test_arg.expr: IntValue(1)}) replacement.update({indexed_test_arg.expr: IntValue(1)})
expr = replace_expression(expr, replacemap=replacement) expr = replace_expression(expr, replacemap=replacement)
...@@ -212,4 +213,4 @@ def cut_accumulation_term(expr): ...@@ -212,4 +213,4 @@ def cut_accumulation_term(expr):
expr = identity_propagation(expr) expr = identity_propagation(expr)
expr = replace_expression(expr, replacemap=backmap) expr = replace_expression(expr, replacemap=backmap)
return AccumulationTerm(expr, test_arg, indexmap, newi) return AccumulationTerm(expr, test_arg, newi)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment