From 3fa94dc6ca514ba210e060edf71e69a4b7097b09 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Thu, 24 Jan 2019 10:57:02 +0100
Subject: [PATCH] Remove backend mechanism

---
 python/dune/codegen/__init__.py               |  3 +-
 .../dune/codegen/blockstructured/argument.py  |  4 +-
 python/dune/codegen/blockstructured/basis.py  |  5 +-
 python/dune/codegen/generation/__init__.py    |  6 ---
 python/dune/codegen/generation/backend.py     | 44 ----------------
 python/dune/codegen/generation/cache.py       |  6 ---
 python/dune/codegen/pdelab/argument.py        |  2 -
 python/dune/codegen/pdelab/basis.py           | 14 +++--
 python/dune/codegen/pdelab/function.py        | 52 -------------------
 python/dune/codegen/pdelab/geometry.py        |  4 +-
 python/dune/codegen/pdelab/localoperator.py   | 16 +++---
 python/dune/codegen/pdelab/quadrature.py      |  4 +-
 python/dune/codegen/sumfact/accumulation.py   |  1 -
 python/dune/codegen/sumfact/basis.py          |  3 +-
 python/dune/codegen/sumfact/geometry.py       |  3 +-
 python/dune/codegen/sumfact/quadrature.py     |  3 +-
 python/dune/codegen/sumfact/realization.py    |  1 -
 python/dune/codegen/sumfact/switch.py         | 11 ++--
 python/dune/codegen/sumfact/vectorization.py  |  4 +-
 .../dune/codegen/generation/test_backend.py   | 21 --------
 20 files changed, 31 insertions(+), 176 deletions(-)
 delete mode 100644 python/dune/codegen/generation/backend.py
 delete mode 100644 python/dune/codegen/pdelab/function.py
 delete mode 100644 python/test/dune/codegen/generation/test_backend.py

diff --git a/python/dune/codegen/__init__.py b/python/dune/codegen/__init__.py
index 243911af..e892baf4 100644
--- a/python/dune/codegen/__init__.py
+++ b/python/dune/codegen/__init__.py
@@ -6,8 +6,7 @@ os.environ["OMP_NUM_THREADS"] = "1"
 # Trigger imports that involve monkey patching!
 import dune.codegen.loopy.symbolic  # noqa
 
-# Trigger some imports that are needed to have all backend implementations visible
-# to the selection mechanisms
+# Trigger some imports that are needed to have all mixin implementations visible
 import dune.codegen.pdelab  # noqa
 import dune.codegen.sumfact  # noqa
 import dune.codegen.blockstructured  # noqa
diff --git a/python/dune/codegen/blockstructured/argument.py b/python/dune/codegen/blockstructured/argument.py
index b0940e66..420773e8 100644
--- a/python/dune/codegen/blockstructured/argument.py
+++ b/python/dune/codegen/blockstructured/argument.py
@@ -1,5 +1,4 @@
-from dune.codegen.generation import (backend,
-                                     kernel_cached,
+from dune.codegen.generation import (kernel_cached,
                                      valuearg, instruction, globalarg)
 from dune.codegen.options import get_form_option
 from dune.codegen.pdelab.argument import CoefficientAccess
@@ -26,7 +25,6 @@ def name_alias(container, lfs, element):
 
 
 # TODO remove the need for element
-@backend(interface="pymbolic_coefficient", name="blockstructured")
 @kernel_cached
 def pymbolic_coefficient(container, lfs, element, index):
     # TODO introduce a proper type for local function spaces!
diff --git a/python/dune/codegen/blockstructured/basis.py b/python/dune/codegen/blockstructured/basis.py
index ea90582e..5af8bb48 100644
--- a/python/dune/codegen/blockstructured/basis.py
+++ b/python/dune/codegen/blockstructured/basis.py
@@ -1,9 +1,7 @@
 from loopy import Reduction
 
-from dune.codegen.generation import (backend,
-                                     basis_mixin,
+from dune.codegen.generation import (basis_mixin,
                                      kernel_cached,
-                                     get_backend,
                                      instruction,
                                      temporary_variable,
                                      globalarg,
@@ -139,7 +137,6 @@ class BlockStructuredBasisMixin(GenericBasisMixin):
 
 
 # define FE basis explicitly in localoperator
-@backend(interface="typedef_localbasis", name="blockstructured")
 @class_member(classtag="operator")
 def typedef_localbasis(element, name):
     df = "typename {}::Traits::GridView::ctype".format(type_leaf_gfs(element))
diff --git a/python/dune/codegen/generation/__init__.py b/python/dune/codegen/generation/__init__.py
index b8207414..d0cf1d4d 100644
--- a/python/dune/codegen/generation/__init__.py
+++ b/python/dune/codegen/generation/__init__.py
@@ -1,9 +1,3 @@
-from __future__ import absolute_import
-
-from dune.codegen.generation.backend import (backend,
-                                             get_backend,
-                                             )
-
 from dune.codegen.generation.counter import (get_counter,
                                              get_counted_variable,
                                              )
diff --git a/python/dune/codegen/generation/backend.py b/python/dune/codegen/generation/backend.py
deleted file mode 100644
index ba793674..00000000
--- a/python/dune/codegen/generation/backend.py
+++ /dev/null
@@ -1,44 +0,0 @@
-from dune.codegen.generation.cache import _RegisteredFunction
-from dune.codegen.options import option_switch
-from pytools import ImmutableRecord
-
-
-_backend_mapping = {}
-
-
-class FuncProxy(ImmutableRecord):
-    def __init__(self, interface, name, func):
-        ImmutableRecord.__init__(self, interface=interface, name=name, func=func)
-
-    def __call__(self, *args, **kwargs):
-        return self.func(*args, **kwargs)
-
-
-def register_backend(interface, name, func):
-    _backend_mapping.setdefault(interface, {})
-    _backend_mapping[interface][name] = func
-
-
-def backend(interface=None, name='default'):
-    assert interface
-
-    def _dec(func):
-        if not isinstance(func, _RegisteredFunction):
-            # Allow order independence of the generator decorators
-            # and the backend decorator by delaying the registration
-            func = FuncProxy(interface, name, func)
-
-        register_backend(interface, name, func)
-
-        return func
-
-    return _dec
-
-
-def get_backend(interface=None, selector=option_switch("sumfact"), **kwargs):
-    assert interface and selector
-
-    select = selector(**kwargs)
-    assert select in _backend_mapping[interface], "Implementation '{}' for interface '{}' missing!".format(select, interface)
-
-    return _backend_mapping[interface][select]
diff --git a/python/dune/codegen/generation/cache.py b/python/dune/codegen/generation/cache.py
index c727fe05..2e555ae9 100644
--- a/python/dune/codegen/generation/cache.py
+++ b/python/dune/codegen/generation/cache.py
@@ -86,12 +86,6 @@ class _RegisteredFunction(object):
         # Initialize the memoization cache
         self._memoize_cache = {}
 
-        # Allow order independence of the backend and the generator decorators.
-        # If backend was applied first, we resolve the issued FuncProxy object
-        from dune.codegen.generation.backend import FuncProxy
-        if isinstance(self.func, FuncProxy):
-            raise NotImplementedError("Please use @backend as the outer decorator if combining with generator decorators")
-
     def _get_content(self, key):
         return self._memoize_cache[key].value
 
diff --git a/python/dune/codegen/pdelab/argument.py b/python/dune/codegen/pdelab/argument.py
index f9c6c56b..8d3c3a3d 100644
--- a/python/dune/codegen/pdelab/argument.py
+++ b/python/dune/codegen/pdelab/argument.py
@@ -11,7 +11,6 @@ from dune.codegen.generation import (domain,
                                      valuearg,
                                      get_global_context_value,
                                      kernel_cached,
-                                     backend
                                      )
 from dune.codegen.loopy.target import dtype_floatingpoint
 from dune.codegen.pdelab.index import name_index
@@ -96,7 +95,6 @@ def name_applycontainer(restriction):
     return name
 
 
-@backend(interface="pymbolic_coefficient")
 @kernel_cached
 def pymbolic_coefficient(container, lfs, index):
     # TODO introduce a proper type for local function spaces!
diff --git a/python/dune/codegen/pdelab/basis.py b/python/dune/codegen/pdelab/basis.py
index 583c3738..03c069ac 100644
--- a/python/dune/codegen/pdelab/basis.py
+++ b/python/dune/codegen/pdelab/basis.py
@@ -1,9 +1,7 @@
 """ Generators for basis evaluations """
 
-from dune.codegen.generation import (backend,
-                                     basis_mixin,
+from dune.codegen.generation import (basis_mixin,
                                      class_member,
-                                     get_backend,
                                      include_file,
                                      instruction,
                                      kernel_cached,
@@ -278,7 +276,6 @@ def declare_grid_function_range(gridfunction):
     return _decl
 
 
-@backend(interface="typedef_localbasis")
 @class_member(classtag="operator")
 def typedef_localbasis(element, name):
     basis_type = "{}::Traits::FiniteElementMap::Traits::FiniteElementType::Traits::LocalBasisType".format(type_leaf_gfs(element))
@@ -294,7 +291,14 @@ def type_localbasis(element):
         name = "DG{}_LocalBasis".format(element._degree)
     else:
         raise NotImplementedError("Element type not known in code generation")
-    get_backend("typedef_localbasis", selector=option_switch("blockstructured"))(element, name)
+
+    # TODO get rid of this
+    if get_form_option("blockstructured"):
+        from dune.codegen.blockstructured.basis import typedef_localbasis as bs_typedef_localbasis
+        bs_typedef_localbasis(element, name)
+    else:
+        typedef_localbasis(element, name)
+
     return name
 
 
diff --git a/python/dune/codegen/pdelab/function.py b/python/dune/codegen/pdelab/function.py
deleted file mode 100644
index 74e6b6b6..00000000
--- a/python/dune/codegen/pdelab/function.py
+++ /dev/null
@@ -1,52 +0,0 @@
-from dune.codegen.generation import (get_backend,
-                                     instruction,
-                                     kernel_cached,
-                                     preamble,
-                                     temporary_variable,
-                                     )
-from dune.codegen.pdelab.geometry import (name_cell,
-                                          world_dimension,
-                                          )
-from dune.codegen.pdelab.localoperator import name_gridfunction_member
-
-import pymbolic.primitives as prim
-
-
-@preamble
-def bind_gridfunction_to_element(gf, restriction):
-    element = name_cell(restriction)
-    return "{}.bind({});".format(gf, element)
-
-
-def declare_grid_function_range(gridfunction):
-    def _decl(name, kernel, decl_info):
-        return "typename decltype({})::Range {};".format(gridfunction, name)
-
-    return _decl
-
-
-@kernel_cached
-def pymbolic_evaluate_gridfunction(name, coeff, restriction, grad):
-    diffOrder = 1 if grad else 0
-
-    gridfunction = name_gridfunction_member(coeff, restriction, diffOrder)
-    bind_gridfunction_to_element(gridfunction, restriction)
-
-    temporary_variable(name,
-                       shape=(1,) + (world_dimension(),) * diffOrder,
-                       decl_method=declare_grid_function_range(gridfunction),
-                       managed=False,
-                       )
-
-    quadpos = get_backend(interface="qp_in_cell")(restriction)
-    instruction(code="{} = {}({});".format(name, gridfunction, quadpos),
-                assignees=frozenset({name}),
-                within_inames=frozenset(get_backend(interface="quad_inames")()),
-                within_inames_is_final=True,
-                )
-
-
-def pymbolic_gridfunction(coeff, restriction, grad):
-    name = "coeff{}{}".format(coeff.count(), "_grad" if grad else "")
-    pymbolic_evaluate_gridfunction(name, coeff, restriction, grad)
-    return prim.Subscript(prim.Variable(name), (0,))
diff --git a/python/dune/codegen/pdelab/geometry.py b/python/dune/codegen/pdelab/geometry.py
index 4ade8a22..5ecd8c19 100644
--- a/python/dune/codegen/pdelab/geometry.py
+++ b/python/dune/codegen/pdelab/geometry.py
@@ -1,10 +1,8 @@
 from dune.codegen.ufl.modified_terminals import Restriction
 from dune.codegen.pdelab.restriction import restricted_name
-from dune.codegen.generation import (backend,
-                                     class_member,
+from dune.codegen.generation import (class_member,
                                      domain,
                                      geometry_mixin,
-                                     get_backend,
                                      get_global_context_value,
                                      globalarg,
                                      iname,
diff --git a/python/dune/codegen/pdelab/localoperator.py b/python/dune/codegen/pdelab/localoperator.py
index 65627da0..70ab9f47 100644
--- a/python/dune/codegen/pdelab/localoperator.py
+++ b/python/dune/codegen/pdelab/localoperator.py
@@ -10,7 +10,6 @@ from dune.codegen.options import (get_form_option,
                                   option_switch,
                                   set_form_option)
 from dune.codegen.generation import (accumulation_mixin,
-                                     backend,
                                      base_class,
                                      class_basename,
                                      class_member,
@@ -21,7 +20,6 @@ from dune.codegen.generation import (accumulation_mixin,
                                      end_of_file,
                                      function_mangler,
                                      generator_factory,
-                                     get_backend,
                                      get_global_context_value,
                                      global_context,
                                      iname,
@@ -545,9 +543,13 @@ def generate_kernel(integrals):
     return knl
 
 
-@backend(interface="generate_kernels_per_integral")
 def generate_kernels_per_integral(integrals):
-    yield generate_kernel(integrals)
+    if get_form_option("sumfact"):
+        from dune.codegen.sumfact.switch import sumfact_generate_kernels_per_integral
+        for k in sumfact_generate_kernels_per_integral(integrals):
+            yield k
+    else:
+        yield generate_kernel(integrals)
 
 
 def extract_kernel_from_cache(tag, name, signature, wrap_in_cgen=True, add_timings=True):
@@ -826,7 +828,7 @@ def generate_residual_kernels(form, original_form):
             with global_context(integral_type=measure):
                 from dune.codegen.pdelab.signatures import assembler_routine_name
                 with global_context(kernel=assembler_routine_name()):
-                    kernel = [k for k in get_backend(interface="generate_kernels_per_integral")(form.integrals_by_type(measure))]
+                    kernel = [k for k in generate_kernels_per_integral(form.integrals_by_type(measure))]
 
                 # The integrals might vanish due to unfulfillable boundary conditions.
                 # We only generate the local operator enums/base classes if they did not.
@@ -913,7 +915,7 @@ def generate_jacobian_kernels(form, original_form):
                 with global_context(integral_type=measure):
                     from dune.codegen.pdelab.signatures import assembler_routine_name
                     with global_context(kernel=assembler_routine_name()):
-                        kernel = [k for k in get_backend(interface="generate_kernels_per_integral")(jac_apply_form.integrals_by_type(measure))]
+                        kernel = [k for k in generate_kernels_per_integral(jac_apply_form.integrals_by_type(measure))]
                 operator_kernels[(measure, 'jacobian_apply')] = kernel
 
                 # Generate dummy functions for those kernels, that vanished in the differentiation process
@@ -942,7 +944,7 @@ def generate_jacobian_kernels(form, original_form):
                     with global_context(integral_type=measure):
                         from dune.codegen.pdelab.signatures import assembler_routine_name
                         with global_context(kernel=assembler_routine_name()):
-                            kernel = [k for k in get_backend(interface="generate_kernels_per_integral")(jacform.integrals_by_type(measure))]
+                            kernel = [k for k in generate_kernels_per_integral(jacform.integrals_by_type(measure))]
                     operator_kernels[(measure, 'jacobian')] = kernel
 
                 # Generate dummy functions for those kernels, that vanished in the differentiation process
diff --git a/python/dune/codegen/pdelab/quadrature.py b/python/dune/codegen/pdelab/quadrature.py
index 71972c2c..ae6a7e2d 100644
--- a/python/dune/codegen/pdelab/quadrature.py
+++ b/python/dune/codegen/pdelab/quadrature.py
@@ -1,9 +1,7 @@
 import numpy
 
-from dune.codegen.generation import (backend,
-                                     class_member,
+from dune.codegen.generation import (class_member,
                                      domain,
-                                     get_backend,
                                      get_global_context_value,
                                      globalarg,
                                      iname,
diff --git a/python/dune/codegen/sumfact/accumulation.py b/python/dune/codegen/sumfact/accumulation.py
index ba8c9bba..3be02ae0 100644
--- a/python/dune/codegen/sumfact/accumulation.py
+++ b/python/dune/codegen/sumfact/accumulation.py
@@ -4,7 +4,6 @@ from dune.codegen.pdelab.argument import (name_accumulation_variable,
                                           PDELabAccumulationFunction,
                                           )
 from dune.codegen.generation import (accumulation_mixin,
-                                     backend,
                                      domain,
                                      dump_accumulate_timer,
                                      generator_factory,
diff --git a/python/dune/codegen/sumfact/basis.py b/python/dune/codegen/sumfact/basis.py
index a6fdbb30..d9ecc4a7 100644
--- a/python/dune/codegen/sumfact/basis.py
+++ b/python/dune/codegen/sumfact/basis.py
@@ -5,8 +5,7 @@ multiplication with the test function is part of the sum factorization kernel.
 """
 import itertools
 
-from dune.codegen.generation import (backend,
-                                     basis_mixin,
+from dune.codegen.generation import (basis_mixin,
                                      domain,
                                      get_counted_variable,
                                      get_counter,
diff --git a/python/dune/codegen/sumfact/geometry.py b/python/dune/codegen/sumfact/geometry.py
index aa396340..feda6083 100644
--- a/python/dune/codegen/sumfact/geometry.py
+++ b/python/dune/codegen/sumfact/geometry.py
@@ -1,7 +1,6 @@
 """ Sum factorized geometry evaluations """
 
-from dune.codegen.generation import (backend,
-                                     class_member,
+from dune.codegen.generation import (class_member,
                                      domain,
                                      geometry_mixin,
                                      get_counted_variable,
diff --git a/python/dune/codegen/sumfact/quadrature.py b/python/dune/codegen/sumfact/quadrature.py
index 86982443..91e99c4c 100644
--- a/python/dune/codegen/sumfact/quadrature.py
+++ b/python/dune/codegen/sumfact/quadrature.py
@@ -1,5 +1,4 @@
-from dune.codegen.generation import (backend,
-                                     domain,
+from dune.codegen.generation import (domain,
                                      function_mangler,
                                      get_global_context_value,
                                      globalarg,
diff --git a/python/dune/codegen/sumfact/realization.py b/python/dune/codegen/sumfact/realization.py
index d00f2a01..a40be948 100644
--- a/python/dune/codegen/sumfact/realization.py
+++ b/python/dune/codegen/sumfact/realization.py
@@ -17,7 +17,6 @@ from dune.codegen.generation import (barrier,
                                      transform,
                                      )
 from dune.codegen.loopy.flatten import flatten_index
-from dune.codegen.pdelab.argument import pymbolic_coefficient
 from dune.codegen.pdelab.basis import shape_as_pymbolic
 from dune.codegen.pdelab.geometry import world_dimension
 from dune.codegen.options import (get_form_option,
diff --git a/python/dune/codegen/sumfact/switch.py b/python/dune/codegen/sumfact/switch.py
index c057e731..67cb9df3 100644
--- a/python/dune/codegen/sumfact/switch.py
+++ b/python/dune/codegen/sumfact/switch.py
@@ -2,13 +2,11 @@
 
 import csv
 
-from dune.codegen.generation import (backend,
-                                     get_backend,
-                                     get_global_context_value,
+from dune.codegen.generation import (get_global_context_value,
                                      global_context,
                                      )
 from dune.codegen.pdelab.geometry import world_dimension
-from dune.codegen.pdelab.localoperator import generate_kernel
+from dune.codegen.pdelab.localoperator import generate_kernel, generate_kernels_per_integral
 from dune.codegen.pdelab.signatures import (assembly_routine_args,
                                             assembly_routine_signature,
                                             kernel_name,
@@ -17,8 +15,7 @@ from dune.codegen.options import get_form_option, get_option, set_form_option
 from dune.codegen.cgen.clazz import ClassMember
 
 
-@backend(interface="generate_kernels_per_integral", name="sumfact")
-def generate_kernels_per_integral(integrals):
+def sumfact_generate_kernels_per_integral(integrals):
     dim = world_dimension()
     measure = get_global_context_value("integral_type")
 
@@ -29,7 +26,7 @@ def generate_kernels_per_integral(integrals):
         # Maybe skip sum factorization on boundary integrals
         if not get_form_option("sumfact_on_boundary"):
             set_form_option("sumfact", False)
-            for k in get_backend(interface="generate_kernels_per_integral")(integrals):
+            for k in generate_kernels_per_integral(integrals):
                 yield k
             set_form_option("sumfact", True)
             return
diff --git a/python/dune/codegen/sumfact/vectorization.py b/python/dune/codegen/sumfact/vectorization.py
index 0b6b3a23..fc555f2c 100644
--- a/python/dune/codegen/sumfact/vectorization.py
+++ b/python/dune/codegen/sumfact/vectorization.py
@@ -7,9 +7,7 @@ import logging
 from dune.codegen.loopy.target import dtype_floatingpoint
 from dune.codegen.loopy.vcl import get_vcl_type_size
 from dune.codegen.loopy.symbolic import SumfactKernel, VectorizedSumfactKernel
-from dune.codegen.generation import (backend,
-                                     generator_factory,
-                                     get_backend,
+from dune.codegen.generation import (generator_factory,
                                      get_counted_variable,
                                      get_global_context_value,
                                      kernel_cached,
diff --git a/python/test/dune/codegen/generation/test_backend.py b/python/test/dune/codegen/generation/test_backend.py
deleted file mode 100644
index e7b11e56..00000000
--- a/python/test/dune/codegen/generation/test_backend.py
+++ /dev/null
@@ -1,21 +0,0 @@
-from dune.codegen.generation import (backend,
-                                     get_backend,
-                                     generator_factory,
-                                     )
-
-
-@backend(interface="foo", name="f1")
-@generator_factory()
-def f1():
-    return 1
-
-
-@backend(interface="bar", name="f3")
-@generator_factory()
-def f3():
-    return 3
-
-
-def test_backend():
-    assert get_backend(interface="foo", selector=lambda: "f1")() == 1
-    assert get_backend(interface="bar", selector=lambda: "f3")() == 3
-- 
GitLab