Skip to content
Snippets Groups Projects
constraints.py 9 KiB
Newer Older
from dune.codegen.generation import (class_member,
                                     get_counted_variable,
Dominic Kempf's avatar
Dominic Kempf committed
                                     global_context,
                                     include_file,
                                     preamble,
                                     )
Dominic Kempf's avatar
Dominic Kempf committed
from dune.codegen.pdelab.driver import (FEM_name_mangling,
Dominic Kempf's avatar
Dominic Kempf committed
                                        get_trial_element,
                                        )
Dominic Kempf's avatar
Dominic Kempf committed
from dune.codegen.pdelab.driver.gridfunctionspace import (name_gfs,
Dominic Kempf's avatar
Dominic Kempf committed
                                                          name_leafview,
                                                          name_trial_gfs,
                                                          type_leafview,
Dominic Kempf's avatar
Dominic Kempf committed
                                                          type_range,
                                                          type_trial_gfs,
                                                          preprocess_leaf_data,
                                                          )
from ufl.classes import Expr
from ufl import FiniteElement, MixedElement, TensorElement, VectorElement, TensorProductElement


def name_assembled_constraints():
    name = name_constraintscontainer()
    define_constraintscontainer(name)
    assemble_constraints(name)
    return name


def has_dirichlet_constraints(is_dirichlet):
    if isinstance(is_dirichlet, (tuple, list)):
        return any(has_dirichlet_constraints(d) for d in is_dirichlet)

    if isinstance(is_dirichlet, Expr):
        return True
    else:
        return bool(is_dirichlet)


@preamble(section="constraints", kernel="driver_block")
def assemble_constraints(name):
    element = get_trial_element()
    gfs = name_trial_gfs()
    is_dirichlet = preprocess_leaf_data(element, "is_dirichlet")
    if has_dirichlet_constraints(is_dirichlet):
        bctype_function = name_bctype_grid_function(element, is_dirichlet)
        return "Dune::PDELab::constraints(*{}, *{}, *{});".format(bctype_function,
                                                                  gfs,
                                                                  name,)
        return "Dune::PDELab::constraints(*{}, *{});".format(gfs,
                                                             name,)
#
# Bctype lambda
#


def bctype_lambda(func, local=False):
    from ufl.classes import Expr
    if func is None:
        func = 0.0
    if isinstance(func, (int, float)):
        return "[&](const auto& is, const auto& xl){{ return {}; }}".format(float(func))
    elif isinstance(func, Expr):
        from dune.codegen.pdelab.driver.visitor import ufl_to_code
        return "[&](const auto& is, const auto& xl){{ {}; }}".format(ufl_to_code(func))
    raise ValueError("Expression not understood")


#
# Bctype function
#
# std::function that will be used to create the bctype grid function


@class_member(classtag="driver_block")
def typedef_bctype_function(name):
    leafview_type = type_leafview()
    entity = "typename {}::Intersection".format(leafview_type)
    coordinate = "typename {}::Intersection::LocalCoordinate".format(leafview_type)
    return "using {} = std::function<bool({}, {})>;".format(name, entity, coordinate)


def type_bctype_function():
    name = "BctypeFunction"
    typedef_bctype_function(name)
    return name


@class_member(classtag="driver_block")
def declare_bctype_function(name, element, is_dirichlet):
    bctype_function_type = type_bctype_function()
    return "std::shared_ptr<{}> {};".format(bctype_function_type, name)


@preamble(section="constriants", kernel="driver_block")
def define_bctype_function(name, element, is_dirichlet):
    declare_bctype_function(name, element, is_dirichlet)
    bctype_function_type = type_bctype_function()
    bct_lambda = bctype_lambda(is_dirichlet)
    return "{} = std::make_shared<{}> ({});".format(name, bctype_function_type, bct_lambda)


def name_bctype_function(element, is_dirichlet):
    name = get_counted_variable("bctype_function")
    define_bctype_function(name, element, is_dirichlet)
    return name


#
# Bctype Grid Function
#
# GridFunction that returns 1 if the boundary is constraint (else 0)


@class_member(classtag="driver_block")
def typedef_bctype_grid_function(name):
    bctype_function_type = type_bctype_function()
    return "using {} = Dune::PDELab::LocalCallableToBoundaryConditionAdapter<{}>;".format(name,
                                                                                          bctype_function_type)


def type_bctype_grid_function():
    name = "BctypeGridFunction"
    typedef_bctype_grid_function(name)
    return name


@class_member(classtag="driver_block")
def declare_bctype_grid_function(name):
    bctype_type = type_bctype_grid_function()
    return "std::shared_ptr<{}> {};".format(bctype_type, name)


@preamble(section="constraints", kernel="driver_block")
def define_bctype_grid_function(name, element, is_dirichlet):
    declare_bctype_grid_function(name)
    bctype_type = type_bctype_grid_function()
    bctype_function = name_bctype_function(element, is_dirichlet)
    return "{} = std::make_shared<{}> (*{});".format(name, bctype_type, bctype_function)


def name_bctype_grid_function(element, is_dirichlet):
    # Note: Previously, there was a separate code branch for VectorElement here,
    #       which was implemented through PDELabs Power constraints concept.
    #       However this completely fails if you have different constraints for
    #       the children of a VectorElement. We therefore omit this branch completely.
    if isinstance(element, MixedElement):
        k = 0
        childs = []
        for subel in element.sub_elements():
            childs.append(name_bctype_grid_function(subel, is_dirichlet[k:k + subel.value_size()]))
            k = k + subel.value_size()
        name = "{}".format("_".join(childs))
        define_composite_constraints_parameters(name, element, tuple(childs))
        return name
    else:
        assert isinstance(element, (FiniteElement, TensorProductElement))
        name = get_counted_variable("bctype_grid_function")
        define_bctype_grid_function(name, element, is_dirichlet[0])
#
# Composite constraints parameter
#
@class_member(classtag="driver_block")
def typedef_composite_constraints_parameters(name, element, gfs_tuple):
    assert isinstance(element, MixedElement)
    assert len(element.sub_elements()) == len(gfs_tuple)
    types = []
    for subel, subgfs in zip(element.sub_elements(), gfs_tuple):
        if isinstance(subel, MixedElement):
            types.append(type_composite_constriants_parameters(subel, subgfs))
        else:
            assert isinstance(subel, (FiniteElement, TensorProductElement))
            types.append(type_bctype_grid_function())
    return "using {} = Dune::PDELab::CompositeConstraintsParameters<{}>;".format(name,
                                                                                 ", ".join(t for t in types))


def type_composite_constriants_parameters(element, gfs_tuple):
    if isinstance(gfs_tuple, str):
        gfs_tuple = (gfs_tuple,)
    name = "CompositeConstraintsParameters_{}".format('_'.join(c for c in gfs_tuple))
    if len(element.sub_elements()) == len(gfs_tuple):
        typedef_composite_constraints_parameters(name, element, gfs_tuple)
    return name
@class_member(classtag="driver_block")
def declare_composite_constraints_parameter(name, element, gfs_tuple):
    ccp_type = type_composite_constriants_parameters(element, gfs_tuple)
    return "std::shared_ptr<{}> {};".format(ccp_type, name)
@preamble(section="constraints", kernel="driver_block")
def define_composite_constraints_parameters(name, element, gfs_tuple):
    include_file('dune/pdelab/constraints/common/constraintsparameters.hh', filetag='driver')
    declare_composite_constraints_parameter(name, element, gfs_tuple)
    ccp_type = type_composite_constriants_parameters(element, gfs_tuple)
    return "{} = std::make_shared<{}>({});".format(name, ccp_type, ', '.join('*{}'.format(c) for c in gfs_tuple))

#
# Constraint container
#
@class_member(classtag="driver_block")
def typedef_constraintscontainer(name):
    gfs = type_trial_gfs()
    r = type_range()
    return "using {} = typename {}::template ConstraintsContainer<{}>::Type;".format(name, gfs, r)


def type_constraintscontainer():
    name = "{}_CC".format(type_trial_gfs())
    typedef_constraintscontainer(name)
    return name


@class_member(classtag="driver_block")
def declare_constraintscontainer(name):
    cctype = type_constraintscontainer()
    return "std::shared_ptr<{}> {};".format(cctype, name)


@preamble(section="constraints", kernel="driver_block")
def define_constraintscontainer(name):
    declare_constraintscontainer(name)
    cctype = type_constraintscontainer()
    return ["{} = std::make_shared<{}>();".format(name, cctype),
            "{}->clear();".format(name)]


def name_constraintscontainer():
    gfs = name_trial_gfs()
    name = "{}_cc".format(gfs)
    element = get_trial_element()
    is_dirichlet = preprocess_leaf_data(element, "is_dirichlet")
    define_constraintscontainer(name)
    return name