Skip to content
Snippets Groups Projects
gridoperator.py 7.93 KiB
Newer Older
from dune.codegen.generation import (class_member,
                                     get_global_context_value,
Dominic Kempf's avatar
Dominic Kempf committed
                                     include_file,
                                     preamble,
                                     )
Dominic Kempf's avatar
Dominic Kempf committed
from dune.codegen.pdelab.driver import (get_cell,
Dominic Kempf's avatar
Dominic Kempf committed
                                        get_dimension,
                                        get_test_element,
                                        get_trial_element,
                                        isQuadrilateral,
                                        name_initree,
                                        )
Dominic Kempf's avatar
Dominic Kempf committed
from dune.codegen.pdelab.driver.constraints import (name_assembled_constraints,
Dominic Kempf's avatar
Dominic Kempf committed
                                                    type_constraintscontainer,
                                                    )
from dune.codegen.pdelab.driver.driverblock import (name_driver_block,
                                                    type_driver_block,
                                                    )
Dominic Kempf's avatar
Dominic Kempf committed
from dune.codegen.pdelab.driver.gridfunctionspace import (name_test_gfs,
Dominic Kempf's avatar
Dominic Kempf committed
                                                          name_trial_gfs,
                                                          preprocess_leaf_data,
                                                          type_domainfield,
                                                          type_range,
                                                          type_test_gfs,
                                                          type_trial_gfs,
                                                          )
Dominic Kempf's avatar
Dominic Kempf committed
from dune.codegen.pdelab.localoperator import localoperator_basename
from dune.codegen.options import get_form_option
@class_member(classtag="driver_block")
def typedef_gridoperator(name, form_ident):
    ugfs = type_trial_gfs()
    vgfs = type_test_gfs()
    lop = type_localoperator(form_ident)
    cc = type_constraintscontainer()
    mb = type_matrixbackend()
    df = type_domainfield()
    r = type_range()
    if get_form_option("fastdg"):
        if not get_form_option("sumfact"):
Dominic Kempf's avatar
Dominic Kempf committed
            raise CodegenCodegenError("FastDGGridOperator is only implemented for sumfactorization.")
        include_file("dune/pdelab/gridoperator/fastdg.hh", filetag="driver")
        return "using {} = Dune::PDELab::FastDGGridOperator<{}, {}, {}, {}, {}, {}, {}, {}, {}>;".format(name, ugfs, vgfs, lop, mb, df, r, r, cc, cc)
    else:
        include_file("dune/pdelab/gridoperator/gridoperator.hh", filetag="driver")
        return "using {} = Dune::PDELab::GridOperator<{}, {}, {}, {}, {}, {}, {}, {}, {}>;".format(name, ugfs, vgfs, lop, mb, df, r, r, cc, cc)


def type_gridoperator(form_ident):
    name = "GO_{}".format(form_ident)
    typedef_gridoperator(name, form_ident)
    return name


@class_member(classtag="driver_block")
def declare_gridoperator(name, form_ident):
    gotype = type_gridoperator(form_ident)
    return "std::shared_ptr<{}> {};".format(gotype, name)


@preamble(section="gridoperator", kernel="driver_block")
def define_gridoperator(name, form_ident):
    declare_gridoperator(name, form_ident)
    gotype = type_gridoperator(form_ident)
    ugfs = name_trial_gfs()
    vgfs = name_test_gfs()
    if ugfs != vgfs:
        raise NotImplementedError("Non-Galerkin methods currently not supported!")
    cc = name_assembled_constraints()
    lop = name_localoperator(form_ident)
    mb = name_matrixbackend()
    return ["{} = std::make_shared<{}>(*{}, *{}, *{}, *{}, *{}, *{});".format(name, gotype, ugfs, cc, vgfs, cc, lop, mb),
            "std::cout << \"gfs with \" << {}->size() << \" dofs generated  \"<< std::endl;".format(ugfs),
            "std::cout << \"cc with \" << {}->size() << \" dofs generated  \"<< std::endl;".format(cc)]
def name_gridoperator(form_ident):
    name = "go_{}".format(form_ident)
    define_gridoperator(name, form_ident)
    driver_block_get_gridoperator(form_ident, name=name)
    return name


@preamble(section="postprocessing", kernel="main")
def main_typedef_gridoperator(name, form_ident):
    driver_block_type = type_driver_block()
    db_gridoperator_type = type_gridoperator(form_ident)
    gridoperator_type = "using {} = {}::{};".format(name, driver_block_type, db_gridoperator_type)
    return gridoperator_type


def main_type_gridoperator(form_ident):
    name = "GridOperator"
    main_typedef_gridoperator(name, form_ident)
    return name


@class_member(classtag="driver_block")
def driver_block_get_gridoperator(form_ident, name=None):
    gridoperator_type = type_gridoperator(form_ident)
    if not name:
        name = name_gridoperator(form_ident)
    return ["std::shared_ptr<{}> getGridOperator(){{".format(gridoperator_type),
            "  return {};".format(name),
            "}"]


@preamble(section="postprocessing", kernel="main")
def main_define_gridoperator(name, form_ident):
    driver_block_name = name_driver_block()
    driver_block_get_gridoperator(form_ident)
    return "auto {} = {}.getGridOperator();".format(name, driver_block_name)


def main_name_gridoperator(form_ident):
    name = "gridOperator"
    main_define_gridoperator(name, form_ident)
    return name


@class_member(classtag="driver_block")
def typedef_localoperator(name, form_ident):
    ugfs = type_trial_gfs()
    vgfs = type_test_gfs()
    filename = get_form_option("filename", form_ident)
    include_file(filename, filetag="driver")
    lopname = localoperator_basename(form_ident)
    return "using {} = {}<{}, {}>;".format(name, lopname, ugfs, vgfs)
def type_localoperator(form_ident):
    name = "LOP_{}".format(form_ident.upper())
    typedef_localoperator(name, form_ident)
@class_member(classtag="driver_block")
def declare_localoperator(name, form_ident):
    loptype = type_localoperator(form_ident)
    return "std::shared_ptr<{}> {};".format(loptype, name)


@preamble(section="gridoperator", kernel="driver_block")
def define_localoperator(name, form_ident):
    declare_localoperator(name, form_ident)
    trial_gfs = name_trial_gfs()
    test_gfs = name_test_gfs()
    loptype = type_localoperator(form_ident)
    ini = name_initree()
    return "{} = std::make_shared<{}>(*{}, *{}, {});".format(name, loptype, trial_gfs, test_gfs, ini)
def name_localoperator(form_ident):
    name = "lop_{}".format(form_ident)
    define_localoperator(name, form_ident)
@preamble(section="gridoperator", kernel="driver_block")
def define_dofestimate(name):
    # Provide a worstcase estimate for the number of entries per row based
    # on the given gridfunction space and cell geometry
    if isQuadrilateral(get_cell()):
        geo_factor = 2**get_dimension()
    else:
        if get_dimension() < 3:
Dominic Kempf's avatar
Dominic Kempf committed
            geo_factor = 3 * get_dimension()
        else:
            # TODO no idea how a generic estimate for 3D simplices looks like
            geo_factor = 12

    gfs = name_trial_gfs()
    ini = name_initree()

    return ["{}->update();".format(gfs),
            "int generic_dof_estimate =  {} * {}->maxLocalSize();".format(geo_factor, gfs),
            "int {} = {}.get<int>(\"istl.number_of_nnz\", generic_dof_estimate);".format(name, ini)]


def name_dofestimate():
    name = "dofestimate"
    define_dofestimate(name)
    return name


@class_member(classtag="driver_block")
def typedef_matrixbackend(name):
    include_file("dune/pdelab/backend/istl.hh", filetag="driver")
Dominic Kempf's avatar
Dominic Kempf committed
    return "using {} = Dune::PDELab::ISTL::BCRSMatrixBackend<>;".format(name)


def type_matrixbackend():
    name = "MatrixBackend"
    typedef_matrixbackend(name)
    return name


@class_member(classtag="driver_block")
def declare_matrixbackend(name):
    mbtype = type_matrixbackend()
    return "std::shared_ptr<{}> {};".format(mbtype, name)


@preamble(section="gridoperator", kernel="driver_block")
def define_matrixbackend(name):
    declare_matrixbackend(name)
    mbtype = type_matrixbackend()
    dof = name_dofestimate()
    return "{} = std::make_shared<{}>({});".format(name, mbtype, dof)


def name_matrixbackend():
    name = "mb"
    define_matrixbackend(name)
    return name