From a67b53506f171f63b13e4461d939ae2b4a638550 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ren=C3=A9=20He=C3=9F?= <rene.hess@iwr.uni-heidelberg.de>
Date: Thu, 15 Feb 2018 13:31:46 +0100
Subject: [PATCH] Generate control local operator through seperated interface

The control operator doesn't accumulate in the residual. Instead the
localoperator has a member vector where the value of the derivative is
accumulated.
---
 python/dune/perftool/pdelab/adjoint.py       | 128 +++++++++++++++++++
 python/dune/perftool/pdelab/localoperator.py |  31 ++++-
 2 files changed, 157 insertions(+), 2 deletions(-)
 create mode 100644 python/dune/perftool/pdelab/adjoint.py

diff --git a/python/dune/perftool/pdelab/adjoint.py b/python/dune/perftool/pdelab/adjoint.py
new file mode 100644
index 00000000..f957656c
--- /dev/null
+++ b/python/dune/perftool/pdelab/adjoint.py
@@ -0,0 +1,128 @@
+import logging
+
+import numpy
+
+from loopy import CallMangleInfo
+from loopy.symbolic import FunctionIdentifier
+from loopy.types import NumpyType
+
+import pymbolic.primitives as prim
+
+from dune.perftool.generation import (class_member,
+                                      constructor_parameter,
+                                      function_mangler,
+                                      get_global_context_value,
+                                      global_context,
+                                      globalarg,
+                                      initializer_list,
+                                      template_parameter,
+                                      )
+from dune.perftool.options import (get_form_option,
+                                   )
+from dune.perftool.loopy.target import dtype_floatingpoint
+from dune.perftool.pdelab import PDELabInterface
+from dune.perftool.pdelab.localoperator import (boundary_predicates,
+                                                determine_accumulation_space,
+                                                extract_kernel_from_cache,
+                                                )
+
+
+@template_parameter(classtag="operator")
+def type_dJdm():
+    return "DJDM_VEC"
+
+
+def name_dJdm_constructor_argument(name):
+    _type = type_dJdm()
+    constructor_name = name + "_"
+    constructor_parameter("{}&".format(_type), constructor_name, classtag="operator")
+    return constructor_name
+
+
+@class_member(classtag="operator")
+def define_dJdm_member(name):
+    _type = type_dJdm()
+    param = name_dJdm_constructor_argument(name)
+    initializer_list(name, [param,], classtag="operator")
+    return "{}& {};".format(_type, name)
+
+
+def generate_accumulation_instruction(expr, visitor):
+    accumvar = "dJdm"
+    assignee = prim.Variable(accumvar)
+    shape = ()
+    define_dJdm_member(accumvar)
+    globalarg(accumvar, shape=shape)
+    quad_inames = visitor.interface.quadrature_inames()
+    from dune.perftool.generation import instruction
+    expr = prim.Sum((assignee, expr))
+    instruction(assignee=assignee,
+                expression=expr,
+                forced_iname_deps=frozenset(quad_inames),
+                forced_iname_deps_is_final=True,
+                )
+
+def list_accumulation_infos(expr, visitor):
+    # TODO: This object should probably contain the flat index of the control?
+    return ["control",]
+
+class AdjointInterface(PDELabInterface):
+    def list_accumulation_infos(self, expr, visitor):
+        return list_accumulation_infos(expr, visitor)
+
+    def generate_accumulation_instruction(self, expr, visitor):
+        return generate_accumulation_instruction(expr, visitor)
+
+
+def get_visitor(measure, subdomain_id):
+    interface = AdjointInterface()
+    from dune.perftool.ufl.visitor import UFL2LoopyVisitor
+    return UFL2LoopyVisitor(interface, measure, subdomain_id)
+
+
+def visit_integral(integral):
+    integrand = integral.integrand()
+    measure = integral.integral_type()
+    subdomain_id = integral.subdomain_id()
+
+    # Start the visiting process!
+    visitor = get_visitor(measure, subdomain_id)
+    visitor.accumulate(integrand)
+
+
+def generate_kernel(integrals):
+    logger = logging.getLogger(__name__)
+
+    # Visit all integrals once to collect information (dry-run)!
+    logger.debug('generate_kernel: visit_integrals (dry run)')
+    with global_context(dry_run=True):
+        for integral in integrals:
+            visit_integral(integral)
+
+    # Now perform some checks on what should be done
+    from dune.perftool.sumfact.vectorization import decide_vectorization_strategy
+    logger.debug('generate_kernel: decide_vectorization_strategy')
+    decide_vectorization_strategy()
+
+    # Delete the cache contents and do the real thing!
+    logger.debug('generate_kernel: visit_integrals (no dry run)')
+    from dune.perftool.generation import delete_cache_items
+    delete_cache_items("kernel_default")
+    for integral in integrals:
+        visit_integral(integral)
+    knl = extract_kernel_from_cache("kernel_default")
+    delete_cache_items("kernel_default")
+
+    # Reset the quadrature degree
+    from dune.perftool.sumfact.tabulation import set_quadrature_points
+    set_quadrature_points(None)
+
+    # Clean the cache from any data collected after the dry run
+    delete_cache_items("dryrundata")
+
+    return knl
+
+
+# @backend(interface="generate_kernels_per_integral")
+def adjoint_generate_kernels_per_integral(integrals):
+    yield generate_kernel(integrals)
diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py
index 68bf0698..d7162262 100644
--- a/python/dune/perftool/pdelab/localoperator.py
+++ b/python/dune/perftool/pdelab/localoperator.py
@@ -879,7 +879,31 @@ def generate_jacobian_kernels(form, original_form):
 
 
 def generate_control_kernels(forms):
-    pass
+    # TODO: Implement case of multiple controls
+    assert len(forms) == 1
+    form = forms[0]
+
+    logger = logging.getLogger(__name__)
+    with global_context(form_type='residual'):
+        operator_kernels = {}
+
+        # Generate the necessary residual methods
+        for measure in set(i.integral_type() for i in form.integrals()):
+            logger.info("generate_control_kernels: measure {}".format(measure))
+            with global_context(integral_type=measure):
+                enum_pattern()
+                pattern_baseclass()
+                enum_alpha()
+
+                from dune.perftool.pdelab.signatures import assembler_routine_name
+                with global_context(kernel=assembler_routine_name()):
+                    # TODO: make backend switch for PDELab/sumfact
+                    from dune.perftool.pdelab.adjoint import adjoint_generate_kernels_per_integral
+                    kernel = [k for k in adjoint_generate_kernels_per_integral(form.integrals_by_type(measure))]
+
+            operator_kernels[(measure, 'residual')] = kernel
+
+        return operator_kernels
 
 
 def generate_localoperator_kernels(operator):
@@ -937,6 +961,9 @@ def generate_localoperator_kernels(operator):
         controls = [data.object_by_name[ctrl.strip()] for ctrl in get_form_option("control_variable").split(",")]
         assert len(controls) == 1
 
+        from ufl import action, diff
+        from ufl.classes import Coefficient
+
         # We need to transform numpy ints to python native ints
         def _unravel(flat_index, shape):
             multi_index = np.unravel_index(flat_index, shape)
@@ -948,7 +975,7 @@ def generate_localoperator_kernels(operator):
         coeff = Coefficient(element, count=3)
         for control in controls:
             shape = control.ufl_shape
-            flat_length = np.prod(shape)
+            flat_length = int(np.prod(shape))
             for i in range(flat_length):
                 c = control[_unravel(i, shape)]
                 control_form = diff(original_form, control)
-- 
GitLab