From 08d79cb291efe8fad477f0d0b9c614e98663442c Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Tue, 27 Mar 2018 10:25:19 +0200
Subject: [PATCH] Unify treatment of input (stage 1) and output (stage 3)

I was tired of the amount of ifs that changed behaviour
depending on stage 1 or 3. This is a cleaner approach.
---
 python/dune/perftool/sumfact/accumulation.py |  41 +++-
 python/dune/perftool/sumfact/basis.py        |  32 ++-
 python/dune/perftool/sumfact/geometry.py     |   5 +-
 python/dune/perftool/sumfact/realization.py  |  32 +--
 python/dune/perftool/sumfact/symbolic.py     | 229 ++++++++++---------
 5 files changed, 187 insertions(+), 152 deletions(-)

diff --git a/python/dune/perftool/sumfact/accumulation.py b/python/dune/perftool/sumfact/accumulation.py
index d576562e..255d0909 100644
--- a/python/dune/perftool/sumfact/accumulation.py
+++ b/python/dune/perftool/sumfact/accumulation.py
@@ -24,6 +24,7 @@ from dune.perftool.options import (get_form_option,
                                    )
 from dune.perftool.loopy.flatten import flatten_index
 from dune.perftool.sumfact.quadrature import nest_quadrature_loops
+from dune.perftool.pdelab.driver import FEM_name_mangling
 from dune.perftool.pdelab.localoperator import determine_accumulation_space
 from dune.perftool.pdelab.restriction import restricted_name
 from dune.perftool.pdelab.signatures import assembler_routine_name
@@ -35,7 +36,7 @@ from dune.perftool.sumfact.tabulation import (basis_functions_per_direction,
 from dune.perftool.sumfact.switch import (get_facedir,
                                           get_facemod,
                                           )
-from dune.perftool.sumfact.symbolic import SumfactKernel, SumfactKernelOutputBase
+from dune.perftool.sumfact.symbolic import SumfactKernel, SumfactKernelInterfaceBase
 from dune.perftool.ufl.modified_terminals import extract_modified_arguments
 from dune.perftool.tools import get_pymbolic_basename, get_leaf
 from dune.perftool.error import PerftoolError
@@ -84,7 +85,7 @@ def accum_iname(element, bound, i):
     return sumfact_iname(bound, "accum{}".format(suffix))
 
 
-class AccumulationOutput(SumfactKernelOutputBase, ImmutableRecord):
+class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord):
     def __init__(self,
                  accumvar=None,
                  restriction=None,
@@ -106,6 +107,14 @@ class AccumulationOutput(SumfactKernelOutputBase, ImmutableRecord):
     def __repr__(self):
         return ImmutableRecord.__repr__(self)
 
+    @property
+    def stage(self):
+        return 3
+
+    @property
+    def direct_is_possible(self):
+        return get_form_option("fastdg")
+
     @property
     def within_inames(self):
         if self.trial_element is None:
@@ -202,7 +211,17 @@ class AccumulationOutput(SumfactKernelOutputBase, ImmutableRecord):
                                       **args)})
 
     @property
-    def fastdg_args(self):
+    def function_name_suffix(self):
+        if get_form_option("fastdg"):
+            suffix = "_fastdg1_{}comp{}".format(FEM_name_mangling(self.test_element), self.test_element_index)
+            if self.within_inames:
+                suffix = "{}x{}comp{}".format(suffix, FEM_name_mangling(self.trial_element), self.trial_element_index)
+            return suffix
+        else:
+            return ""
+
+    @property
+    def function_args(self):
         if get_form_option("fastdg"):
             ret = ("{}.data()".format(self.accumvar),)
             if get_form_option("fastdg") and self.within_inames:
@@ -214,6 +233,17 @@ class AccumulationOutput(SumfactKernelOutputBase, ImmutableRecord):
         else:
             return ()
 
+    @property
+    def signature_args(self):
+        if get_form_option('fastdg'):
+            ret = ("double* fastdg0",)
+            if self.within_inames:
+                ret = ret + ("unsigned int jacobian_offset0",)
+            return ret
+        else:
+            return ()
+
+
 
 def _local_sizes(element):
     from ufl import FiniteElement, MixedElement
@@ -411,9 +441,8 @@ def generate_accumulation_instruction(expr, visitor):
                                 )
 
     sf = SumfactKernel(matrix_sequence=matrix_sequence,
-                       stage=3,
                        position_priority=priority,
-                       output=output,
+                       interface=output,
                        predicates=predicates,
                        )
 
@@ -497,7 +526,7 @@ def generate_accumulation_instruction(expr, visitor):
     result, insn_dep = realize_sum_factorization_kernel(vsf.copy(insn_dep=vsf.insn_dep.union(insn_dep)))
 
     if not get_form_option("fastdg"):
-        insn_dep = vsf.output.realize(vsf, result, insn_dep)
+        insn_dep = vsf.interface.realize(vsf, result, insn_dep)
 
     if get_option("instrumentation_level") >= 4:
         assert vsf.stage == 3
diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py
index ffb2eef9..3c48f419 100644
--- a/python/dune/perftool/sumfact/basis.py
+++ b/python/dune/perftool/sumfact/basis.py
@@ -32,7 +32,7 @@ from dune.perftool.pdelab.argument import name_coefficientcontainer
 from dune.perftool.pdelab.geometry import (local_dimension,
                                            world_dimension,
                                            )
-from dune.perftool.sumfact.symbolic import SumfactKernel, SumfactKernelInputBase
+from dune.perftool.sumfact.symbolic import SumfactKernel, SumfactKernelInterfaceBase
 from dune.perftool.options import get_form_option
 from dune.perftool.pdelab.driver import FEM_name_mangling
 from dune.perftool.pdelab.restriction import restricted_name
@@ -50,7 +50,7 @@ from loopy.match import Writes
 import pymbolic.primitives as prim
 
 
-class LFSSumfactKernelInput(SumfactKernelInputBase, ImmutableRecord):
+class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord):
     def __init__(self,
                  coeff_func=None,
                  element=None,
@@ -71,7 +71,11 @@ class LFSSumfactKernelInput(SumfactKernelInputBase, ImmutableRecord):
         return repr(self)
 
     @property
-    def direct_input_is_possible(self):
+    def stage(self):
+        return 1
+
+    @property
+    def direct_is_possible(self):
         return get_form_option("fastdg")
 
     def realize(self, sf, insn_dep, index=0):
@@ -113,13 +117,27 @@ class LFSSumfactKernelInput(SumfactKernelInputBase, ImmutableRecord):
         return prim.Subscript(prim.Variable(arg), inames)
 
     @property
-    def fastdg_args(self):
-        if self.direct_input_is_possible:
+    def function_name_suffix(self):
+        if get_form_option("fastdg"):
+            return "_fastdg1_{}comp{}".format(FEM_name_mangling(self.element), self.element_index)
+        else:
+            return ""
+
+    @property
+    def function_args(self):
+        if get_form_option("fastdg"):
             func = self.coeff_func(self.restriction)
             return ("{}.data()".format(func),)
         else:
             return ()
 
+    @property
+    def signature_args(self):
+        if get_form_option("fastdg"):
+            return ("const double* fastdg0",)
+        else:
+            return ()
+
 
 def _basis_functions_per_direction(element):
     """Number of basis functions per direction """
@@ -166,7 +184,7 @@ def pymbolic_coefficient_gradient(element, restriction, index, coeff_func, visit
     # The sum factorization kernel object gathering all relevant information
     sf = SumfactKernel(matrix_sequence=matrix_sequence,
                        position_priority=grad_index,
-                       input=inp,
+                       interface=inp,
                        )
 
     from dune.perftool.sumfact.vectorization import attach_vectorization_info
@@ -207,7 +225,7 @@ def pymbolic_coefficient(element, restriction, index, coeff_func, visitor):
                                 )
 
     sf = SumfactKernel(matrix_sequence=matrix_sequence,
-                       input=inp,
+                       interface=inp,
                        position_priority=3,
                        )
 
diff --git a/python/dune/perftool/sumfact/geometry.py b/python/dune/perftool/sumfact/geometry.py
index d0e96c00..7b78de41 100644
--- a/python/dune/perftool/sumfact/geometry.py
+++ b/python/dune/perftool/sumfact/geometry.py
@@ -17,7 +17,7 @@ from dune.perftool.pdelab.geometry import (local_dimension,
                                            name_geometry,
                                            )
 from dune.perftool.sumfact.switch import get_facedir
-from dune.perftool.sumfact.symbolic import SumfactKernelInputBase
+from dune.perftool.sumfact.symbolic import SumfactKernelInterfaceBase
 from dune.perftool.sumfact.vectorization import attach_vectorization_info
 from dune.perftool.options import get_form_option, option_switch
 from dune.perftool.ufl.modified_terminals import Restriction
@@ -35,7 +35,7 @@ def corner_iname():
     return name
 
 
-class GeoCornersInput(SumfactKernelInputBase, ImmutableRecord):
+class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableRecord):
     def __init__(self, dir):
         ImmutableRecord.__init__(self, dir=dir)
 
@@ -45,7 +45,6 @@ class GeoCornersInput(SumfactKernelInputBase, ImmutableRecord):
         temporary_variable(name,
                            shape=(2 ** local_dimension(), sf.vector_width),
                            custom_base_storage=name_buffer_storage(sf.buffer, 0),
-                           decl_method=buffer_decl(storage, get_sumfact_dtype(sf)),
                            managed=True,
                            )
 
diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py
index 703f9a6b..55e14108 100644
--- a/python/dune/perftool/sumfact/realization.py
+++ b/python/dune/perftool/sumfact/realization.py
@@ -19,7 +19,6 @@ from dune.perftool.generation import (barrier,
 from dune.perftool.loopy.flatten import flatten_index
 from dune.perftool.pdelab.argument import pymbolic_coefficient
 from dune.perftool.pdelab.basis import shape_as_pymbolic
-from dune.perftool.pdelab.driver import FEM_name_mangling
 from dune.perftool.pdelab.geometry import world_dimension
 from dune.perftool.options import (get_form_option,
                                    get_option,
@@ -30,8 +29,7 @@ from dune.perftool.sumfact.permutation import (sumfact_permutation_strategy,
                                                permute_forward,
                                                )
 from dune.perftool.sumfact.quadrature import quadrature_points_per_direction
-from dune.perftool.sumfact.symbolic import (get_input_output_tuple,
-                                            SumfactKernel,
+from dune.perftool.sumfact.symbolic import (SumfactKernel,
                                             VectorizedSumfactKernel,
                                             )
 from dune.perftool.sumfact.vectorization import attach_vectorization_info
@@ -113,22 +111,15 @@ def _realize_sum_factorization_kernel(sf):
                            )
 
     # Realize the input if it is not direct
-    if not sf.input.direct_input_is_possible:
-        insn_dep = insn_dep.union(sf.input.realize(sf, insn_dep))
-
-    # Collect function call arguments
-    fastdg_args = ()
-    if sf.stage == 1:
-        fastdg_args = sf.input.fastdg_args
-    if sf.stage == 3:
-        fastdg_args = sf.output.fastdg_args
+    if sf.stage == 1 and not sf.interface.direct_is_possible:
+        insn_dep = insn_dep.union(sf.interface.realize(sf, insn_dep))
 
     # Trigger generation of the sum factorization kernel function
     qp = quadrature_points_per_direction()
     necessary_kernel_implementations((sf, qp))
 
     # Call the function
-    code = "{}({});".format(sf.function_name, ", ".join(buffers + fastdg_args))
+    code = "{}({});".format(sf.function_name, ", ".join(buffers + sf.interface.function_args))
     tag = "sumfact_stage{}".format(sf.stage)
     insn_dep = frozenset({instruction(code=code,
                                       depends_on=insn_dep,
@@ -227,12 +218,12 @@ def realize_sumfact_kernel_function(sf):
         # * a global data structure (if FastDGGridOperator is in use)
         # * a value from a global data structure, broadcasted to a vector type (vectorized + FastDGGridOperator)
         input_inames = (k_expr,) + tuple(prim.Variable(j) for j in out_inames[1:])
-        if l == 0 and sf.input.direct_input_is_possible:
+        if l == 0 and sf.stage == 1 and sf.interface.direct_is_possible:
             # See comment below
             input_inames = permute_backward(input_inames, perm)
             inp_shape = permute_backward(inp_shape, perm)
 
-            input_summand = sf.input.realize_direct(inp_shape, input_inames)
+            input_summand = sf.interface.realize_direct(inp_shape, input_inames)
         else:
             # If we did permute the order of a matrices above we also
             # permuted the order of out_inames. Unfortunately the
@@ -297,7 +288,7 @@ def realize_sumfact_kernel_function(sf):
         if l == len(matrix_sequence) - 1 and get_form_option('fastdg') and sf.stage == 3:
             if sf.vectorized:
                 insn_args["forced_iname_deps"] = frozenset({vec_iname[0].name})
-            insn_dep = sf.output.realize_direct(matprod, output_inames, out_shape, **insn_args)
+            insn_dep = sf.interface.realize_direct(matprod, output_inames, out_shape, **insn_args)
         else:
             # Issue the reduction instruction that implements the multiplication
             # at the same time store the instruction ID for the next instruction to depend on
@@ -309,14 +300,7 @@ def realize_sumfact_kernel_function(sf):
 
     # Construct a loopy kernel object
     from dune.perftool.pdelab.localoperator import extract_kernel_from_cache
-    args = ["const char* buffer0", "const char* buffer1"]
-    if get_form_option('fastdg'):
-        const = "const " if sf.stage == 1 else ""
-        for i in range(len(get_input_output_tuple(sf))):
-            args.append("{}double* fastdg{}".format(const, i))
-            if sf.within_inames:
-                args.append("unsigned int jacobian_offset{}".format(i))
-
+    args = ("const char* buffer0", "const char* buffer1") + sf.interface.signature_args
     signature = "void {}({}) const".format(sf.function_name, ", ".join(args))
     kernel = extract_kernel_from_cache("kernel_default", sf.function_name, [signature], add_timings=False)
     delete_cache_items("kernel_default")
diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py
index 5ef99849..5f979cd7 100644
--- a/python/dune/perftool/sumfact/symbolic.py
+++ b/python/dune/perftool/sumfact/symbolic.py
@@ -5,7 +5,6 @@ from dune.perftool.generation import (get_counted_variable,
                                       subst_rule,
                                       transform,
                                       )
-from dune.perftool.pdelab.driver import FEM_name_mangling
 from dune.perftool.pdelab.geometry import local_dimension, world_dimension
 from dune.perftool.sumfact.quadrature import quadrature_inames
 from dune.perftool.sumfact.tabulation import BasisTabulationMatrixBase, BasisTabulationMatrixArray
@@ -23,58 +22,85 @@ import frozendict
 import inspect
 
 
-class SumfactKernelInputBase(object):
+class SumfactKernelInterfaceBase(object):
+    """ A base class for the input/output of a sum factorization kernel
+    In stage 1, this represents the input object, in stage 3 the output object.
+    """
+    def realize(self, *a, **kw):
+        raise NotImplementedError
+
+    def realize_direct(self, *a, **kw):
+        raise NotImplementedError
+
     @property
-    def direct_input_is_possible(self):
-        return False
+    def within_inames(self):
+        return ()
 
-    def realize(self, sf, dep, index=0):
-        return dep
+    @property
+    def direct_is_possible(self):
+        return False
 
-    def realize_direct(self, inames):
+    @property
+    def stage(self):
         raise NotImplementedError
 
+    @property
+    def function_args(self):
+        return ()
+
+    @property
+    def signature_args(self):
+        return ()
+    
+    @property
+    def function_name_suffix(self):
+        return ""
+
     def __repr__(self):
-        return "SumfactKernelInputBase()"
+        return "SumfactKernelInterfaceBase()"
 
 
-class VectorSumfactKernelInput(SumfactKernelInputBase):
-    def __init__(self, inputs):
-        assert(isinstance(inputs, tuple))
-        self.inputs = inputs
+class VectorSumfactKernelInput(SumfactKernelInterfaceBase):
+    def __init__(self, interfaces):
+        assert(isinstance(interfaces, tuple))
+        self.interfaces = interfaces 
 
     def __repr__(self):
-        return "_".join(repr(i) for i in self.inputs)
+        return "_".join(repr(i) for i in self.interfaces)
+
+    @property
+    def stage(self):
+        return 1
 
     @property
-    def direct_input_is_possible(self):
-        return all(i.direct_input_is_possible for i in self.inputs)
+    def direct_is_possible(self):
+        return all(i.direct_is_possible for i in self.interfaces)
 
     def realize(self, sf, dep):
-        for i, inp in enumerate(self.inputs):
+        for i, inp in enumerate(self.interfaces):
             dep = dep.union(inp.realize(sf, dep, index=i))
         return dep
 
     def realize_direct(self, shape, inames):
         # Check whether the input exhibits a favorable structure
         # (whether we can broadcast scalar values into SIMD registers)
-        total = set(self.inputs)
-        lower = set(self.inputs[:len(self.inputs) // 2])
-        upper = set(self.inputs[len(self.inputs) // 2:])
+        total = set(self.interfaces)
+        lower = set(self.interfaces[:len(self.interfaces) // 2])
+        upper = set(self.interfaces[len(self.interfaces) // 2:])
 
         if len(total) == 1:
             # All input coefficients use the exact same input coefficient.
             # We implement this by broadcasting it into a SIMD register
             return prim.Call(ExplicitVCLCast(dtype_floatingpoint()),
-                             (self.inputs[0].realize_direct(shape, inames),)
+                             (self.interfaces[0].realize_direct(shape, inames),)
                              )
         elif len(total) == 2 and len(lower) == 1 and len(upper) == 1:
             # The lower and the upper part of the SIMD register use
             # the same input coefficient, we combine the SIMD register
             # from two shorter SIMD types
             return prim.Call(VCLLowerUpperLoad(dtype_floatingpoint()),
-                             (self.inputs[0].realize_direct(shape, inames),
-                              self.inputs[len(self.inputs) // 2].realize_direct(shape, inames, which=1),
+                             (self.interfaces[0].realize_direct(shape, inames),
+                              self.interfaces[len(self.interfaces) // 2].realize_direct(shape, inames, which=1),
                               )
                              )
         else:
@@ -83,36 +109,37 @@ class VectorSumfactKernelInput(SumfactKernelInputBase):
             raise NotImplementedError("SIMD loads from scalars not implemented!")
 
     @property
-    def fastdg_args(self):
-        return sum((i.fastdg_args for i in remove_duplicates(self.inputs)), ())
+    def function_args(self):
+        return sum((i.function_args for i in remove_duplicates(self.interfaces)), ())
 
+    @property
+    def signature_args(self):
+        return tuple("const double* fastdg{}".format(i)for i, _ in enumerate(remove_duplicates(self.interfaces)))
 
-class SumfactKernelOutputBase(object):
     @property
-    def within_inames(self):
-        return ()
+    def function_name_suffix(self):
+        return "".join(i.function_name_suffix for i in remove_duplicates(self.interfaces))
 
-    def realize(self, sf, result, insn_dep):
-        return dep
 
-    def realize_direct(self, result, inames, shape, args):
-        raise NotImplementedError
+class VectorSumfactKernelOutput(SumfactKernelInterfaceBase):
+    def __init__(self, interfaces):
+        self.interfaces = interfaces
 
     def __repr__(self):
-        return "SumfactKernelOutputBase()"
-
+        return "_".join(repr(o) for o in self.interfaces)
 
-class VectorSumfactKernelOutput(SumfactKernelOutputBase):
-    def __init__(self, outputs):
-        self.outputs = outputs
+    @property
+    def stage(self):
+        return 3
 
-    def __repr__(self):
-        return "_".join(repr(o) for o in self.outputs)
+    @property
+    def within_inames(self):
+        return self.interfaces[0].within_inames
 
     def _add_hadd(self, o, result):
         hadd_function = "horizontal_add"
-        if len(set(self.outputs)) > 1:
-            pos = self.outputs.index(o)
+        if len(set(self.interfaces)) > 1:
+            pos = self.interfaces.index(o)
             if pos == 0:
                 hadd_function = "horizontal_add_lower"
             else:
@@ -121,10 +148,10 @@ class VectorSumfactKernelOutput(SumfactKernelOutputBase):
         return prim.Call(prim.Variable(hadd_function), (result,))
 
     def realize(self, sf, result, insn_dep):
-        outputs = set(self.outputs)
+        outputs = set(self.interfaces)
 
-        trial_element, = set(o.trial_element for o in self.outputs)
-        trial_element_index, = set(o.trial_element_index for o in self.outputs)
+        trial_element, = set(o.trial_element for o in self.interfaces)
+        trial_element_index, = set(o.trial_element_index for o in self.interfaces)
         from dune.perftool.sumfact.accumulation import accum_iname
         element = get_leaf(trial_element, trial_element_index) if trial_element is not None else None
         inames = tuple(accum_iname(element, mat.rows, i)
@@ -140,7 +167,7 @@ class VectorSumfactKernelOutput(SumfactKernelOutputBase):
         return deps
 
     def realize_direct(self, result, inames, shape, **args):
-        outputs = set(self.outputs)
+        outputs = set(self.interfaces)
 
         # If multiple horizontal_add's are to be performed with 'result'
         # we need to precompute the result!
@@ -153,14 +180,33 @@ class VectorSumfactKernelOutput(SumfactKernelOutputBase):
         deps = frozenset()
         for o in outputs:
             hadd_result = self._add_hadd(o, result)
-            which = tuple(remove_duplicates(self.outputs)).index(o)
+            which = tuple(remove_duplicates(self.interfaces)).index(o)
             deps = deps.union(o.realize_direct(hadd_result, inames, shape, which=which, **args))
 
         return deps
 
     @property
-    def fastdg_args(self):
-        return sum((i.fastdg_args for i in remove_duplicates(self.outputs)), ())
+    def function_args(self):
+        if get_form_option("fastdg"):
+            return sum((i.function_args for i in remove_duplicates(self.interfaces)), ())
+        else:
+            return()
+
+    @property
+    def signature_args(self):
+        if get_form_option("fastdg"):
+            def _get_pair(i):
+                ret = ("double* fastdg{}".format(i),)
+                if self.within_inames:
+                    ret = ret + ("unsigned int jacobian_offset{}".format(i),)
+                return ret
+            return sum((_get_pair(i) for i, _ in enumerate(remove_duplicates(self.interfaces))), ())
+        else:
+            return ()
+
+    @property
+    def function_name_suffix(self):
+        return "".join(i.function_name_suffix for i in remove_duplicates(self.interfaces))
 
 
 class SumfactKernelBase(object):
@@ -171,11 +217,9 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
     def __init__(self,
                  matrix_sequence=None,
                  buffer=None,
-                 stage=1,
                  position_priority=None,
                  insn_dep=frozenset(),
-                 input=SumfactKernelInputBase(),
-                 output=SumfactKernelOutputBase(),
+                 interface=SumfactKernelInterfaceBase(),
                  predicates=frozenset(),
                  ):
         """Create a sum factorization kernel
@@ -229,31 +273,18 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
             for intermediate results. The memory is expected to be
             pre-initialized with the input or you have to provide
             direct_input (FastDGGridOperator).
-        stage: 1 or 3
         position_priority: Will be used in the dry run to order kernels
             when doing vectorization e.g. (dx u,dy u,dz u, u).
-        restriction: Restriction for faces values.
         insn_dep: An instruction ID that the first issued instruction
             should depend upon. All following ones will depend on each
             other.
-        input: An SumfactKernelInputBase instance describing the input of the kernel
-        accumvar: The accumulation variable to accumulate into
-        trial_element: The leaf element of the trial function space.
-            Used to correctly nest stage 3 in the jacobian case.
-        test_element: The leaf element of the test function space
-            Used to compute offsets in the fastdg case.
-        test_element_index: the component of the test_element
-        trial_element_index: the component of the trial_element
+        interface: An SumfactKernelInterfaceBase instance describing the input
+            (stage 1) or output (stage 3) of the kernel
         """
         # Assert the inputs!
         assert isinstance(matrix_sequence, tuple)
         assert all(isinstance(m, BasisTabulationMatrixBase) for m in matrix_sequence)
-
-        assert stage in (1, 3)
-
-        assert isinstance(input, SumfactKernelInputBase)
-        assert isinstance(output, SumfactKernelOutputBase)
-
+        assert isinstance(interface, SumfactKernelInterfaceBase)
         assert isinstance(insn_dep, frozenset)
 
         # The following construction is a bit weird: Dict comprehensions do not have
@@ -279,7 +310,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
     def __str__(self):
         # Above stringifier just calls back into this
         return "SF{}:[{}]->[{}]".format(self.stage,
-                                        str(self.input),
+                                        str(self.interface),
                                         ", ".join(str(m) for m in self.matrix_sequence))
 
     mapper_method = "map_sumfact_kernel"
@@ -291,16 +322,8 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
     @property
     def function_name(self):
         """ The name of the function that implements this kernel """
-        name = "sfimpl_{}".format("_".join(str(m) for m in self.matrix_sequence))
-        if get_form_option("fastdg"):
-            if self.stage == 1:
-                fastdg = "{}comp{}".format(FEM_name_mangling(self.input.element), self.input.element_index)
-            if self.stage == 3:
-                fastdg = "{}comp{}".format(FEM_name_mangling(self.output.test_element), self.output.test_element_index)
-                if self.within_inames:
-                    fastdg = "{}x{}comp{}".format(fastdg, FEM_name_mangling(self.output.trial_element), self.output.trial_element_index)
-            name = "{}_fastdg{}_{}".format(name, self.stage, fastdg)
-        return name
+        return "sfimpl_{}{}".format("_".join(str(m) for m in self.matrix_sequence),
+                                    self.interface.function_name_suffix)
 
     @property
     def parallel_key(self):
@@ -327,7 +350,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
         work on the same input coefficient (stage 1) or accumulate
         into the same thing (stage 3)
         """
-        return (repr(self.input), repr(self.output))
+        return repr(self.interface)
 
     #
     # Some convenience methods to extract information about the sum factorization kernel
@@ -337,7 +360,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
         if self.parallel_key != other.parallel_key:
             return self.parallel_key < other.parallel_key
         if self.inout_key != other.inout_key:
-            return self.input_key < other.input_key
+            return self.inout_key < other.inout_key
         if self.position_priority == other.position_priority:
             return repr(self) < repr(other)
         if self.position_priority is None:
@@ -361,7 +384,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
 
     @property
     def within_inames(self):
-        return self.output.within_inames
+        return self.interface.within_inames
 
     def vec_index(self, sf):
         """ Map an unvectorized sumfact kernel object to its position
@@ -447,6 +470,10 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
     def tag(self):
         return "sumfac"
 
+    @property
+    def stage(self):
+        return self.interface.stage
+
     #
     # Define properties for conformity with the interface of VectorizedSumfactKernel
     #
@@ -548,7 +575,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
     def __str__(self):
         # Above stringifier just calls back into this
         return "VSF{}:[{}]->[{}]".format(self.stage,
-                                         ", ".join(str(k.input) for k in self.kernels),
+                                         ", ".join(str(k.interface) for k in self.kernels),
                                          ", ".join(str(mat) for mat in self.matrix_sequence))
 
     mapper_method = "map_vectorized_sumfact_kernel"
@@ -561,18 +588,8 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
     #
     @property
     def function_name(self):
-        name = "sfimpl_{}".format("_".join(str(m) for m in self.matrix_sequence))
-        if get_form_option("fastdg"):
-            if self.stage == 1:
-                fastdg = "_".join("{}comp{}".format(FEM_name_mangling(i.element), i.element_index) for i in remove_duplicates(self.input.inputs))
-            if self.stage == 3:
-                fastdg = "_".join("{}comp{}".format(FEM_name_mangling(i.test_element), i.test_element_index) for i in remove_duplicates(self.output.outputs))
-                if self.within_inames:
-                    fastdg = "{}x{}".format(fastdg,
-                                            "_".join("{}comp{}".format(FEM_name_mangling(i.trial_element), i.trial_element_index) for i in remove_duplicates(self.output.outputs))
-                                            )
-            name = "{}_fastdg{}_{}".format(name, self.stage, fastdg)
-        return name
+        return "sfimpl_{}{}".format("_".join(str(m) for m in self.matrix_sequence),
+                                    self.interface.function_name_suffix)
 
     @property
     def cache_key(self):
@@ -627,14 +644,16 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
     #
     # Define the same properties the normal SumfactKernel defines
     #
-
     @property
-    def input(self):
-        return VectorSumfactKernelInput(tuple(k.input for k in self.kernels))
+    def stage(self):
+        return self.kernels[0].stage
 
     @property
-    def output(self):
-        return VectorSumfactKernelOutput(tuple(k.output for k in self.kernels))
+    def interface(self):
+        if self.stage == 1:
+            return VectorSumfactKernelInput(tuple(k.interface for k in self.kernels))
+        else:
+            return VectorSumfactKernelOutput(tuple(k.interface for k in self.kernels))
 
     @property
     def cache_key(self):
@@ -791,17 +810,3 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
         to be carried out """
         from dune.perftool.sumfact.permutation import flop_cost
         return flop_cost(self.matrix_sequence)
-
-
-def get_input_output_tuple(sf):
-    if sf.stage == 1:
-        if isinstance(sf, SumfactKernel):
-            return (sf.input,)
-        if isinstance(sf, VectorizedSumfactKernel):
-            return tuple(remove_duplicates(sf.input.inputs))
-    if sf.stage == 3:
-        if isinstance(sf, SumfactKernel):
-            return (sf.output,)
-        if isinstance(sf, VectorizedSumfactKernel):
-            return tuple(remove_duplicates(sf.output.outputs))
-    assert(False)
-- 
GitLab