Skip to content
Snippets Groups Projects
Commit ca0c381f authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Cleanup the indexpushdown transformation

parent 0f957482
No related branches found
No related tags found
No related merge requests found
...@@ -25,13 +25,13 @@ def preprocess_form(form): ...@@ -25,13 +25,13 @@ def preprocess_form(form):
def apply_default_transformations(form): def apply_default_transformations(form):
#
# This is the list of transformations we unconditionally apply to
# all forms we want to generate code for.
#
from dune.perftool.ufl.transformations import transform_form from dune.perftool.ufl.transformations import transform_form
from dune.perftool.ufl.transformations.indexpushdown import pushdown_indexed from dune.perftool.ufl.transformations.indexpushdown import pushdown_indexed
from dune.perftool.ufl.transformations.reindexing import reindexing
from dune.perftool.ufl.transformations.unroll import unroll_dimension_loops
# form = transform_form(form, unroll_dimension_loops)
form = transform_form(form, pushdown_indexed) form = transform_form(form, pushdown_indexed)
# form = transform_form(form, reindexing)
return form return form
from __future__ import absolute_import from __future__ import absolute_import
from ufl.algorithms import MultiFunction from ufl.algorithms import MultiFunction
from ufl.classes import Sum, Indexed
from dune.perftool.ufl.flatoperators import get_operands, construct_binary_operator from dune.perftool.ufl.flatoperators import get_operands, construct_binary_operator
from dune.perftool.ufl.transformations import ufl_transformation from dune.perftool.ufl.transformations import ufl_transformation
import ufl.classes as uc
class IndexPushDown(MultiFunction): class IndexPushDown(MultiFunction):
def expr(self, o): def expr(self, o):
...@@ -11,9 +12,9 @@ class IndexPushDown(MultiFunction): ...@@ -11,9 +12,9 @@ class IndexPushDown(MultiFunction):
def indexed(self, o): def indexed(self, o):
expr, idx = o.ufl_operands expr, idx = o.ufl_operands
if isinstance(expr, Sum): if isinstance(expr, uc.Sum):
terms = [Indexed(term, idx) for term in get_operands(expr)] terms = [uc.Indexed(term, idx) for term in get_operands(expr)]
return construct_binary_operator(terms, Sum) return construct_binary_operator(terms, uc.Sum)
else: else:
# This is a normal indexed, we treat it as any other. # This is a normal indexed, we treat it as any other.
return self.expr(o) return self.expr(o)
...@@ -21,4 +22,10 @@ class IndexPushDown(MultiFunction): ...@@ -21,4 +22,10 @@ class IndexPushDown(MultiFunction):
@ufl_transformation(name="index_pushdown") @ufl_transformation(name="index_pushdown")
def pushdown_indexed(e): def pushdown_indexed(e):
"""
Removes the following antipattern from UFL expressions:
(a+b)[i] -> a[i] + b[i]
If similar antipatterns arise with a node other than sum,
add the corresponding handlers here.
"""
return IndexPushDown()(e) return IndexPushDown()(e)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment