from dune.perftool.generation import (include_file,
                                      preamble,
                                      )
from dune.perftool.pdelab.driver import (get_formdata,
                                         get_mass_formdata,
                                         get_trial_element,
                                         is_linear,
                                         name_initree,
                                         preprocess_leaf_data,
                                         )
from dune.perftool.pdelab.driver.gridfunctionspace import (name_gfs,
                                                           type_range,
                                                           )
from dune.perftool.pdelab.driver.gridoperator import (name_gridoperator,
                                                      name_parameters,
                                                      type_gridoperator,)
from dune.perftool.pdelab.driver.constraints import (name_bctype_function,
                                                     name_constraintscontainer,
                                                     )
from dune.perftool.pdelab.driver.interpolate import name_boundary_function
from dune.perftool.pdelab.driver.solve import (print_matrix,
                                               print_residual,
                                               name_stationarynonlinearproblemsolver,
                                               name_vector,
                                               type_stationarynonlinearproblemssolver,
                                               type_vector,
                                               )
from dune.perftool.pdelab.driver.vtk import (name_vtk_sequence_writer,
                                             visualize_initial_condition,
                                             )
from dune.perftool.options import get_option


def solve_instationary():
    # Test if form is linear in ansatzfunction
    linear = is_linear()

    # Test wether we want to do matrix free operator evaluation
    matrix_free = get_option('matrix_free')

    # Create time loop
    if linear and matrix_free:
        assert False
    elif linear and not matrix_free:
        time_loop()
    if not linear and matrix_free:
        assert False
    elif not linear and not matrix_free:
        assert False

    print_residual()
    print_matrix()


@preamble
def time_loop():
    ini = name_initree()
    formdata = get_formdata()
    params = name_parameters(formdata)
    time = name_time()
    element = get_trial_element()
    is_dirichlet = preprocess_leaf_data(element, "is_dirichlet")
    bctype = name_bctype_function(element, is_dirichlet)
    gfs = name_gfs(element, is_dirichlet)
    cc = name_constraintscontainer()
    vector_type = type_vector(formdata)
    vector = name_vector(formdata)

    # Choose between explicit and implicit time stepping
    explicit = get_option('explicit_time_stepping')
    if explicit:
        osm = name_explicitonestepmethod()
        apply_call = "{}.apply(time, dt, {}, {}new);".format(osm, vector, vector)
    else:
        dirichlet = preprocess_leaf_data(element, "dirichlet_expression")
        boundary = name_boundary_function(element, dirichlet)
        osm = name_onestepmethod()
        apply_call = "{}.apply(time, dt, {}, {}, {}new);".format(osm, vector, boundary, vector)

    # Setup visualization
    visualize_initial_condition()
    vtk_sequence_writer = name_vtk_sequence_writer()

    return ["",
            "double T = {}.get<double>(\"instat.T\", 1.0);".format(ini),
            "double dt = {}.get<double>(\"instat.dt\", 0.1);".format(ini),
            "while (time<T-1e-8){",
            "  // Assemble constraints for new time step",
            "  {}.setTime({}+dt);".format(params, time),
            "  Dune::PDELab::constraints({}, {}, {});".format(bctype, gfs, cc),
            "",
            "  // Do time step",
            "  {} {}new({});".format(vector_type, vector, vector),
            "  {}".format(apply_call),
            "",
            "  // Accept new time step",
            "  {} = {}new;".format(vector, vector),
            "  time += dt;",
            "",
            "  // Output to VTK File",
            "  {}.write({}, Dune::VTK::appendedraw);".format(vtk_sequence_writer, time),
            "}",
            ""]


@preamble
def define_time(name):
    return "double {} = 0.0;".format(name)


def name_time():
    define_time("time")
    return "time"



@preamble
def typedef_timesteppingmethod(name):
    r_type = type_range()
    explicit = get_option('explicit_time_stepping')
    if explicit:
        return "using {} = Dune::PDELab::ExplicitEulerParameter<{}>;".format(name, r_type)
    else:
        return "using {} = Dune::PDELab::OneStepThetaParameter<{}>;".format(name, r_type)


def type_timesteppingmethod():
    typedef_timesteppingmethod("TSM")
    return "TSM"


@preamble
def define_timesteppingmethod(name):
    tsm_type = type_timesteppingmethod()
    explicit = get_option('explicit_time_stepping')
    if explicit:
        return "{} {};".format(tsm_type, name)
    else:
        return "{} {}(1.0);".format(tsm_type, name)


def name_timesteppingmethod():
    define_timesteppingmethod("tsm")
    return "tsm"


@preamble
def typedef_instationarygridoperator(name):
    include_file("dune/pdelab/gridoperator/onestep.hh", filetag="driver")
    go_type = type_gridoperator(get_formdata())
    mass_go_type = type_gridoperator(get_mass_formdata())
    explicit = get_option('explicit_time_stepping')
    if explicit:
        return "using {} = Dune::PDELab::OneStepGridOperator<{},{},false>;".format(name, go_type, mass_go_type)
    else:
        return "using {} = Dune::PDELab::OneStepGridOperator<{},{}>;".format(name, go_type, mass_go_type)


def type_instationarygridoperator():
    typedef_instationarygridoperator("IGO")
    return "IGO"


@preamble
def define_instationarygridoperator(name):
    igo_type = type_instationarygridoperator()
    go = name_gridoperator(get_formdata())
    mass_go = name_gridoperator(get_mass_formdata())
    return "{} {}({}, {});".format(igo_type, name, go, mass_go)


def name_instationarygridoperator():
    define_instationarygridoperator("igo")
    return "igo"


@preamble
def typedef_onestepmethod(name):
    r_type = type_range()
    igo_type = type_instationarygridoperator()
    snp_type = type_stationarynonlinearproblemssolver(igo_type)
    vector_type = type_vector(get_formdata())
    return "using {} = Dune::PDELab::OneStepMethod<{}, {}, {}, {}, {}>;".format(name, r_type, igo_type, snp_type, vector_type, vector_type)


def type_onestepmethod():
    typedef_onestepmethod("OSM")
    return "OSM"


@preamble
def define_onestepmethod(name):
    ilptype = type_onestepmethod()
    tsm = name_timesteppingmethod()
    igo_type = type_instationarygridoperator()
    igo = name_instationarygridoperator()
    snp = name_stationarynonlinearproblemsolver(igo_type, igo)
    return "{} {}({},{},{});".format(ilptype, name, tsm, igo, snp)


def name_onestepmethod():
    define_onestepmethod("osm")
    return "osm"


@preamble
def typedef_explicitonestepmethod(name):
    r_type = type_range()
    igo_type = type_instationarygridoperator()
    ls_type = type_linearsolver()
    vector_type = type_vector(get_formdata())
    return "using {} = Dune::PDELab::ExplicitOneStepMethod<{}, {}, {}, {}>;".format(name, r_type, igo_type, ls_type, vector_type)


def type_explicitonestepmethod():
    typedef_explicitonestepmethod("EOSM")
    return "EOSM"


@preamble
def define_explicitonestepmethod(name):
    eosm_type = type_explicitonestepmethod()
    tsm = name_timesteppingmethod()
    igo = name_instationarygridoperator()
    ls = name_linearsolver()
    return "{} {}({}, {}, {});".format(eosm_type, name, tsm, igo, ls)


def name_explicitonestepmethod():
    define_explicitonestepmethod("eosm")
    return "eosm"