Skip to content
Snippets Groups Projects
Commit 9bda3dfe authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Update adoint control code path

parent 3cb3511c
No related branches found
No related tags found
No related merge requests found
......@@ -55,9 +55,6 @@ class CodegenGlobalOptionsArray(ImmutableRecord):
target_name = CodegenOption(default=None, helpstr="The target name from CMake")
operator_to_build = CodegenOption(default=None, helpstr="The operators from the list that is about to be build now. CMake sets this one!!!")
debug_interpolate_input = CodegenOption(default=False, helpstr="Should the input for printresidual and printmatix be interpolated (instead of random input).")
geometry_mixins = CodegenOption(default="generic", helpstr="A comma separated list of mixin identifiers to use for geometries. Currently implemented mixins: generic, axiparallel, equidistant")
quadrature_mixins = CodegenOption(default="generic", helpstr="A comma separated list of mixin identifiers to use for quadrature. Currently implemented: generic")
basis_mixins = CodegenOption(default="generic", helpstr="A comma separated list of mixin identifiers to use for basis function evaluation. Currently implemented: generic")
# Arguments that are mainly to be set by logic depending on other options
max_vector_width = CodegenOption(default=256, helpstr=None)
......@@ -113,7 +110,7 @@ class CodegenFormOptionsArray(ImmutableRecord):
geometry_mixins = CodegenOption(default="generic", helpstr="A comma separated list of mixin identifiers to use for geometries. Currently implemented mixins: generic, axiparallel, equidistant, sumfact_multilinear, sumfact_axiparallel, sumfact_equidistant")
quadrature_mixins = CodegenOption(default="generic", helpstr="A comma separated list of mixin identifiers to use for quadrature. Currently implemented: generic, sumfact")
basis_mixins = CodegenOption(default="generic", helpstr="A comma separated list of mixin identifiers to use for basis function evaluation. Currently implemented: generic, sumfact")
accumulation_mixins = CodegenOption(default="generic", helpstr="A comma separated list of mixin identifiers to use for accumulation. Currently implemented: generic, sumfact")
accumulation_mixins = CodegenOption(default="generic", helpstr="A comma separated list of mixin identifiers to use for accumulation. Currently implemented: generic, sumfact, control")
enable_volume = CodegenOption(default=True, helpstr="Whether to assemble volume integrals")
enable_skeleton = CodegenOption(default=True, helpstr="Whether to assemble skeleton integrals")
enable_boundary = CodegenOption(default=True, helpstr="Whether to assemble boundary integrals")
......@@ -187,6 +184,9 @@ def process_form_options(opt, form):
accumulation_mixins="sumfact",
)
if opt.control:
opt = opt.copy(accumulation_mixins="control")
if opt.numerical_jacobian:
opt = opt.copy(generate_jacobians=False, generate_jacobian_apply=False)
......
......@@ -20,34 +20,10 @@ class PDELabInterface(object):
# The visitor instance will be registered by its init method
self.visitor = None
#
# TODO: The following ones are actually entirely PDELab independent!
# They should be placed elsewhere and be used directly in the visitor.
#
def component_iname(self, context=None, count=None):
return component_iname(context=context, count=count)
def name_index(self, ind):
return name_index(ind)
#
# Local function space related generator functions
#
def lfs_inames(self, element, restriction, number=None, context=''):
return lfs_inames(element, restriction, number, context)
def initialize_function_spaces(self, expr, visitor):
from dune.codegen.pdelab.spaces import initialize_function_spaces
return initialize_function_spaces(expr, visitor)
#
# Test and trial function related generator functions
#
def pymbolic_gridfunction(self, coeff, restriction, grad):
return pymbolic_gridfunction(coeff, restriction, grad)
#
# Tensor expression related generator functions
#
......
......@@ -8,7 +8,8 @@ from loopy.types import NumpyType
import pymbolic.primitives as prim
from dune.codegen.generation import (class_member,
from dune.codegen.generation import (accumulation_mixin,
class_member,
constructor_parameter,
function_mangler,
get_global_context_value,
......@@ -24,6 +25,8 @@ from dune.codegen.pdelab import PDELabInterface
from dune.codegen.pdelab.localoperator import (boundary_predicates,
determine_accumulation_space,
extract_kernel_from_cache,
GenericAccumulationMixin,
get_visitor,
)
......@@ -61,7 +64,7 @@ def generate_accumulation_instruction(expr, visitor, accumulation_index, number_
expr = prim.Sum((assignee, expr))
from dune.codegen.generation import instruction
quad_inames = visitor.interface.quadrature_inames()
quad_inames = visitor.quadrature_inames()
instruction(assignee=assignee,
expression=expr,
forced_iname_deps=frozenset(quad_inames),
......@@ -69,44 +72,22 @@ def generate_accumulation_instruction(expr, visitor, accumulation_index, number_
)
def list_accumulation_infos(expr, visitor):
return ["control", ]
class ControlInterface(PDELabInterface):
"""Interface for generating the control localoperator
In this case we will not accumulate in the residual vector but use
a class member representing dJdm instead.
"""
def __init__(self, accumulation_index, number_of_controls):
"""Create ControlInterface
Arguments:
----------
accumulation_index: In which component of the dJdm should be accumulated.
number_of_controls: Number of components of dJdm. Needed for creating the member variable.
"""
@accumulation_mixin("control")
class AdjointAccumulationMixin(GenericAccumulationMixin):
def set_adjoint_information(self, accumulation_index, number_of_controls):
self.accumulation_index = accumulation_index
self.number_of_controls = number_of_controls
def list_accumulation_infos(self, expr, visitor):
return list_accumulation_infos(expr, visitor)
def list_accumulation_infos(self, expr):
return ["control"]
def generate_accumulation_instruction(self, expr, visitor):
def generate_accumulation_instruction(self, expr):
return generate_accumulation_instruction(expr,
visitor,
self,
self.accumulation_index,
self.number_of_controls)
def get_visitor(measure, subdomain_id, accumulation_index, number_of_controls):
interface = ControlInterface(accumulation_index, number_of_controls)
from dune.codegen.ufl.visitor import UFL2LoopyVisitor
return UFL2LoopyVisitor(interface, measure, subdomain_id)
def visit_integral(integral, accumulation_index, number_of_controls):
integrand = integral.integrand()
measure = integral.integral_type()
......@@ -114,7 +95,8 @@ def visit_integral(integral, accumulation_index, number_of_controls):
# The visitor needs to know about the current index and the number
# of controls in order to generate the accumulation instruction
visitor = get_visitor(measure, subdomain_id, accumulation_index, number_of_controls)
visitor = get_visitor(measure, subdomain_id)
visitor.set_adjoint_information(accumulation_index, number_of_controls)
# Start the visiting process!
visitor.accumulate(integrand)
......@@ -162,7 +144,6 @@ def generate_kernel(forms):
return knl
# @backend(interface="generate_kernels_per_integral")
def control_generate_kernels_per_integral(forms):
"""For the control problem forms will have one form for every
measure. Every form will only contain integrals of one type.
......
......@@ -28,6 +28,7 @@ from dune.codegen.pdelab.geometry import (component_iname,
)
from dune.codegen.pdelab.localoperator import (lop_template_ansatz_gfs,
lop_template_test_gfs,
name_gridfunction_member,
)
from dune.codegen.tools import (get_pymbolic_basename,
get_pymbolic_indices,
......@@ -69,10 +70,13 @@ class BasisMixinBase(object):
def implement_apply_function_gradient(self, element, restriction, index):
raise NotImplementedError("Basis Mixins should implement linearization point gradient evaluation")
def implement_grid_function(self, coeff, restriction, grad):
raise NotImplementedError("Basis Mixins should implement grid function evaluation")
@basis_mixin("generic")
class GenericBasisMixin(BasisMixinBase):
def lfs_inames(self, element, restriction, number, context):
def lfs_inames(self, element, restriction, number, context=""):
return (lfs_iname(element, restriction, number, context),)
def implement_basis(self, element, restriction, number, context=''):
......@@ -210,7 +214,7 @@ class GenericBasisMixin(BasisMixinBase):
from dune.codegen.tools import maybe_wrap_subscript
basis = maybe_wrap_subscript(basis, prim.Variable(dimindex))
# TODO get rif of this
# TODO get rid of this
if get_form_option("blockstructured"):
from dune.codegen.blockstructured.argument import pymbolic_coefficient
coeff = pymbolic_coefficient(container, lfs, sub_element, basisindex)
......@@ -227,6 +231,43 @@ class GenericBasisMixin(BasisMixinBase):
forced_iname_deps_is_final=True,
)
def implement_gridfunction(self, coeff, restriction, grad):
name = "coeff{}{}".format(coeff.count(), "_grad" if grad else "")
self.define_grid_function(name, coeff, restriction, grad)
return prim.Subscript(prim.Variable(name), (0,))
def define_grid_function(self, name, coeff, restriction, grad):
diffOrder = 1 if grad else 0
gridfunction = name_gridfunction_member(coeff, restriction, diffOrder)
bind_gridfunction_to_element(gridfunction, restriction)
temporary_variable(name,
shape=(1,) + (world_dimension(),) * diffOrder,
decl_method=declare_grid_function_range(gridfunction),
managed=False,
)
quadpos = self.to_cell(self.quadrature_position())
instruction(code="{} = {}({});".format(name, gridfunction, quadpos),
assignees=frozenset({name}),
within_inames=frozenset(self.quadrature_inames()),
within_inames_is_final=True,
)
@preamble
def bind_gridfunction_to_element(gf, restriction):
element = name_cell(restriction)
return "{}.bind({});".format(gf, element)
def declare_grid_function_range(gridfunction):
def _decl(name, kernel, decl_info):
return "typename decltype({})::Range {};".format(gridfunction, name)
return _decl
@backend(interface="typedef_localbasis")
@class_member(classtag="operator")
......
......@@ -416,10 +416,10 @@ def get_accumulation_info(expr, visitor):
if visitor.measure == 'exterior_facet':
restriction = Restriction.POSITIVE
inames = visitor.interface.lfs_inames(leaf_element,
restriction,
expr.number()
)
inames = visitor.lfs_inames(leaf_element,
restriction,
expr.number()
)
return PDELabAccumulationInfo(element=expr.ufl_element(),
element_index=element_index,
......
......@@ -176,7 +176,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
if self.reference_grad:
raise CodegenUFLError("Coefficient gradients should not be transformed to reference element")
return self.interface.pymbolic_gridfunction(o, restriction, self.grad)
return self.implement_gridfunction(o, restriction, self.grad)
def variable(self, o):
# Right now only scalar varibables are supported
......@@ -232,7 +232,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
if index in self.indexmap:
return self.indexmap[index]
else:
return Variable(self.interface.name_index(index))
raise CodegenUFLError("Index should have been unrolled!")
def multi_index(self, o):
return tuple(self._index_or_fixed_index(i) for i in o)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment