From c44ec24a8c69a6df18bcc896645fc1e990935f88 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.r.kempf@gmail.com>
Date: Wed, 2 Sep 2015 11:49:28 +0200
Subject: [PATCH] Adapt the driver generation to the new infrastructure.

---
 python/dune/perftool/driver.py | 547 +++++++++++++++++++++++++++++++++
 1 file changed, 547 insertions(+)
 create mode 100644 python/dune/perftool/driver.py

diff --git a/python/dune/perftool/driver.py b/python/dune/perftool/driver.py
new file mode 100644
index 00000000..8df2f262
--- /dev/null
+++ b/python/dune/perftool/driver.py
@@ -0,0 +1,547 @@
+""" The module that drives the generation of the pdelab driver """
+
+from dune.perftool.generation import dune_include, dune_preamble, dune_symbol
+
+# Have a global variable with the entire form data. This allows functions that depend
+# deterministically on the entire data set to directly access it instead of passing it
+# through the entire generator chain.
+_formdata = None
+
+
+def isLagrange(fem):
+    return fem._short_name is 'CG'
+
+def isSimplical(fem):
+    return any(fem._cell._cellname in x for x in ["triangle", "tetrahedron"])
+
+def isQuadrilateral(fem):
+    return any(fem._cell._cellname in x for x in ["vertex", "interval", "quadrilateral", "hexalateral"])
+
+def isPk(fem):
+    return isLagrange(fem) and isSimplical(fem)
+
+def isQk(fem):
+    return isLagrange(fem) and isQuadrilateral(fem)
+
+def FEM_name_mangling(fem):
+    from ufl import MixedElement, VectorElement, FiniteElement
+    if isinstance(fem, MixedElement):
+        name = ""
+        for elem in fem._sub_elements:
+            if name is not "":
+                name = name + "_"
+            name = name + FEM_name_mangling(elem)
+        return name
+    if isinstance(fem, VectorElement):
+        return FEM_name_mangling(fem._sub_elements[0]) + "_" + str(fem._cell._geometric_dimension)
+    if isinstance(fem, FiniteElement):
+        if isPk(fem):
+            return "P" + str(fem._degree)
+        if isQk(fem):
+            return "Q" + str(fem._degree)
+    raise NotImplementedError("FEM NAME MANGLING")
+
+@dune_symbol
+def name_inifile():
+    # TODO pass some other option here.
+    return "argv[1]"
+
+@dune_preamble
+def parse_initree(varname):
+    dune_include("dune/common/parametertree.hh")
+    dune_include("dune/common/parametertreeparser.hh")
+    filename = name_inifile()
+    return ["Dune::ParameterTree initree;", "Dune::ParameterTreeParser::readINITree({}, {});".format(filename, varname)]
+
+@dune_symbol
+def name_initree():
+    parse_initree("initree")
+    #TODO we can get some other ini file here.
+    return "initree"
+
+@dune_preamble
+def define_dimension(name):
+    return "static const int {} = {};".format(name, _formdata.geometric_dimension)
+
+@dune_symbol
+def name_dimension():
+    define_dimension("dim")
+    return "dim"
+
+@dune_preamble
+def typedef_grid(name):
+    dim = name_dimension()
+    if any(_formdata.unique_elements[0].cell().cellname() in x for x in ["vertex", "interval", "quadrilateral", "hexalateral"]):
+        gridt = "Dune::YaspGrid<{}>".format(dim)
+        dune_include("dune/grid/yaspgrid.hh")
+    else:
+        if any(_formdata.unique_elements[0].cell().cellname() in x for x in ["triangle", "tetrahedron"]):
+            gridt = "Dune::UGGrid<{}>".format(dim)
+            dune_include("dune/grid/uggrid.hh")
+        else:
+            raise ValueError("Cant match your geometry with a DUNE grid. Please report bug.")
+    return "typedef {} {};".format(gridt, name)
+
+@dune_symbol
+def type_grid():
+    typedef_grid("Grid")
+    return "Grid"
+
+@dune_preamble
+def define_grid(name):
+    dune_include("dune/testtools/gridconstruction.hh")
+    ini = name_initree()
+    _type = type_grid()
+    return ["IniGridFactory<{}> factory({});".format(_type, ini),
+            "std::shared_ptr<{}> grid = factory.getGrid();".format(_type)]
+
+@dune_symbol
+def name_grid():
+    define_grid("grid")
+    return "grid"
+
+@dune_preamble
+def typedef_leafview(name):
+    grid = type_grid()
+    return "typedef {}::LeafGridView {};".format(grid, name)
+
+@dune_symbol
+def type_leafview():
+    typedef_leafview("GV")
+    return "GV"
+
+@dune_preamble
+def define_leafview(name):
+    _type = type_leafview()
+    grid = name_grid()
+    return "{} {} = {}->leafView();".format(_type, name, grid)
+
+@dune_symbol
+def name_leafview():
+    define_leafview("gv")
+    return "gv"
+
+@dune_preamble
+def typedef_vtkwriter(name):
+    dune_include("dune/grid/io/file/vtk/subsamplingvtkwriter.hh")
+    gv = type_leafview()
+    return "typedef Dune::SubsamplingVTKWriter<{}> {}".format(gv, name)
+
+@dune_symbol
+def type_vtkwriter():
+    typedef_vtkwriter("VTKWriter")
+    return "VTKWriter"
+
+@dune_preamble
+def define_subsamplinglevel(name):
+    ini = name_initree()
+    return "int {} = {}.get<int>(\"vtk.subsamplinglevel\", 0);".format(name, ini)
+
+@dune_symbol
+def name_subsamplinglevel():
+    define_subsamplinglevel("sublevel")
+    return "sublevel"
+
+@dune_preamble
+def define_vtkwriter(name):
+    _type = type_vtkwriter()
+    gv = name_leafview()
+    subsamp = name_subsamplinglevel()
+    return "{} {}({}, {});".format(_type, name, gv, subsamp)
+
+@dune_symbol
+def name_vtkwriter():
+    define_vtkwriter("vtkwriter")
+    return "vtkwriter"
+
+@dune_preamble
+def typedef_domainfield(name):
+    gridt = type_grid()
+    return "typedef {}::ctype {};".format(gridt, name)
+
+@dune_symbol
+def type_domainfield():
+    typedef_domainfield("DF")
+    return "DF"
+
+@dune_preamble
+def typedef_range(name):
+    return "typedef double {};".format(name)
+
+@dune_symbol
+def type_range():
+    typedef_range("R")
+    return "R"
+
+@dune_preamble
+def typedef_fem(expr, name):
+    gv = type_leafview()
+    df = type_domainfield()
+    r = type_range()
+    if isPk(expr):
+        dune_include("dune/pdelab/finiteelementmap/pkfem.hh")
+        return "typedef Dune::PDELab::PkLocalFiniteElementMap<{}, {}, {}, {}> {};".format(gv, df, r, expr._degree, name)
+    if isQk(generator._kwargs['expr']):
+        dune_include("dune/pdelab/finiteelementmap/qkfem.hh")
+        return "typedef Dune::PDELab::QkLocalFiniteElementMap<{}, {}, {}, {}> {};".format(gv, df, r, expr._degree, name)
+    raise NotImplementedError("FEM not implemented in dune-perftool")
+
+@dune_symbol
+def type_fem(expr):
+    name = FEM_name_mangling(expr).upper()
+    typedef_fem(expr, name)
+    return "{}_fem".format(name)
+
+@dune_preamble
+def define_fem(expr, name):
+    femtype = type_fem(expr)
+    gv = name_leafview()
+    return "{} {}({});".format(femtype, name, gv)
+
+@dune_symbol
+def name_fem(expr):
+    name = FEM_name_mangling(expr).lower()
+    define_fem(expr, name)
+    return "{}_fem".format(name)
+
+@dune_preamble
+def typedef_vectorbackend(name):
+    dune_include("dune/pdelab/backend/istlvectorbackend.hh")
+    return "typedef Dune::PDELab::ISTLVectorBacken<Dune::PDELab::ISTLParameters::no_blocking, 1> {}".format(name)
+
+@dune_symbol
+def type_vectorbackend():
+    typedef_vectorbackend("VectorBackend")
+    return "VectorBackend"
+
+@dune_symbol
+def type_orderingtag():
+    return "Dune::PDELab::LexicographicOrderingTag"
+
+@dune_preamble
+def typedef_constraintsassembler(name):
+    dune_include("dune/pdelab/constraints/conforming.hh")
+    return "typedef Dune::PDELab::ConformingDirichletConstraints {}".format(name)
+
+@dune_symbol
+def type_constraintsassembler():
+    typedef_constraintsassembler("ConstraintsAssembler")
+    return "ConstraintsAssembler"
+
+@dune_preamble
+def typedef_constraintscontainer(expr, name):
+    gfs = type_gfs(expr)
+    r = type_range()
+    return "typedef {}::ConstraintsContainer<{}>::Type {}".format(gfs,r, name)
+
+@dune_symbol
+def type_constraintscontainer(expr):
+    name = "{}_cc".format(FEM_name_mangling(expr)).upper()
+    typedef_constraintscontainer(expr, name)
+    return name
+
+@dune_preamble
+def define_constraintscontainer(expr, name):
+    cctype = type_constraintscontainer(expr)
+    return ["{} {};".format(cctype, name), "{}.clear();".format(name)]
+
+@dune_symbol
+def name_constraintscontainer(expr):
+    name = "{}_cc".format(FEM_name_mangling(expr)).lower()
+    define_constraintscontainer(expr, name)
+    return name
+
+@dune_preamble
+def typedef_gfs(expr, name):
+    vb = type_vectorbackend()
+    from ufl import FiniteElement, MixedElement, VectorElement, EnrichedElement, RestrictedElement
+    if isinstance(expr, FiniteElement):
+        gv = type_leafview()
+        fem = type_fem(expr)
+        cass = type_constraintsassembler()
+        return "typedef Dune::PDELab::GridFunctionSpace<{}, {}, {}, {}> {};".format(gv, fem, cass, vb, name)
+    if isinstance(expr, MixedElement):
+        ot = type_orderingtag()
+        args = ", ".join(type_gfs(e) for e in expr._sub_elements)
+        return "typedef Dune::PDELab::CompositeGridFunctionSpace<{}, {}, {}> {}".format(vb, ot, args, name)
+    if isinstance(expr, VectorElement):
+        dune_include("dune/pdelab/gridfunctionspace/vectorgridfunctionspace.hh")
+        gv = type_leafview()
+        fem = type_fem(expr._sub_elements[0])
+        dim = name_dimension()
+        cass = type_constraintsassembler()
+        return "typedef Dune::PDELab::GridFunctionSpace<{}, {}, {}, {}, {}, {}> {};".format(gv, fem, dim, vb, vb, cass, name)
+    if isinstance(expr, EnrichedElement):
+        raise NotImplementedError("Dune does not support enriched elements!")
+    if isinstance(expr, RestrictedElement):
+        raise NotImplementedError("Dune does not support restricted elements!")
+
+@dune_symbol
+def type_gfs(expr):
+    name = FEM_name_mangling(expr).upper()
+    typedef_gfs(expr, name)
+    return "{}_GFS".format(name)
+
+@dune_preamble
+def define_gfs(expr, name):
+    gfstype = type_gfs(expr)
+    from ufl import FiniteElement, MixedElement, VectorElement, EnrichedElement, RestrictedElement
+    if isinstance(expr, FiniteElement):
+        gv = name_leafview()
+        fem = name_fem(expr)
+        return "{} {}({}, {});".format(gfstype, name, gv, fem)
+    if isinstance(expr, MixedElement):
+        args = ", ".join(name_gfs(childgfs) for childgfs in expr._sub_elements)
+        return "{} {}({});".format(gfstype, name, args)
+    if isinstance(expr, VectorElement):
+        gv = name_leafview()
+        fem = name_fem(expr._sub_elements[0])
+        return "{} {}({}, {});".format(gfstype, name, gv, fem)
+    if isinstance(expr, EnrichedElement):
+        raise NotImplementedError("Dune does not support enriched elements!")
+    if isinstance(expr, RestrictedElement):
+        raise NotImplementedError("Dune does not support restricted elements!")
+
+
+@dune_symbol
+def name_gfs(expr):
+    name = "{}_gfs".format(FEM_name_mangling(expr)).lower()
+    define_gfs(expr, name)
+    return name
+
+
+@dune_preamble
+def define_dofestimate(name):
+    # Provide a worstcase estimate for the number of entries per row based on the given gridfunction space and cell geometry
+    if isQuadrilateral(_formdata.coefficient_elements[0]):
+        geo_factor = "4"
+    else:
+        geo_factor = "6"
+    gfs = name_gfs(_formdata.coefficient_elements[0])
+    ini = name_initree()
+    return ["int generic_dof_estimate =  {} * {}.maxLocalSize();".format(geo_factor, gfs),
+            "int dof_estimate = {}.get<int>(\"istl.number_of_nnz\", generic_dof_estimate);".format(ini)]
+
+@dune_symbol
+def name_dofestimate():
+    define_dofestimate("dofestimate")
+    return "dofestimate"
+
+@dune_preamble
+def typedef_matrixbackend(name):
+    dune_include("dune/pdelab/backend/istl/bcrsmatrixbackend.hh")
+    return "typedef Dune::PDELab::istl::BCRSMatrixBackend<> {};".format(name)
+
+@dune_symbol
+def type_matrixbackend():
+    typedef_matrixbackend("MatrixBackend")
+    return "MatrixBackend"
+
+@dune_preamble
+def define_matrixbackend(name):
+    mbtype = type_matrixbackend()
+    dof = name_dofestimate()
+    return "{} {}({});".format(mbtype, name, dof)
+
+@dune_symbol
+def name_matrixbackend():
+    define_matrixbackend("mb")
+    return "mb"
+
+@dune_preamble
+def typedef_parameters(name):
+    return "typedef LocalOperatorParameters {}".format(name)
+
+@dune_symbol
+def type_parameters():
+    typedef_parameters("Params")
+    return "Params"
+
+@dune_preamble
+def define_parameters(name):
+    partype = type_parameters()
+    return "{} {}();".format(partype, name)
+
+@dune_symbol
+def name_parameters():
+    define_parameters("params")
+    return "params"
+
+@dune_preamble
+def typedef_localoperator(name):
+    params = type_parameters()
+    return "typedef LocalOperator<{}> {}".format(params, name)
+
+@dune_symbol
+def type_localoperator():
+    typedef_localoperator("LocalOperator")
+    return "LocalOperator"
+
+@dune_preamble
+def define_localoperator(name):
+    loptype = type_localoperator()
+    ini = name_initree()
+    params = name_parameters()
+    return "{} {}({}, {})".format(loptype, name, ini, params)
+
+@dune_symbol
+def name_localoperator():
+    define_localoperator("lop")
+    return "lop"
+
+@dune_preamble
+def typedef_gridoperator(name):
+    ugfs = type_gfs(_formdata.coefficient_elements[0])
+    vgfs = type_gfs(_formdata.argument_elements[0])
+    lop = type_localoperator()
+    ucc = type_constraintscontainer(_formdata.coefficient_elements[0])
+    vcc = type_constraintscontainer(_formdata.argument_elements[0])
+    mb = type_matrixbackend()
+    df = type_domainfield()
+    r = type_range()
+    dune_include("dune/pdelab/gridoperator/gridoperator.hh")
+    return "typedef Dune::PDELab::GridOperator<{}, {}, {}, {}, {}, {}, {}, {}, {}> {}".format(ugfs, vgfs, lop, mb, df, r, r, ucc, vcc, name)
+
+@dune_symbol
+def type_gridoperator():
+    typedef_gridoperator("GO")
+    return "GO"
+
+@dune_preamble
+def define_gridoperator(name):
+    gotype = type_gridoperator()
+    ugfs = name_gfs(_formdata.coefficient_elements[0])
+    ucc = name_constraintscontainer(_formdata.coefficient_elements[0])
+    vgfs = name_gfs(_formdata.argument_elements[0])
+    vcc = name_constraintscontainer(_formdata.argument_elements[0])
+    lop = name_localoperator()
+    mb = name_matrixbackend()
+    return "{} {}({}, {}, {}, {}, {}, {});".format(gotype, name, ugfs, ucc, vgfs, vcc, lop, mb)
+
+@dune_symbol
+def name_gridoperator():
+    define_gridoperator("go")
+    return "go"
+
+@dune_preamble
+def typedef_vector(name):
+    gotype = type_gridoperator()
+    return "typedef {}::Traits::Domain {}".format(gotype, name)
+
+@dune_symbol
+def type_vector():
+    typedef_vector("V")
+    return "V"
+
+@dune_preamble
+def define_vector(name):
+    vtype = type_vector()
+    gfs = name_gfs(_formdata.coefficient_elements[0])
+    return ["{} {}({});".format(vtype, name, gfs), "{} = 0.0;".format(name)]
+
+@dune_symbol
+def name_vector():
+    define_vector("x")
+    return "x"
+
+@dune_preamble
+def typedef_linearsolver(name):
+    dune_include("dune/pdelab/backend/istlsolverbackend.hh")
+    return "typedef Dune::PDELab::ISTLBackend_SEQ_UMFPack {};".format(name)
+
+@dune_symbol
+def type_linearsolver():
+    typedef_linearsolver("LinearSolver")
+    return "LinearSolver"
+
+@dune_preamble
+def define_linearsolver(name):
+    lstype = type_linearsolver()
+    return "{} {}(false);".format(lstype, name)
+
+@dune_symbol
+def name_linearsolver():
+    define_linearsolver("ls")
+    return "ls"
+
+@dune_preamble
+def define_reduction(name):
+    ini = name_initree()
+    return "double {} = {}.get<double>(\"reduction\", 1e-12);".format(name, ini)
+
+@dune_symbol
+def name_reduction():
+    define_reduction("reduction")
+    return "reduction"
+
+@dune_symbol
+def typedef_stationarylinearproblemsolver(name):
+    dune_include("dune/pdelab/stationary/linearproblem.hh")
+    gotype = type_gridoperator()
+    lstype = type_linearsolver()
+    xtype = type_vector()
+    return "typedef Dune::PDELab::StationaryLinearProblemSolver<{}, {}, {}> {}".format(gotype, lstype, xtype, name)
+
+@dune_symbol
+def type_stationarylinearproblemsolver():
+    typedef_stationarylinearproblemsolver("SLP")
+    return "SLP"
+
+@dune_preamble
+def define_stationarylinearproblemsolver(name):
+    slptype = type_stationarylinearproblemsolver()
+    go = name_gridoperator()
+    ls = name_linearsolver()
+    x = name_vector()
+    red = name_reduction()
+    return "{} {}({}, {}, {}, {});".format(slptype, name, go, ls, x, red)
+
+@dune_symbol
+def name_stationarylinearproblemsolver():
+    define_stationarylinearproblemsolver("slp")
+    return "slp"
+
+@dune_preamble
+def dune_solve():
+    from ufl.algorithms.predicates import is_multilinear
+    if is_multilinear(_formdata.preprocessed_form):
+        slp = name_stationarylinearproblemsolver()
+        return "{}.apply();".format(slp)
+    else:
+        pass
+
+@dune_preamble
+def define_vtkfile(name):
+    ini = name_initree()
+    dune_include("string")
+    return "std::string {} = {}.get<std::string>(\"vtk.filename\", \"output\");".format(name, ini)
+
+@dune_symbol
+def name_vtkfile():
+    define_vtkfile("vtkfile")
+    return "vtkfile"
+
+@dune_preamble
+def vtkoutput():
+    dune_include("dune/pdelab/gridfunctionspace/vtk.hh")
+    vtkwriter = name_vtkwriter()
+    gfs = name_gfs(_formdata.coefficient_elements[0])
+    vec = name_vector()
+    vtkfile = name_vtkfile()
+    dune_solve()
+    return ["Dune::PDELab::addSolutionToVTKWriter({}, {}, {});".format(vtkwriter, gfs, vec),
+            "{}.write({}, Dune::VTK::appendedraw);".format(vtkwriter, vtkfile)]
+
+
+def generate_driver(formdata):
+    # Set the global data:
+    global _formdata
+    _formdata = formdata
+
+    # This should trigger everything IMO
+    vtkoutput()
+
+    # Print the results:
+    from dune.perftool.generation import cache_preambles
+    for p in sorted(cache_preambles(), key=lambda x : x[0]):
+        print p
-- 
GitLab