Skip to content
Snippets Groups Projects
switch.py 5.28 KiB
Newer Older
""" boundary and skeleton integrals come in variants in sum factorization - implement the switch! """

from dune.perftool.generation import (backend,
                                      get_global_context_value,
                                      global_context,
                                      )
from dune.perftool.pdelab.geometry import world_dimension
from dune.perftool.pdelab.localoperator import generate_kernel
from dune.perftool.pdelab.signatures import (assembly_routine_args,
                                             assembly_routine_signature,
                                             kernel_name,
                                             )
from dune.perftool.cgen.clazz import ClassMember


@backend(interface="generate_kernels_per_integral", name="sumfact")
def generate_kernels_per_integral(integrals):
    dim = get_global_context_value("formdata").geometric_dimension
    measure = get_global_context_value("integral_type")

    if measure == "cell":
        yield generate_kernel(integrals)

    if measure == "exterior_facet":
        # Generate all necessary kernels
        for facedir in range(dim):
            for facemod in range(2):
                with global_context(facedir_s=facedir, facemod_s=facemod):
                    yield generate_kernel(integrals)

        # Generate switch statement
        yield generate_exterior_facet_switch()

    if measure == "interior_facet":
        # Generate all necessary kernels
        for facedir_s in range(dim):
            for facemod_s in range(2):
                for facedir_n in range(dim):
                    for facemod_n in range(2):
                        with global_context(facedir_s=facedir_s, facemod_s=facemod_s, facedir_n=facedir_n, facemod_n=facemod_n):
                            yield generate_kernel(integrals)

        # Generate switch statement
        yield generate_interior_facet_switch()


def get_kernel_name(facedir_s=None, facemod_s=None, facedir_n=None, facemod_n=None):
    with global_context(facedir_s=facedir_s, facemod_s=facemod_s, facedir_n=facedir_n, facemod_n=facemod_n):
        return kernel_name()


def generate_exterior_facet_switch():
    # Extract the signature
    signature = assembly_routine_signature()
    args = ", ".join(tuple(a for c, a in assembly_routine_args()))
    dim = world_dimension()

    # Construct the switch statement
    block = []
    block.append("{")
    block.append("  size_t variant = ig.indexInInside();")
    block.append("  switch(variant)")
    block.append("  {")

    for facedir_s in range(dim):
        for facemod_s in range(2):
            block.append("    case {}: {}({}); break;".format(dim * facedir_s + facemod_s,
                                                              get_kernel_name(facedir_s=facedir_s,
                                                                              facemod_s=facemod_s,
                                                                              ),
                                                              args))

    block.append("  }")
    block.append("}")

    return ClassMember(signature + block)


def generate_interior_facet_switch():
    # Extract the signature
    signature = assembly_routine_signature()
    args = ", ".join(tuple(a for c, a in assembly_routine_args()))
    dim = world_dimension()

    # Construct the switch statement
    block = []
    block.append("{")
    block.append("  size_t variant = ig.indexInOutside() + 6 * ig.indexInInside();")
    block.append("  switch(variant)")
    block.append("  {")

    for facedir_s in range(dim):
        for facemod_s in range(2):
            for facedir_n in range(dim):
                for facemod_n in range(2):
                    block.append("    case {}: {}({}); break;".format((dim * facedir_s + facemod_s) * (2 * dim) + dim * facedir_n + facemod_n,
                                                                      get_kernel_name(facedir_s=facedir_s,
                                                                                      facemod_s=facemod_s,
                                                                                      facedir_n=facedir_n,
                                                                                      facemod_n=facemod_n,
                                                                                      ),
                                                                      args))

    block.append("  }")
    block.append("}")

    return ClassMember(signature + block)


def get_facedir(restriction):
    from dune.perftool.pdelab.restriction import Restriction
    if restriction == Restriction.NEGATIVE or get_global_context_value("integral_type") == "exterior_facet":
        return get_global_context_value("facedir_s")
    if restriction == Restriction.POSITIVE:
        return get_global_context_value("facedir_n")
    if restriction == Restriction.NONE:
        return None
    assert False


def get_facemod(restriction):
    from dune.perftool.pdelab.restriction import Restriction
    if restriction == Restriction.NEGATIVE or get_global_context_value("integral_type") == "exterior_facet":
        return get_global_context_value("facemod_s")
    if restriction == Restriction.POSITIVE:
        return get_global_context_value("facemod_n")
    if restriction == Restriction.NONE:
        return None
    assert False