From 9bda3dfe554e860d767a8e7ea69728a65b7c9940 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Mon, 10 Dec 2018 13:36:06 +0100
Subject: [PATCH] Update adoint control code path

---
 python/dune/codegen/options.py              |  8 ++--
 python/dune/codegen/pdelab/__init__.py      | 24 -----------
 python/dune/codegen/pdelab/adjoint.py       | 47 ++++++---------------
 python/dune/codegen/pdelab/basis.py         | 45 +++++++++++++++++++-
 python/dune/codegen/pdelab/localoperator.py |  8 ++--
 python/dune/codegen/ufl/visitor.py          |  4 +-
 6 files changed, 67 insertions(+), 69 deletions(-)

diff --git a/python/dune/codegen/options.py b/python/dune/codegen/options.py
index 4a5a913d..db7a1440 100644
--- a/python/dune/codegen/options.py
+++ b/python/dune/codegen/options.py
@@ -55,9 +55,6 @@ class CodegenGlobalOptionsArray(ImmutableRecord):
     target_name = CodegenOption(default=None, helpstr="The target name from CMake")
     operator_to_build = CodegenOption(default=None, helpstr="The operators from the list that is about to be build now. CMake sets this one!!!")
     debug_interpolate_input = CodegenOption(default=False, helpstr="Should the input for printresidual and printmatix be interpolated (instead of random input).")
-    geometry_mixins = CodegenOption(default="generic", helpstr="A comma separated list of mixin identifiers to use for geometries. Currently implemented mixins: generic, axiparallel, equidistant")
-    quadrature_mixins = CodegenOption(default="generic", helpstr="A comma separated list of mixin identifiers to use for quadrature. Currently implemented: generic")
-    basis_mixins = CodegenOption(default="generic", helpstr="A comma separated list of mixin identifiers to use for basis function evaluation. Currently implemented: generic")
 
     # Arguments that are mainly to be set by logic depending on other options
     max_vector_width = CodegenOption(default=256, helpstr=None)
@@ -113,7 +110,7 @@ class CodegenFormOptionsArray(ImmutableRecord):
     geometry_mixins = CodegenOption(default="generic", helpstr="A comma separated list of mixin identifiers to use for geometries. Currently implemented mixins: generic, axiparallel, equidistant, sumfact_multilinear, sumfact_axiparallel, sumfact_equidistant")
     quadrature_mixins = CodegenOption(default="generic", helpstr="A comma separated list of mixin identifiers to use for quadrature. Currently implemented: generic, sumfact")
     basis_mixins = CodegenOption(default="generic", helpstr="A comma separated list of mixin identifiers to use for basis function evaluation. Currently implemented: generic, sumfact")
-    accumulation_mixins = CodegenOption(default="generic", helpstr="A comma separated list of mixin identifiers to use for accumulation. Currently implemented: generic, sumfact")
+    accumulation_mixins = CodegenOption(default="generic", helpstr="A comma separated list of mixin identifiers to use for accumulation. Currently implemented: generic, sumfact, control")
     enable_volume = CodegenOption(default=True, helpstr="Whether to assemble volume integrals")
     enable_skeleton = CodegenOption(default=True, helpstr="Whether to assemble skeleton integrals")
     enable_boundary = CodegenOption(default=True, helpstr="Whether to assemble boundary integrals")
@@ -187,6 +184,9 @@ def process_form_options(opt, form):
                        accumulation_mixins="sumfact",
                        )
 
+    if opt.control:
+        opt = opt.copy(accumulation_mixins="control")
+
     if opt.numerical_jacobian:
         opt = opt.copy(generate_jacobians=False, generate_jacobian_apply=False)
 
diff --git a/python/dune/codegen/pdelab/__init__.py b/python/dune/codegen/pdelab/__init__.py
index 7ca90b96..e32a61c5 100644
--- a/python/dune/codegen/pdelab/__init__.py
+++ b/python/dune/codegen/pdelab/__init__.py
@@ -20,34 +20,10 @@ class PDELabInterface(object):
         # The visitor instance will be registered by its init method
         self.visitor = None
 
-    #
-    # TODO: The following ones are actually entirely PDELab independent!
-    # They should be placed elsewhere and be used directly in the visitor.
-    #
-
-    def component_iname(self, context=None, count=None):
-        return component_iname(context=context, count=count)
-
-    def name_index(self, ind):
-        return name_index(ind)
-
-    #
-    # Local function space related generator functions
-    #
-
-    def lfs_inames(self, element, restriction, number=None, context=''):
-        return lfs_inames(element, restriction, number, context)
-
     def initialize_function_spaces(self, expr, visitor):
         from dune.codegen.pdelab.spaces import initialize_function_spaces
         return initialize_function_spaces(expr, visitor)
 
-    #
-    # Test and trial function related generator functions
-    #
-    def pymbolic_gridfunction(self, coeff, restriction, grad):
-        return pymbolic_gridfunction(coeff, restriction, grad)
-
     #
     # Tensor expression related generator functions
     #
diff --git a/python/dune/codegen/pdelab/adjoint.py b/python/dune/codegen/pdelab/adjoint.py
index 2681b003..4bf35aa5 100644
--- a/python/dune/codegen/pdelab/adjoint.py
+++ b/python/dune/codegen/pdelab/adjoint.py
@@ -8,7 +8,8 @@ from loopy.types import NumpyType
 
 import pymbolic.primitives as prim
 
-from dune.codegen.generation import (class_member,
+from dune.codegen.generation import (accumulation_mixin,
+                                     class_member,
                                      constructor_parameter,
                                      function_mangler,
                                      get_global_context_value,
@@ -24,6 +25,8 @@ from dune.codegen.pdelab import PDELabInterface
 from dune.codegen.pdelab.localoperator import (boundary_predicates,
                                                determine_accumulation_space,
                                                extract_kernel_from_cache,
+                                               GenericAccumulationMixin,
+                                               get_visitor,
                                                )
 
 
@@ -61,7 +64,7 @@ def generate_accumulation_instruction(expr, visitor, accumulation_index, number_
     expr = prim.Sum((assignee, expr))
 
     from dune.codegen.generation import instruction
-    quad_inames = visitor.interface.quadrature_inames()
+    quad_inames = visitor.quadrature_inames()
     instruction(assignee=assignee,
                 expression=expr,
                 forced_iname_deps=frozenset(quad_inames),
@@ -69,44 +72,22 @@ def generate_accumulation_instruction(expr, visitor, accumulation_index, number_
                 )
 
 
-def list_accumulation_infos(expr, visitor):
-    return ["control", ]
-
-
-class ControlInterface(PDELabInterface):
-    """Interface for generating the control localoperator
-
-    In this case we will not accumulate in the residual vector but use
-    a class member representing dJdm instead.
-
-    """
-    def __init__(self, accumulation_index, number_of_controls):
-        """Create ControlInterface
-
-        Arguments:
-        ----------
-        accumulation_index: In which component of the dJdm should be accumulated.
-        number_of_controls: Number of components of dJdm. Needed for creating the member variable.
-        """
+@accumulation_mixin("control")
+class AdjointAccumulationMixin(GenericAccumulationMixin):
+    def set_adjoint_information(self, accumulation_index, number_of_controls):
         self.accumulation_index = accumulation_index
         self.number_of_controls = number_of_controls
 
-    def list_accumulation_infos(self, expr, visitor):
-        return list_accumulation_infos(expr, visitor)
+    def list_accumulation_infos(self, expr):
+        return ["control"]
 
-    def generate_accumulation_instruction(self, expr, visitor):
+    def generate_accumulation_instruction(self, expr):
         return generate_accumulation_instruction(expr,
-                                                 visitor,
+                                                 self,
                                                  self.accumulation_index,
                                                  self.number_of_controls)
 
 
-def get_visitor(measure, subdomain_id, accumulation_index, number_of_controls):
-    interface = ControlInterface(accumulation_index, number_of_controls)
-    from dune.codegen.ufl.visitor import UFL2LoopyVisitor
-    return UFL2LoopyVisitor(interface, measure, subdomain_id)
-
-
 def visit_integral(integral, accumulation_index, number_of_controls):
     integrand = integral.integrand()
     measure = integral.integral_type()
@@ -114,7 +95,8 @@ def visit_integral(integral, accumulation_index, number_of_controls):
 
     # The visitor needs to know about the current index and the number
     # of controls in order to generate the accumulation instruction
-    visitor = get_visitor(measure, subdomain_id, accumulation_index, number_of_controls)
+    visitor = get_visitor(measure, subdomain_id)
+    visitor.set_adjoint_information(accumulation_index, number_of_controls)
 
     # Start the visiting process!
     visitor.accumulate(integrand)
@@ -162,7 +144,6 @@ def generate_kernel(forms):
     return knl
 
 
-# @backend(interface="generate_kernels_per_integral")
 def control_generate_kernels_per_integral(forms):
     """For the control problem forms will have one form for every
     measure. Every form will only contain integrals of one type.
diff --git a/python/dune/codegen/pdelab/basis.py b/python/dune/codegen/pdelab/basis.py
index ec03a182..b3c0d132 100644
--- a/python/dune/codegen/pdelab/basis.py
+++ b/python/dune/codegen/pdelab/basis.py
@@ -28,6 +28,7 @@ from dune.codegen.pdelab.geometry import (component_iname,
                                           )
 from dune.codegen.pdelab.localoperator import (lop_template_ansatz_gfs,
                                                lop_template_test_gfs,
+                                               name_gridfunction_member,
                                                )
 from dune.codegen.tools import (get_pymbolic_basename,
                                 get_pymbolic_indices,
@@ -69,10 +70,13 @@ class BasisMixinBase(object):
     def implement_apply_function_gradient(self, element, restriction, index):
         raise NotImplementedError("Basis Mixins should implement linearization point gradient evaluation")
 
+    def implement_grid_function(self, coeff, restriction, grad):
+        raise NotImplementedError("Basis Mixins should implement grid function evaluation")
+
 
 @basis_mixin("generic")
 class GenericBasisMixin(BasisMixinBase):
-    def lfs_inames(self, element, restriction, number, context):
+    def lfs_inames(self, element, restriction, number, context=""):
         return (lfs_iname(element, restriction, number, context),)
 
     def implement_basis(self, element, restriction, number, context=''):
@@ -210,7 +214,7 @@ class GenericBasisMixin(BasisMixinBase):
         from dune.codegen.tools import maybe_wrap_subscript
         basis = maybe_wrap_subscript(basis, prim.Variable(dimindex))
 
-        # TODO get rif of this
+        # TODO get rid of this
         if get_form_option("blockstructured"):
             from dune.codegen.blockstructured.argument import pymbolic_coefficient
             coeff = pymbolic_coefficient(container, lfs, sub_element, basisindex)
@@ -227,6 +231,43 @@ class GenericBasisMixin(BasisMixinBase):
                     forced_iname_deps_is_final=True,
                     )
 
+    def implement_gridfunction(self, coeff, restriction, grad):
+        name = "coeff{}{}".format(coeff.count(), "_grad" if grad else "")
+        self.define_grid_function(name, coeff, restriction, grad)
+        return prim.Subscript(prim.Variable(name), (0,))
+
+    def define_grid_function(self, name, coeff, restriction, grad):
+        diffOrder = 1 if grad else 0
+
+        gridfunction = name_gridfunction_member(coeff, restriction, diffOrder)
+        bind_gridfunction_to_element(gridfunction, restriction)
+
+        temporary_variable(name,
+                           shape=(1,) + (world_dimension(),) * diffOrder,
+                           decl_method=declare_grid_function_range(gridfunction),
+                           managed=False,
+                           )
+
+        quadpos = self.to_cell(self.quadrature_position())
+        instruction(code="{} = {}({});".format(name, gridfunction, quadpos),
+                    assignees=frozenset({name}),
+                    within_inames=frozenset(self.quadrature_inames()),
+                    within_inames_is_final=True,
+                    )
+
+
+@preamble
+def bind_gridfunction_to_element(gf, restriction):
+    element = name_cell(restriction)
+    return "{}.bind({});".format(gf, element)
+
+
+def declare_grid_function_range(gridfunction):
+    def _decl(name, kernel, decl_info):
+        return "typename decltype({})::Range {};".format(gridfunction, name)
+
+    return _decl
+
 
 @backend(interface="typedef_localbasis")
 @class_member(classtag="operator")
diff --git a/python/dune/codegen/pdelab/localoperator.py b/python/dune/codegen/pdelab/localoperator.py
index 999f5014..db5ca426 100644
--- a/python/dune/codegen/pdelab/localoperator.py
+++ b/python/dune/codegen/pdelab/localoperator.py
@@ -416,10 +416,10 @@ def get_accumulation_info(expr, visitor):
     if visitor.measure == 'exterior_facet':
         restriction = Restriction.POSITIVE
 
-    inames = visitor.interface.lfs_inames(leaf_element,
-                                          restriction,
-                                          expr.number()
-                                          )
+    inames = visitor.lfs_inames(leaf_element,
+                                restriction,
+                                expr.number()
+                                )
 
     return PDELabAccumulationInfo(element=expr.ufl_element(),
                                   element_index=element_index,
diff --git a/python/dune/codegen/ufl/visitor.py b/python/dune/codegen/ufl/visitor.py
index 376fe9d5..b40f330e 100644
--- a/python/dune/codegen/ufl/visitor.py
+++ b/python/dune/codegen/ufl/visitor.py
@@ -176,7 +176,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
             if self.reference_grad:
                 raise CodegenUFLError("Coefficient gradients should not be transformed to reference element")
 
-            return self.interface.pymbolic_gridfunction(o, restriction, self.grad)
+            return self.implement_gridfunction(o, restriction, self.grad)
 
     def variable(self, o):
         # Right now only scalar varibables are supported
@@ -232,7 +232,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
             if index in self.indexmap:
                 return self.indexmap[index]
             else:
-                return Variable(self.interface.name_index(index))
+                raise CodegenUFLError("Index should have been unrolled!")
 
     def multi_index(self, o):
         return tuple(self._index_or_fixed_index(i) for i in o)
-- 
GitLab