diff --git a/python/dune/perftool/ufl/preprocess.py b/python/dune/perftool/ufl/preprocess.py index df756bf742628e46d76cb69de431a383463219c0..0ee27f8a75010341103fdcd18342993ee0a0502a 100644 --- a/python/dune/perftool/ufl/preprocess.py +++ b/python/dune/perftool/ufl/preprocess.py @@ -25,13 +25,13 @@ def preprocess_form(form): def apply_default_transformations(form): + # + # This is the list of transformations we unconditionally apply to + # all forms we want to generate code for. + # from dune.perftool.ufl.transformations import transform_form from dune.perftool.ufl.transformations.indexpushdown import pushdown_indexed - from dune.perftool.ufl.transformations.reindexing import reindexing - from dune.perftool.ufl.transformations.unroll import unroll_dimension_loops -# form = transform_form(form, unroll_dimension_loops) form = transform_form(form, pushdown_indexed) -# form = transform_form(form, reindexing) return form diff --git a/python/dune/perftool/ufl/transformations/indexpushdown.py b/python/dune/perftool/ufl/transformations/indexpushdown.py index 1f1e0b64307b72b916423089ca0ca2d6871c6957..f70a6b53a1c0ede151d89f647607b3548702313b 100644 --- a/python/dune/perftool/ufl/transformations/indexpushdown.py +++ b/python/dune/perftool/ufl/transformations/indexpushdown.py @@ -1,9 +1,10 @@ from __future__ import absolute_import from ufl.algorithms import MultiFunction -from ufl.classes import Sum, Indexed from dune.perftool.ufl.flatoperators import get_operands, construct_binary_operator from dune.perftool.ufl.transformations import ufl_transformation +import ufl.classes as uc + class IndexPushDown(MultiFunction): def expr(self, o): @@ -11,9 +12,9 @@ class IndexPushDown(MultiFunction): def indexed(self, o): expr, idx = o.ufl_operands - if isinstance(expr, Sum): - terms = [Indexed(term, idx) for term in get_operands(expr)] - return construct_binary_operator(terms, Sum) + if isinstance(expr, uc.Sum): + terms = [uc.Indexed(term, idx) for term in get_operands(expr)] + return construct_binary_operator(terms, uc.Sum) else: # This is a normal indexed, we treat it as any other. return self.expr(o) @@ -21,4 +22,10 @@ class IndexPushDown(MultiFunction): @ufl_transformation(name="index_pushdown") def pushdown_indexed(e): + """ + Removes the following antipattern from UFL expressions: + (a+b)[i] -> a[i] + b[i] + If similar antipatterns arise with a node other than sum, + add the corresponding handlers here. + """ return IndexPushDown()(e)