From ca0c381f2ac55a304a2b0ae61c55f91b4b4eb3d3 Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Tue, 25 Jul 2017 11:31:20 +0200 Subject: [PATCH] Cleanup the indexpushdown transformation --- python/dune/perftool/ufl/preprocess.py | 8 ++++---- .../perftool/ufl/transformations/indexpushdown.py | 15 +++++++++++---- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/python/dune/perftool/ufl/preprocess.py b/python/dune/perftool/ufl/preprocess.py index df756bf7..0ee27f8a 100644 --- a/python/dune/perftool/ufl/preprocess.py +++ b/python/dune/perftool/ufl/preprocess.py @@ -25,13 +25,13 @@ def preprocess_form(form): def apply_default_transformations(form): + # + # This is the list of transformations we unconditionally apply to + # all forms we want to generate code for. + # from dune.perftool.ufl.transformations import transform_form from dune.perftool.ufl.transformations.indexpushdown import pushdown_indexed - from dune.perftool.ufl.transformations.reindexing import reindexing - from dune.perftool.ufl.transformations.unroll import unroll_dimension_loops -# form = transform_form(form, unroll_dimension_loops) form = transform_form(form, pushdown_indexed) -# form = transform_form(form, reindexing) return form diff --git a/python/dune/perftool/ufl/transformations/indexpushdown.py b/python/dune/perftool/ufl/transformations/indexpushdown.py index 1f1e0b64..f70a6b53 100644 --- a/python/dune/perftool/ufl/transformations/indexpushdown.py +++ b/python/dune/perftool/ufl/transformations/indexpushdown.py @@ -1,9 +1,10 @@ from __future__ import absolute_import from ufl.algorithms import MultiFunction -from ufl.classes import Sum, Indexed from dune.perftool.ufl.flatoperators import get_operands, construct_binary_operator from dune.perftool.ufl.transformations import ufl_transformation +import ufl.classes as uc + class IndexPushDown(MultiFunction): def expr(self, o): @@ -11,9 +12,9 @@ class IndexPushDown(MultiFunction): def indexed(self, o): expr, idx = o.ufl_operands - if isinstance(expr, Sum): - terms = [Indexed(term, idx) for term in get_operands(expr)] - return construct_binary_operator(terms, Sum) + if isinstance(expr, uc.Sum): + terms = [uc.Indexed(term, idx) for term in get_operands(expr)] + return construct_binary_operator(terms, uc.Sum) else: # This is a normal indexed, we treat it as any other. return self.expr(o) @@ -21,4 +22,10 @@ class IndexPushDown(MultiFunction): @ufl_transformation(name="index_pushdown") def pushdown_indexed(e): + """ + Removes the following antipattern from UFL expressions: + (a+b)[i] -> a[i] + b[i] + If similar antipatterns arise with a node other than sum, + add the corresponding handlers here. + """ return IndexPushDown()(e) -- GitLab