From 24198f26336e388501cffbaf3367c05b74221aae Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Wed, 26 Oct 2016 11:31:59 +0200
Subject: [PATCH] Introduce an interface between visitor and generator
 functions

---
 python/dune/perftool/pdelab/__init__.py      | 111 ++++++++++++++--
 python/dune/perftool/pdelab/localoperator.py |   4 +-
 python/dune/perftool/ufl/visitor.py          | 127 +++++++------------
 3 files changed, 151 insertions(+), 91 deletions(-)

diff --git a/python/dune/perftool/pdelab/__init__.py b/python/dune/perftool/pdelab/__init__.py
index d8d6ab1d..f994193a 100644
--- a/python/dune/perftool/pdelab/__init__.py
+++ b/python/dune/perftool/pdelab/__init__.py
@@ -2,12 +2,105 @@
 
 # Trigger some imports that are needed to have all backend implementations visible
 # to the selection mechanisms
-import dune.perftool.pdelab.argument
-import dune.perftool.pdelab.basis
-import dune.perftool.pdelab.driver
-import dune.perftool.pdelab.geometry
-import dune.perftool.pdelab.localoperator
-import dune.perftool.pdelab.parameter
-import dune.perftool.pdelab.quadrature
-import dune.perftool.pdelab.signatures
-import dune.perftool.pdelab.spaces
\ No newline at end of file
+from dune.perftool.pdelab.argument import (name_apply_function,
+                                           name_apply_function_gradient,
+                                           name_trialfunction,
+                                           name_trialfunction_gradient,
+                                           )
+from dune.perftool.pdelab.basis import (name_basis,
+                                        name_reference_gradient,
+                                        )
+from dune.perftool.pdelab.geometry import (dimension_iname,
+                                           name_facet_jacobian_determinant,
+                                           name_jacobian_determinant,
+                                           name_jacobian_inverse_transposed,
+                                           name_unit_inner_normal,
+                                           name_unit_outer_normal,
+                                           )
+from dune.perftool.pdelab.index import (name_index,
+                                        )
+from dune.perftool.pdelab.parameter import (cell_parameter_function,
+                                            intersection_parameter_function,
+                                            )
+from dune.perftool.pdelab.quadrature import (name_quadrature_weight,
+                                             )
+from dune.perftool.pdelab.spaces import (lfs_iname,
+                                         )
+
+
+class PDELabInterface(object):
+    #
+    # TODO: The following ones are actually entirely PDELab independent!
+    # They should be placed elsewhere and be used directly in the visitor.
+    #
+
+    def dimension_iname(self, context=None, count=None):
+        return dimension_iname(context=context, count=count)
+
+    def name_index(self, ind):
+        return name_index(ind)
+
+    #
+    # Local function space related generator functions
+    #
+
+    def lfs_iname(self, element, restriction, number):
+        return lfs_iname(element, restriction, number)
+
+    #
+    # Test and trial function related generator functions
+    #
+
+    def name_basis(self, element, restriction):
+        return name_basis(element, restriction)
+
+    def name_reference_gradient(self, element, restriction):
+        return name_reference_gradient(element, restriction)
+
+    def name_trialfunction_gradient(self, element, restriction, component):
+        return name_trialfunction_gradient(element, restriction, component)
+
+    def name_apply_function_gradient(self, element, restriction, component):
+        return name_apply_function_gradient(element, restriction, component)
+
+    def name_trialfunction(self, element, restriction, component):
+        return name_trialfunction(element, restriction, component)
+
+    def name_apply_function(self, element, restriction, component):
+        return name_apply_function(element, restriction, component)
+
+    #
+    # Parameter function related generator functions
+    #
+
+    def intersection_parameter_function(self, name, expr, cellwise_constant):
+        return intersection_parameter_function(name, expr, cellwise_constant)
+
+    def cell_parameter_function(self, name, expr, restriction, cellwise_constant):
+        return cell_parameter_function(name, expr, restriction, cellwise_constant)
+
+    #
+    # Geometry related generator functions
+    #
+
+    def name_facet_jacobian_determinant(self):
+        return name_facet_jacobian_determinant()
+
+    def name_jacobian_determinant(self):
+        return name_jacobian_determinant()
+
+    def name_jacobian_inverse_transposed(self, restriction):
+        return name_jacobian_inverse_transposed(restriction)
+
+    def name_unit_inner_normal(self):
+        return name_unit_inner_normal()
+
+    def name_unit_outer_normal(self):
+        return name_unit_outer_normal()
+
+    #
+    # Quadrature related generator functions
+    #
+
+    def name_quadrature_weight(self):
+        return name_quadrature_weight()
diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py
index f456b8d8..5f9ac0d9 100644
--- a/python/dune/perftool/pdelab/localoperator.py
+++ b/python/dune/perftool/pdelab/localoperator.py
@@ -459,8 +459,10 @@ def generate_kernel(integrals):
                     indexmap[j] = indexmap[i]
 
             # Get a transformer instance for this kernel
+            from dune.perftool.pdelab import PDELabInterface
+            interface = PDELabInterface()
             from dune.perftool.ufl.visitor import UFL2LoopyVisitor
-            visitor = UFL2LoopyVisitor(measure, indexmap)
+            visitor = UFL2LoopyVisitor(interface, measure, indexmap)
             generate_accumulation_instruction(visitor, term, measure, subdomain_id)
 
     # Extract the information, which is needed to create a loopy kernel.
diff --git a/python/dune/perftool/ufl/visitor.py b/python/dune/perftool/ufl/visitor.py
index ba3d98c5..898a7b46 100644
--- a/python/dune/perftool/ufl/visitor.py
+++ b/python/dune/perftool/ufl/visitor.py
@@ -3,31 +3,40 @@ This module defines the main visitor algorithm transforming ufl expressions
 to pymbolic and loopy.
 """
 
-from dune.perftool.ufl.modified_terminals import ModifiedTerminalTracker, Restriction
 from dune.perftool.generation import (domain,
-                                      get_temporary_name,
-                                      global_context,
-                                      globalarg,
-                                      iname,
-                                      instruction,
-                                      temporary_variable,
-                                      valuearg,
+                                      get_global_context_value,
                                       )
+from dune.perftool.ufl.flatoperators import get_operands
+from dune.perftool.ufl.modified_terminals import (ModifiedTerminalTracker,
+                                                  Restriction,
+                                                  )
+from dune.perftool.ufl.execution import Expression
+
+from loopy import Reduction
+
+from pymbolic.primitives import (Call,
+                                 Product,
+                                 Quotient,
+                                 Subscript,
+                                 Sum,
+                                 Variable,
+                                 )
 
-from dune.perftool.pdelab.spaces import (lfs_iname,
-                                         name_leaf_lfs,
-                                         name_lfs,
-                                         name_lfs_bound,
-                                         traverse_lfs_tree,
-                                         )
-from dune.perftool.pdelab.quadrature import quadrature_iname
-from pymbolic.primitives import Subscript, Variable
 from ufl.algorithms import MultiFunction
+from ufl.checks import is_cellwise_constant
+from ufl.functionview import select_subelement
+from ufl import (VectorElement,
+                 TensorElement,
+                 )
+from ufl.classes import (FixedIndex,
+                         IndexSum,
+                         JacobianDeterminant,
+                         )
 
 
 class UFL2LoopyVisitor(ModifiedTerminalTracker):
-    def __init__(self, measure, dimension_indices):
-        # Some variables describing the integral measure of this integral
+    def __init__(self, interface, measure, dimension_indices):
+        self.interface = interface
         self.measure = measure
         self.dimension_indices = dimension_indices
 
@@ -57,14 +66,12 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
             restriction = Restriction.NEGATIVE
 
         # Select the correct subtree of the finite element
-        from ufl.functionview import select_subelement
         element = select_subelement(o.ufl_element(), self.component)
         leaf_element = element
 
         # Now treat the case of this being a vector finite element
         if element.num_sub_elements() > 0:
             # I cannot handle general mixed elements here...
-            from ufl import VectorElement, TensorElement
             assert isinstance(element, (VectorElement, TensorElement))
 
             # Determine whether this is a non-scalar subargument. This information is later
@@ -72,9 +79,8 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
             self.argshape = len(element.value_shape())
 
             # If this is a vector element, we need add an additional accumulation loop iname
-            from dune.perftool.pdelab.geometry import dimension_iname
             for i in range(self.argshape):
-                self.inames.append(dimension_iname(context='arg', count=i))
+                self.inames.append(self.interface.dimension_iname(context='arg', count=i))
 
             # For the purpose of basis evaluation, we need to take the leaf element
             leaf_element = element.sub_elements()[0]
@@ -83,16 +89,13 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
             raise ValueError("Gradients should have been transformed to reference gradients!!!")
 
         # Have the issued instruction depend on the iname for this localfunction space
-        from dune.perftool.pdelab.spaces import lfs_iname
-        iname = lfs_iname(leaf_element, restriction, o.number())
+        iname = self.interface.lfs_iname(leaf_element, restriction, o.number())
         self.inames.append(iname)
 
         if self.reference_grad:
-            from dune.perftool.pdelab.basis import name_reference_gradient
-            return Subscript(Variable(name_reference_gradient(leaf_element, restriction)), (Variable(iname), 0))
+            return Subscript(Variable(self.interface.name_reference_gradient(leaf_element, restriction)), (Variable(iname), 0))
         else:
-            from dune.perftool.pdelab.basis import name_basis
-            return Subscript(Variable(name_basis(leaf_element, restriction)), (Variable(iname),))
+            return Subscript(Variable(self.interface.name_basis(leaf_element, restriction)), (Variable(iname),))
 
     def coefficient(self, o):
         # Do something different for trial function and coefficients from jacobian apply
@@ -107,40 +110,30 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
 
             if self.reference_grad:
                 if o.count() == 0:
-                    from dune.perftool.pdelab.argument import name_trialfunction_gradient
-                    return Variable(name_trialfunction_gradient(o.ufl_element(), restriction, self.component))
+                    return Variable(self.interface.name_trialfunction_gradient(o.ufl_element(), restriction, self.component))
                 else:
-                    from dune.perftool.pdelab.argument import name_apply_function_gradient
-                    return Variable(name_apply_function_gradient(o.ufl_element(), restriction, self.component))
+                    return Variable(self.interface.name_apply_function_gradient(o.ufl_element(), restriction, self.component))
             else:
                 if o.count() == 0:
-                    from dune.perftool.pdelab.argument import name_trialfunction
-                    return Variable(name_trialfunction(o.ufl_element(), restriction, self.component))
+                    return Variable(self.interface.name_trialfunction(o.ufl_element(), restriction, self.component))
                 else:
-                    from dune.perftool.pdelab.argument import name_apply_function
-                    return Variable(name_apply_function(o.ufl_element(), restriction, self.component))
+                    return Variable(self.interface.name_apply_function(o.ufl_element(), restriction, self.component))
 
         # Check if this is a parameter function
         else:
             # We expect all coefficients to be of type Expression!
-            from dune.perftool.ufl.execution import Expression
             assert isinstance(o, Expression)
 
             # Determine the name of the parameter function
-            from dune.perftool.generation import get_global_context_value
             name = get_global_context_value("data").object_names[id(o)]
 
-            from ufl.checks import is_cellwise_constant
             cellwise_constant = is_cellwise_constant(o)
 
             # Trigger the generation of code for this thing in the parameter class
-            from dune.perftool.pdelab.parameter import (cell_parameter_function,
-                                                        intersection_parameter_function,
-                                                        )
             if o.on_intersection:
-                intersection_parameter_function(name, o, cellwise_constant)
+                self.interface.intersection_parameter_function(name, o, cellwise_constant)
             else:
-                cell_parameter_function(name, o, self.restriction, cellwise_constant)
+                self.interface.cell_parameter_function(name, o, self.restriction, cellwise_constant)
 
             # And return a symbol
             return Variable(name)
@@ -184,39 +177,31 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
         ind = o.ufl_operands[1][0]
         redinames = additional_inames + (ind,)
         shape = o.ufl_operands[0].ufl_index_dimensions[0]
-        from dune.perftool.pdelab.index import name_index
-        domain(name_index(ind), shape)
+        domain(self.interface.name_index(ind), shape)
 
         # If the left operand is an index sum to, we do it in one reduction
-        from ufl.classes import IndexSum
         if isinstance(o.ufl_operands[0], IndexSum):
             return self.index_sum(o.ufl_operands[0], additional_inames=redinames)
         else:
-            from loopy import Reduction
-
             # Recurse to get the summation expression
             term = self.call(o.ufl_operands[0])
             redinames = tuple(i for i in redinames if i not in self.dimension_indices)
             if len(redinames) > 0:
-                ret = Reduction("sum", tuple(name_index(ind) for ind in redinames), term)
+                ret = Reduction("sum", tuple(self.interface.name_index(ind) for ind in redinames), term)
             else:
                 ret = term
 
             return ret
 
     def _index_or_fixed_index(self, index):
-        from ufl.classes import FixedIndex
         if isinstance(index, FixedIndex):
             return index._value
         else:
-            from pymbolic.primitives import Variable
-            from dune.perftool.pdelab.index import name_index
             if index in self.dimension_indices:
-                from dune.perftool.pdelab.geometry import dimension_iname
                 self.inames.append(self.dimension_indices[index])
                 return Variable(self.dimension_indices[index])
             else:
-                return Variable(name_index(index))
+                return Variable(self.interface.name_index(index))
 
     def multi_index(self, o):
         return tuple(self._index_or_fixed_index(i) for i in o)
@@ -230,8 +215,6 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
     #
 
     def product(self, o):
-        from dune.perftool.ufl.flatoperators import get_operands
-        from pymbolic.primitives import Product
         return Product(tuple(self.call(op) for op in get_operands(o)))
 
     def float_value(self, o):
@@ -241,25 +224,18 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
         return o.value()
 
     def division(self, o):
-        assert len(o.ufl_operands) == 2
-
-        from pymbolic.primitives import Quotient
         return Quotient(self.call(o.ufl_operands[0]), self.call(o.ufl_operands[1]))
 
     def sum(self, o):
-        from dune.perftool.ufl.flatoperators import get_operands
-        from pymbolic.primitives import Sum
         return Sum(tuple(self.call(op) for op in get_operands(o)))
 
     def zero(self, o):
         return 0
 
     def abs(self, o):
-        from ufl.classes import JacobianDeterminant
         if isinstance(o.ufl_operands[0], JacobianDeterminant):
             return self.call(o.ufl_operands[0])
         else:
-            from pymbolic.primitives import Call
             return Call('abs', self.call(o.ufl_operands[0]))
 
     #
@@ -270,29 +246,20 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
         # The normal must be restricted to be well-defined
         assert self.restriction is not Restriction.NONE
 
-        from pymbolic.primitives import Variable
         if self.restriction == Restriction.POSITIVE:
-            from dune.perftool.pdelab.geometry import name_unit_outer_normal
-            return Variable(name_unit_outer_normal())
+            return Variable(self.interface.name_unit_outer_normal())
         if self.restriction == Restriction.NEGATIVE:
             # It is highly unnatural to have this generator function,
             # but I do run into subtle trouble with return -1*outer
             # as the indexing into the normal happens only later.
             # Not investing more time into this cornercase right now.
-            from dune.perftool.pdelab.geometry import name_unit_inner_normal
-            return Variable(name_unit_inner_normal())
-
-    def facet_area(self, o):
-        from dune.perftool.pdelab.geometry import name_facetarea
-        return Variable(name_facetarea())
+            return Variable(self.interface.name_unit_inner_normal())
 
     def quadrature_weight(self, o):
-        from dune.perftool.pdelab.quadrature import name_quadrature_weight
-        return Variable(name_quadrature_weight())
+        return Variable(self.interface.name_quadrature_weight())
 
     def jacobian_determinant(self, o):
-        from dune.perftool.pdelab.geometry import name_jacobian_determinant
-        return Variable(name_jacobian_determinant())
+        return Variable(self.interface.name_jacobian_determinant())
 
     def jacobian_inverse(self, o):
         restriction = self.restriction
@@ -300,12 +267,10 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
             restriction = Restriction.NEGATIVE
 
         self.transpose_necessary = True
-        from dune.perftool.pdelab.geometry import name_jacobian_inverse_transposed
-        return Variable(name_jacobian_inverse_transposed(restriction))
+        return Variable(self.interface.name_jacobian_inverse_transposed(restriction))
 
     def jacobian(self, o):
         raise NotImplementedError("How did you get Jacobian into your form? We only support JacobianInverse right now. Report!")
 
     def facet_jacobian_determinant(self, o):
-        from dune.perftool.pdelab.geometry import name_facet_jacobian_determinant
-        return Variable(name_facet_jacobian_determinant())
+        return Variable(self.interface.name_facet_jacobian_determinant())
-- 
GitLab