From 5a60480d7909cbbd7494de98651b95e2f005cf20 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ren=C3=A9=20He=C3=9F?= <rene.hess@iwr.uni-heidelberg.de>
Date: Thu, 15 Feb 2018 16:09:14 +0100
Subject: [PATCH] First version for multiple controls

---
 python/dune/perftool/pdelab/adjoint.py       | 44 +++++++++++++-------
 python/dune/perftool/pdelab/localoperator.py | 26 +++---------
 2 files changed, 34 insertions(+), 36 deletions(-)

diff --git a/python/dune/perftool/pdelab/adjoint.py b/python/dune/perftool/pdelab/adjoint.py
index f957656c..27786e30 100644
--- a/python/dune/perftool/pdelab/adjoint.py
+++ b/python/dune/perftool/pdelab/adjoint.py
@@ -47,10 +47,10 @@ def define_dJdm_member(name):
     return "{}& {};".format(_type, name)
 
 
-def generate_accumulation_instruction(expr, visitor):
+def generate_accumulation_instruction(expr, visitor, accumulation_index, number_of_controls):
     accumvar = "dJdm"
-    assignee = prim.Variable(accumvar)
-    shape = ()
+    assignee = prim.Subscript(prim.Variable(accumvar), accumulation_index)
+    shape = (number_of_controls,)
     define_dJdm_member(accumvar)
     globalarg(accumvar, shape=shape)
     quad_inames = visitor.interface.quadrature_inames()
@@ -63,41 +63,48 @@ def generate_accumulation_instruction(expr, visitor):
                 )
 
 def list_accumulation_infos(expr, visitor):
-    # TODO: This object should probably contain the flat index of the control?
     return ["control",]
 
 class AdjointInterface(PDELabInterface):
+    def __init__(self, accumulation_index, number_of_controls):
+        self.accumulation_index = accumulation_index
+        self.number_of_controls = number_of_controls
+
     def list_accumulation_infos(self, expr, visitor):
         return list_accumulation_infos(expr, visitor)
 
     def generate_accumulation_instruction(self, expr, visitor):
-        return generate_accumulation_instruction(expr, visitor)
+        return generate_accumulation_instruction(expr,
+                                                 visitor,
+                                                 self.accumulation_index,
+                                                 self.number_of_controls)
 
 
-def get_visitor(measure, subdomain_id):
-    interface = AdjointInterface()
+def get_visitor(measure, subdomain_id, accumulation_index, number_of_controls):
+    interface = AdjointInterface(accumulation_index, number_of_controls)
     from dune.perftool.ufl.visitor import UFL2LoopyVisitor
     return UFL2LoopyVisitor(interface, measure, subdomain_id)
 
 
-def visit_integral(integral):
+def visit_integral(integral, accumulation_index, number_of_controls):
     integrand = integral.integrand()
     measure = integral.integral_type()
     subdomain_id = integral.subdomain_id()
 
     # Start the visiting process!
-    visitor = get_visitor(measure, subdomain_id)
+    visitor = get_visitor(measure, subdomain_id, accumulation_index, number_of_controls)
     visitor.accumulate(integrand)
 
 
-def generate_kernel(integrals):
+def generate_kernel(forms):
     logger = logging.getLogger(__name__)
 
     # Visit all integrals once to collect information (dry-run)!
     logger.debug('generate_kernel: visit_integrals (dry run)')
     with global_context(dry_run=True):
-        for integral in integrals:
-            visit_integral(integral)
+        for i, form in enumerate(forms):
+            for integral in form:
+                visit_integral(integral, i, len(forms))
 
     # Now perform some checks on what should be done
     from dune.perftool.sumfact.vectorization import decide_vectorization_strategy
@@ -108,8 +115,9 @@ def generate_kernel(integrals):
     logger.debug('generate_kernel: visit_integrals (no dry run)')
     from dune.perftool.generation import delete_cache_items
     delete_cache_items("kernel_default")
-    for integral in integrals:
-        visit_integral(integral)
+    for i, form in enumerate(forms):
+        for integral in form:
+            visit_integral(integral, i, len(forms))
     knl = extract_kernel_from_cache("kernel_default")
     delete_cache_items("kernel_default")
 
@@ -124,5 +132,9 @@ def generate_kernel(integrals):
 
 
 # @backend(interface="generate_kernels_per_integral")
-def adjoint_generate_kernels_per_integral(integrals):
-    yield generate_kernel(integrals)
+def adjoint_generate_kernels_per_integral(forms):
+    """For the control problem forms will have one form for every
+    measure. Every form will only contain integrals of one type.
+
+    """
+    yield generate_kernel(forms)
diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py
index d7162262..ce685263 100644
--- a/python/dune/perftool/pdelab/localoperator.py
+++ b/python/dune/perftool/pdelab/localoperator.py
@@ -879,16 +879,12 @@ def generate_jacobian_kernels(form, original_form):
 
 
 def generate_control_kernels(forms):
-    # TODO: Implement case of multiple controls
-    assert len(forms) == 1
-    form = forms[0]
-
     logger = logging.getLogger(__name__)
     with global_context(form_type='residual'):
         operator_kernels = {}
 
         # Generate the necessary residual methods
-        for measure in set(i.integral_type() for i in form.integrals()):
+        for measure in set(i.integral_type() for form in forms for i in form.integrals()):
             logger.info("generate_control_kernels: measure {}".format(measure))
             with global_context(integral_type=measure):
                 enum_pattern()
@@ -899,7 +895,9 @@ def generate_control_kernels(forms):
                 with global_context(kernel=assembler_routine_name()):
                     # TODO: make backend switch for PDELab/sumfact
                     from dune.perftool.pdelab.adjoint import adjoint_generate_kernels_per_integral
-                    kernel = [k for k in adjoint_generate_kernels_per_integral(form.integrals_by_type(measure))]
+
+                    forms_measure = [form.integrals_by_type(measure) for form in forms]
+                    kernel = [k for k in adjoint_generate_kernels_per_integral(forms_measure)]
 
             operator_kernels[(measure, 'residual')] = kernel
 
@@ -978,28 +976,16 @@ def generate_localoperator_kernels(operator):
             flat_length = int(np.prod(shape))
             for i in range(flat_length):
                 c = control[_unravel(i, shape)]
-                control_form = diff(original_form, control)
+                control_form = diff(original_form, c)
                 control_form = action(control_form, coeff)
                 objective = data.object_by_name[get_form_option("objective_function")]
-                objective_gradient = diff(objective, control)
+                objective_gradient = diff(objective, c)
                 control_form = control_form + objective_gradient
                 forms.append(preprocess_form(control_form).preprocessed_form)
 
         # Used to create local operator default settings
         form = preprocess_form(original_form).preprocessed_form
 
-        # control = data.object_by_name[get_form_option("control_variable")]
-        # assert control.ufl_shape is ()
-
-        # from ufl import diff, replace
-        # from ufl.classes import Coefficient
-        # control_form = diff(original_form, control)
-        # # element = control_form.coefficients()[0].ufl_element()
-        # # coeff = Coefficient(element, count=3)
-        # # control_form = replace(control_form, {control_form.coefficients()[0]: coeff})
-        # original_form = control_form
-        # form = preprocess_form(control_form).preprocessed_form
-
     else:
         form = preprocess_form(original_form).preprocessed_form
 
-- 
GitLab