diff --git a/python/dune/perftool/ufl/preprocess.py b/python/dune/perftool/ufl/preprocess.py index 24d436c9e9dca249c3929f32e704ea2ed72eb804..19ca10359de05dec5154dc194d18a9233645664b 100644 --- a/python/dune/perftool/ufl/preprocess.py +++ b/python/dune/perftool/ufl/preprocess.py @@ -1,10 +1,26 @@ """ Preprocessing algorithms for UFL forms """ import ufl.classes as uc +import ufl.algorithms.apply_function_pullbacks as afp from pytools import memoize +class FunctionPullbackApplier(afp.FunctionPullbackApplier): + def argument(self, o): + return afp.apply_single_function_pullbacks(o) + + def coefficient(self, o): + if o.count() in (0, 1): + return afp.apply_single_function_pullbacks(o) + else: + return o + + +# Monkey patch the pullback applier from UFL +afp.FunctionPullbackApplier = FunctionPullbackApplier + + @memoize def preprocess_form(form): from ufl.algorithms import compute_form_data diff --git a/python/dune/perftool/ufl/visitor.py b/python/dune/perftool/ufl/visitor.py index 687afd5c5d95538089d030021aade9c9b5f707af..071102aea95bb5bb83ed44304a17e57e5ffcec73 100644 --- a/python/dune/perftool/ufl/visitor.py +++ b/python/dune/perftool/ufl/visitor.py @@ -164,7 +164,10 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): # and exports it through a getter method 'getTime' return prim.Call(prim.Variable("getTime"), ()) else: - return self.interface.pymbolic_gridfunction(o, restriction, self.reference_grad) + if self.reference_grad: + raise PerftoolUFLError("Coefficient gradients should not be transformed to reference element") + + return self.interface.pymbolic_gridfunction(o, restriction, self.grad) # # Handlers for all indexing related stuff