Skip to content
Snippets Groups Projects
error.py 7.29 KiB
Newer Older
""" Generator functions to calculate errors in the driver """

Dominic Kempf's avatar
Dominic Kempf committed
from dune.codegen.generation import (cached,
Dominic Kempf's avatar
Dominic Kempf committed
                                     include_file,
                                     preamble,
                                     )
Dominic Kempf's avatar
Dominic Kempf committed
from dune.codegen.options import get_option
from dune.codegen.pdelab.driver import (get_form_ident,
Dominic Kempf's avatar
Dominic Kempf committed
                                        get_trial_element,
                                        preprocess_leaf_data,
                                        )
from dune.codegen.pdelab.driver.gridfunctionspace import (main_type_trial_gfs,
                                                          name_leafview,
Dominic Kempf's avatar
Dominic Kempf committed
                                                          name_trial_subgfs,
                                                          type_range,
                                                          )
Dominic Kempf's avatar
Dominic Kempf committed
from dune.codegen.pdelab.driver.interpolate import (interpolate_vector,
Dominic Kempf's avatar
Dominic Kempf committed
                                                    name_boundary_function,
                                                    )
Dominic Kempf's avatar
Dominic Kempf committed
from dune.codegen.pdelab.driver.solve import (define_vector,
Dominic Kempf's avatar
Dominic Kempf committed
                                              dune_solve,
                                              main_name_vector,
                                              main_type_vector,
from ufl import MixedElement, TensorElement, VectorElement
@preamble(section="error", kernel="main")
def define_test_fail_variable(name):
    return 'bool {}(false);'.format(name)


def name_test_fail_variable():
    name = "testfail"
    define_test_fail_variable(name)
    return name


def name_exact_solution_gridfunction(treepath):
    element = get_trial_element()
    func = preprocess_leaf_data(element, "exact_solution")
    if isinstance(element, MixedElement):
        index = treepath_to_index(element, treepath)
        func = (func[index],)
        element = element.extract_component(index)[1]
    return name_boundary_function(element, func)


def type_discrete_grid_function(gfs):
    return "{}_DGF".format(gfs.upper())
@preamble(section="error", kernel="main")
def define_discrete_grid_function(gfs, vector_name, dgf_name):
    dgf_type = type_discrete_grid_function(gfs)
    gfs_type = main_type_trial_gfs()
    form_ident = get_form_ident()
    vector_type = main_type_vector(form_ident)
    return ["using {} = Dune::PDELab::DiscreteGridFunction<{}, {}>;".format(dgf_type, gfs_type, vector_type),
            "{} {}(*{},*{});".format(dgf_type, dgf_name, gfs, vector_name)]


def name_discrete_grid_function(gfs, vector_name):
    dgf_name = "{}_dgf".format(gfs)
    define_discrete_grid_function(gfs, vector_name, dgf_name)
    return dgf_name


@preamble(section="error", kernel="main")
def typedef_difference_squared_adapter(name, treepath):
    sol = name_exact_solution_gridfunction(treepath)
    vector = main_name_vector(get_form_ident())
    dgf = name_discrete_grid_function(gfs, vector)
    dgf_type = type_discrete_grid_function(gfs)
    return 'using {} = Dune::PDELab::DifferenceSquaredAdapter<decltype({}), {}>;'.format(name, sol, dgf_type)
def type_difference_squared_adapter(treepath):
    name = 'DifferenceSquaredAdapter_{}'.format("_".join(str(t) for t in treepath))
    typedef_difference_squared_adapter(name, treepath)
    return name
@preamble(section="error", kernel="main")
def define_difference_squared_adapter(name, treepath):
    t = type_difference_squared_adapter(treepath)
    sol = name_exact_solution_gridfunction(treepath)
    vector = main_name_vector(get_form_ident())
    dgf = name_discrete_grid_function(gfs, vector)

    return '{} {}({}, {});'.format(t, name, sol, dgf)
def name_difference_squared_adapter(treepath):
    name = 'dsa_{}'.format("_".join(str(t) for t in treepath))
    define_difference_squared_adapter(name, treepath)
@preamble(section="error", kernel="main")
def _accumulate_L2_squared(treepath):
    dsa = name_difference_squared_adapter(treepath)
    accum_error = name_accumulated_L2_error()

    include_file("dune/pdelab/gridfunctionspace/gridfunctionadapter.hh", filetag="driver")
    include_file("dune/pdelab/common/functionutilities.hh", filetag="driver")

    strtp = ", ".join(str(t) for t in treepath)

    gv = name_leafview()
    sum_error_over_ranks = ""
    if get_option("parallel"):
        sum_error_over_ranks = "  err = {}.comm().sum(err);".format(gv)
    return ["{",
            "  // L2 error squared of difference between numerical",
            "  // solution and the interpolation of exact solution",
            "  // for treepath ({})".format(strtp),
            "  typename decltype({})::Traits::RangeType err(0.0);".format(dsa),
            "  Dune::PDELab::integrateGridFunction({}, err, 10);".format(dsa),
            sum_error_over_ranks,
            "  if ({}.comm().rank() == 0){{".format(gv),
            "    std::cout << \"L2 Error for treepath {}: \" << err << std::endl;".format(strtp),
            "  }"
def get_treepath(element, index):
    if isinstance(element, (VectorElement, TensorElement)):
        return (index,)
    if isinstance(element, MixedElement):
        pos, rest = element.extract_subelement_component(index)
        offset = sum(element.sub_elements()[i].value_size() for i in range(pos))
        return (pos,) + get_treepath(element.sub_elements()[pos], index - offset)
    else:
        return ()


def treepath_to_index(element, treepath, offset=0):
    if len(treepath) == 0:
        return offset
    index = treepath[0]
    offset = offset + sum(element.sub_elements()[i].value_size() for i in range(index))
    subel = element.sub_elements()[index]
    return treepath_to_index(subel, treepath[1:], offset)


def accumulate_L2_squared():
    element = get_trial_element()
    if isinstance(element, MixedElement):
René Heß's avatar
René Heß committed
        tree_pathes = (True,) * element.value_size()
        if get_option("l2error_tree_path") is not None:
            tree_pathes = list(map(int, get_option("l2error_tree_path").split(',')))
            assert len(tree_pathes) == element.value_size()
        for i, path in enumerate(tree_pathes):
            if path:
                _accumulate_L2_squared(get_treepath(element, i))
@preamble(section="error", kernel="main")
def define_accumulated_L2_error(name):
    return "Dune::FieldVector<{}, 1> {}(0.0);".format(t, name)


def name_accumulated_L2_error():
    name = 'l2error'
    define_accumulated_L2_error(name)
    return name


@preamble(section="error", kernel="main")
def compare_L2_squared():
    accumulate_L2_squared()
    gv = name_leafview()

    accum_error = name_accumulated_L2_error()
    fail = name_test_fail_variable()
    return ["using std::abs;",
            "using std::isnan;",
            "if ({}.comm().rank() == 0){{".format(gv),
            "  std::cout << \"\\nl2errorsquared: \" << {} << std::endl << std::endl;".format(accum_error),
            "}",
            "if (isnan({0}[0]) or abs({0}[0])>{1})".format(accum_error, get_option("compare_l2errorsquared")),
            "  {} = true;".format(fail)]
@preamble(section="return_stmt", kernel="main")
def return_statement():
    fail = name_test_fail_variable()
    return "return {};".format(fail)