diff --git a/python/dune/perftool/pdelab/__init__.py b/python/dune/perftool/pdelab/__init__.py
index 5ae88fa6b5ff5adc61c8ef2013b3df6bdca94276..9a6e6ae44d1013096af81bfb9246816c44472e2a 100644
--- a/python/dune/perftool/pdelab/__init__.py
+++ b/python/dune/perftool/pdelab/__init__.py
@@ -19,8 +19,8 @@ from dune.perftool.pdelab.geometry import (component_iname,
                                            pymbolic_facet_jacobian_determinant,
                                            pymbolic_jacobian_determinant,
                                            pymbolic_jacobian_inverse_transposed,
-                                           name_unit_inner_normal,
-                                           name_unit_outer_normal,
+                                           pymbolic_unit_inner_normal,
+                                           pymbolic_unit_outer_normal,
                                            to_global,
                                            )
 from dune.perftool.pdelab.index import (name_index,
@@ -137,11 +137,11 @@ class PDELabInterface(object):
     def pymbolic_jacobian_inverse_transposed(self, i, j, restriction):
         return pymbolic_jacobian_inverse_transposed(i, j, restriction)
 
-    def name_unit_inner_normal(self):
-        return name_unit_inner_normal()
+    def pymbolic_unit_inner_normal(self):
+        return pymbolic_unit_inner_normal()
 
-    def name_unit_outer_normal(self):
-        return name_unit_outer_normal()
+    def pymbolic_unit_outer_normal(self):
+        return pymbolic_unit_outer_normal()
 
     def name_cell_volume(self, restriction):
         return name_cell_volume(restriction)
diff --git a/python/dune/perftool/pdelab/geometry.py b/python/dune/perftool/pdelab/geometry.py
index 7808f231286d22435616f14644dbdec7eaac53a1..9690985311ebc3171e3545f5e20da4a5a9684d5e 100644
--- a/python/dune/perftool/pdelab/geometry.py
+++ b/python/dune/perftool/pdelab/geometry.py
@@ -262,7 +262,7 @@ def declare_normal(name, shape, shape_impl):
     return "auto {} = {}.centerUnitOuterNormal();".format(name, ig)
 
 
-def name_unit_outer_normal():
+def pymbolic_unit_outer_normal():
     name = "outer_normal"
     if not get_option("diagonal_transformation_matrix"):
         temporary_variable(name, shape=(world_dimension(),), decl_method=declare_normal)
@@ -270,7 +270,7 @@ def name_unit_outer_normal():
     else:
         declare_normal(name, None, None)
         globalarg(name, shape=(world_dimension(),), dtype=np.float64)
-    return "outer_normal"
+    return prim.Variable(name)
 
 
 def evaluate_unit_inner_normal(name):
@@ -281,11 +281,11 @@ def evaluate_unit_inner_normal(name):
                                )
 
 
-def name_unit_inner_normal():
+def pymbolic_unit_inner_normal():
     name = "inner_normal"
     temporary_variable(name, shape=(world_dimension(),), decl_method=declare_normal)
     evaluate_unit_inner_normal(name)
-    return "inner_normal"
+    return prim.Variable(name)
 
 
 def type_jacobian_inverse_transposed(restriction):
diff --git a/python/dune/perftool/sumfact/__init__.py b/python/dune/perftool/sumfact/__init__.py
index 123c4e1e8b81e76f7eae6328fbff70667cd94853..df246dee44aa60c12827f3ec272b70b6beea7e65 100644
--- a/python/dune/perftool/sumfact/__init__.py
+++ b/python/dune/perftool/sumfact/__init__.py
@@ -76,3 +76,15 @@ class SumFactInterface(PDELabInterface):
         ret, indices = get_backend(interface="spatial_coordinate", selector=option_switch("diagonal_transformation_matrix"))(self.visitor.indices, self.visitor.do_predicates, self.visitor)
         self.visitor.indices = indices
         return ret
+
+    def pymbolic_unit_outer_normal(self):
+        from dune.perftool.sumfact.geometry import pymbolic_unit_outer_normal
+        ret, indices = pymbolic_unit_outer_normal(self.visitor.indices)
+        self.visitor.indices = indices
+        return ret
+
+    def pymbolic_unit_inner_normal(self):
+        from dune.perftool.sumfact.geometry import pymbolic_unit_inner_normal
+        ret, indices = pymbolic_unit_inner_normal(self.visitor.indices)
+        self.visitor.indices = indices
+        return ret
diff --git a/python/dune/perftool/sumfact/geometry.py b/python/dune/perftool/sumfact/geometry.py
index c698a2be28f9066450d0b24bbd52b576e8227e02..94e7a55fa773d39b5db4251a483eee8a81a419f8 100644
--- a/python/dune/perftool/sumfact/geometry.py
+++ b/python/dune/perftool/sumfact/geometry.py
@@ -18,7 +18,8 @@ from dune.perftool.pdelab.geometry import (local_dimension,
                                            )
 from dune.perftool.sumfact.symbolic import SumfactKernelInputBase
 from dune.perftool.sumfact.vectorization import attach_vectorization_info
-from dune.perftool.options import option_switch
+from dune.perftool.options import get_option, option_switch
+from dune.perftool.ufl.modified_terminals import Restriction
 
 from pytools import ImmutableRecord
 
@@ -159,3 +160,37 @@ def pymbolic_spatial_coordinate_axiparallel(visitor_indices, do_predicates, visi
         x = pymbolic_quadrature_position(iindex, visitor)
 
     return prim.Subscript(prim.Variable(lowcorner), (index,)) + x * prim.Subscript(prim.Variable(meshwidth), (index,)), None
+
+
+def pymbolic_unit_outer_normal(visitor_indices):
+    index, = visitor_indices
+    assert isinstance(index, int)
+    if get_option("diagonal_transformation_matrix"):
+        from dune.perftool.sumfact.switch import get_facedir, get_facemod
+        if index == get_facedir(Restriction.POSITIVE):
+            if get_facemod(Restriction.POSITIVE):
+                return 1, None
+            else:
+                return -1, None
+        else:
+            return 0, None
+    else:
+        from dune.perftool.pdelab.geometry import pymbolic_unit_outer_normal as _norm
+        return _norm(), visitor_indices
+
+
+def pymbolic_unit_outer_normal(visitor_indices):
+    index, = visitor_indices
+    assert isinstance(index, int)
+    if get_option("diagonal_transformation_matrix"):
+        from dune.perftool.sumfact.switch import get_facedir, get_facemod
+        if index == get_facedir(Restriction.NEGATIVE):
+            if get_facemod(Restriction.NEGATIVE):
+                return -1, None
+            else:
+                return 1, None
+        else:
+            return 0, None
+    else:
+        from dune.perftool.pdelab.geometry import pymbolic_unit_inner_normal as _norm
+        return _norm(), visitor_indices
diff --git a/python/dune/perftool/ufl/visitor.py b/python/dune/perftool/ufl/visitor.py
index 2ac22c3e4076fdb9b9bcb0bb7357c48318ebb6a5..e4983773ded3655989b72b74913b05282392780f 100644
--- a/python/dune/perftool/ufl/visitor.py
+++ b/python/dune/perftool/ufl/visitor.py
@@ -389,24 +389,14 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
         # The normal must be restricted to be well-defined
         assert self.restriction is not Restriction.NONE
 
-        # Optimize facet normal on axiparallel grids
-        # TODO move this into the sumfact backend, it is only valid there
-        from dune.perftool.options import get_option
-        if get_option("diagonal_transformation_matrix") and get_option("sumfact"):
-            index, = self.indices
-            from dune.perftool.sumfact.switch import get_facedir
-            if isinstance(index, int) and index != get_facedir(self.restriction):
-                self.indices = None
-                return 0
-
         if self.restriction == Restriction.POSITIVE:
-            return Variable(self.interface.name_unit_outer_normal())
+            return self.interface.pymbolic_unit_outer_normal()
         if self.restriction == Restriction.NEGATIVE:
             # It is highly unnatural to have this generator function,
             # but I do run into subtle trouble with return -1*outer
             # as the indexing into the normal happens only later.
             # Not investing more time into this cornercase right now.
-            return Variable(self.interface.name_unit_inner_normal())
+            return self.interface.pymbolic_unit_inner_normal()
 
     def quadrature_weight(self, o):
         return self.interface.pymbolic_quadrature_weight()