From c44ec24a8c69a6df18bcc896645fc1e990935f88 Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.r.kempf@gmail.com> Date: Wed, 2 Sep 2015 11:49:28 +0200 Subject: [PATCH] Adapt the driver generation to the new infrastructure. --- python/dune/perftool/driver.py | 547 +++++++++++++++++++++++++++++++++ 1 file changed, 547 insertions(+) create mode 100644 python/dune/perftool/driver.py diff --git a/python/dune/perftool/driver.py b/python/dune/perftool/driver.py new file mode 100644 index 00000000..8df2f262 --- /dev/null +++ b/python/dune/perftool/driver.py @@ -0,0 +1,547 @@ +""" The module that drives the generation of the pdelab driver """ + +from dune.perftool.generation import dune_include, dune_preamble, dune_symbol + +# Have a global variable with the entire form data. This allows functions that depend +# deterministically on the entire data set to directly access it instead of passing it +# through the entire generator chain. +_formdata = None + + +def isLagrange(fem): + return fem._short_name is 'CG' + +def isSimplical(fem): + return any(fem._cell._cellname in x for x in ["triangle", "tetrahedron"]) + +def isQuadrilateral(fem): + return any(fem._cell._cellname in x for x in ["vertex", "interval", "quadrilateral", "hexalateral"]) + +def isPk(fem): + return isLagrange(fem) and isSimplical(fem) + +def isQk(fem): + return isLagrange(fem) and isQuadrilateral(fem) + +def FEM_name_mangling(fem): + from ufl import MixedElement, VectorElement, FiniteElement + if isinstance(fem, MixedElement): + name = "" + for elem in fem._sub_elements: + if name is not "": + name = name + "_" + name = name + FEM_name_mangling(elem) + return name + if isinstance(fem, VectorElement): + return FEM_name_mangling(fem._sub_elements[0]) + "_" + str(fem._cell._geometric_dimension) + if isinstance(fem, FiniteElement): + if isPk(fem): + return "P" + str(fem._degree) + if isQk(fem): + return "Q" + str(fem._degree) + raise NotImplementedError("FEM NAME MANGLING") + +@dune_symbol +def name_inifile(): + # TODO pass some other option here. + return "argv[1]" + +@dune_preamble +def parse_initree(varname): + dune_include("dune/common/parametertree.hh") + dune_include("dune/common/parametertreeparser.hh") + filename = name_inifile() + return ["Dune::ParameterTree initree;", "Dune::ParameterTreeParser::readINITree({}, {});".format(filename, varname)] + +@dune_symbol +def name_initree(): + parse_initree("initree") + #TODO we can get some other ini file here. + return "initree" + +@dune_preamble +def define_dimension(name): + return "static const int {} = {};".format(name, _formdata.geometric_dimension) + +@dune_symbol +def name_dimension(): + define_dimension("dim") + return "dim" + +@dune_preamble +def typedef_grid(name): + dim = name_dimension() + if any(_formdata.unique_elements[0].cell().cellname() in x for x in ["vertex", "interval", "quadrilateral", "hexalateral"]): + gridt = "Dune::YaspGrid<{}>".format(dim) + dune_include("dune/grid/yaspgrid.hh") + else: + if any(_formdata.unique_elements[0].cell().cellname() in x for x in ["triangle", "tetrahedron"]): + gridt = "Dune::UGGrid<{}>".format(dim) + dune_include("dune/grid/uggrid.hh") + else: + raise ValueError("Cant match your geometry with a DUNE grid. Please report bug.") + return "typedef {} {};".format(gridt, name) + +@dune_symbol +def type_grid(): + typedef_grid("Grid") + return "Grid" + +@dune_preamble +def define_grid(name): + dune_include("dune/testtools/gridconstruction.hh") + ini = name_initree() + _type = type_grid() + return ["IniGridFactory<{}> factory({});".format(_type, ini), + "std::shared_ptr<{}> grid = factory.getGrid();".format(_type)] + +@dune_symbol +def name_grid(): + define_grid("grid") + return "grid" + +@dune_preamble +def typedef_leafview(name): + grid = type_grid() + return "typedef {}::LeafGridView {};".format(grid, name) + +@dune_symbol +def type_leafview(): + typedef_leafview("GV") + return "GV" + +@dune_preamble +def define_leafview(name): + _type = type_leafview() + grid = name_grid() + return "{} {} = {}->leafView();".format(_type, name, grid) + +@dune_symbol +def name_leafview(): + define_leafview("gv") + return "gv" + +@dune_preamble +def typedef_vtkwriter(name): + dune_include("dune/grid/io/file/vtk/subsamplingvtkwriter.hh") + gv = type_leafview() + return "typedef Dune::SubsamplingVTKWriter<{}> {}".format(gv, name) + +@dune_symbol +def type_vtkwriter(): + typedef_vtkwriter("VTKWriter") + return "VTKWriter" + +@dune_preamble +def define_subsamplinglevel(name): + ini = name_initree() + return "int {} = {}.get<int>(\"vtk.subsamplinglevel\", 0);".format(name, ini) + +@dune_symbol +def name_subsamplinglevel(): + define_subsamplinglevel("sublevel") + return "sublevel" + +@dune_preamble +def define_vtkwriter(name): + _type = type_vtkwriter() + gv = name_leafview() + subsamp = name_subsamplinglevel() + return "{} {}({}, {});".format(_type, name, gv, subsamp) + +@dune_symbol +def name_vtkwriter(): + define_vtkwriter("vtkwriter") + return "vtkwriter" + +@dune_preamble +def typedef_domainfield(name): + gridt = type_grid() + return "typedef {}::ctype {};".format(gridt, name) + +@dune_symbol +def type_domainfield(): + typedef_domainfield("DF") + return "DF" + +@dune_preamble +def typedef_range(name): + return "typedef double {};".format(name) + +@dune_symbol +def type_range(): + typedef_range("R") + return "R" + +@dune_preamble +def typedef_fem(expr, name): + gv = type_leafview() + df = type_domainfield() + r = type_range() + if isPk(expr): + dune_include("dune/pdelab/finiteelementmap/pkfem.hh") + return "typedef Dune::PDELab::PkLocalFiniteElementMap<{}, {}, {}, {}> {};".format(gv, df, r, expr._degree, name) + if isQk(generator._kwargs['expr']): + dune_include("dune/pdelab/finiteelementmap/qkfem.hh") + return "typedef Dune::PDELab::QkLocalFiniteElementMap<{}, {}, {}, {}> {};".format(gv, df, r, expr._degree, name) + raise NotImplementedError("FEM not implemented in dune-perftool") + +@dune_symbol +def type_fem(expr): + name = FEM_name_mangling(expr).upper() + typedef_fem(expr, name) + return "{}_fem".format(name) + +@dune_preamble +def define_fem(expr, name): + femtype = type_fem(expr) + gv = name_leafview() + return "{} {}({});".format(femtype, name, gv) + +@dune_symbol +def name_fem(expr): + name = FEM_name_mangling(expr).lower() + define_fem(expr, name) + return "{}_fem".format(name) + +@dune_preamble +def typedef_vectorbackend(name): + dune_include("dune/pdelab/backend/istlvectorbackend.hh") + return "typedef Dune::PDELab::ISTLVectorBacken<Dune::PDELab::ISTLParameters::no_blocking, 1> {}".format(name) + +@dune_symbol +def type_vectorbackend(): + typedef_vectorbackend("VectorBackend") + return "VectorBackend" + +@dune_symbol +def type_orderingtag(): + return "Dune::PDELab::LexicographicOrderingTag" + +@dune_preamble +def typedef_constraintsassembler(name): + dune_include("dune/pdelab/constraints/conforming.hh") + return "typedef Dune::PDELab::ConformingDirichletConstraints {}".format(name) + +@dune_symbol +def type_constraintsassembler(): + typedef_constraintsassembler("ConstraintsAssembler") + return "ConstraintsAssembler" + +@dune_preamble +def typedef_constraintscontainer(expr, name): + gfs = type_gfs(expr) + r = type_range() + return "typedef {}::ConstraintsContainer<{}>::Type {}".format(gfs,r, name) + +@dune_symbol +def type_constraintscontainer(expr): + name = "{}_cc".format(FEM_name_mangling(expr)).upper() + typedef_constraintscontainer(expr, name) + return name + +@dune_preamble +def define_constraintscontainer(expr, name): + cctype = type_constraintscontainer(expr) + return ["{} {};".format(cctype, name), "{}.clear();".format(name)] + +@dune_symbol +def name_constraintscontainer(expr): + name = "{}_cc".format(FEM_name_mangling(expr)).lower() + define_constraintscontainer(expr, name) + return name + +@dune_preamble +def typedef_gfs(expr, name): + vb = type_vectorbackend() + from ufl import FiniteElement, MixedElement, VectorElement, EnrichedElement, RestrictedElement + if isinstance(expr, FiniteElement): + gv = type_leafview() + fem = type_fem(expr) + cass = type_constraintsassembler() + return "typedef Dune::PDELab::GridFunctionSpace<{}, {}, {}, {}> {};".format(gv, fem, cass, vb, name) + if isinstance(expr, MixedElement): + ot = type_orderingtag() + args = ", ".join(type_gfs(e) for e in expr._sub_elements) + return "typedef Dune::PDELab::CompositeGridFunctionSpace<{}, {}, {}> {}".format(vb, ot, args, name) + if isinstance(expr, VectorElement): + dune_include("dune/pdelab/gridfunctionspace/vectorgridfunctionspace.hh") + gv = type_leafview() + fem = type_fem(expr._sub_elements[0]) + dim = name_dimension() + cass = type_constraintsassembler() + return "typedef Dune::PDELab::GridFunctionSpace<{}, {}, {}, {}, {}, {}> {};".format(gv, fem, dim, vb, vb, cass, name) + if isinstance(expr, EnrichedElement): + raise NotImplementedError("Dune does not support enriched elements!") + if isinstance(expr, RestrictedElement): + raise NotImplementedError("Dune does not support restricted elements!") + +@dune_symbol +def type_gfs(expr): + name = FEM_name_mangling(expr).upper() + typedef_gfs(expr, name) + return "{}_GFS".format(name) + +@dune_preamble +def define_gfs(expr, name): + gfstype = type_gfs(expr) + from ufl import FiniteElement, MixedElement, VectorElement, EnrichedElement, RestrictedElement + if isinstance(expr, FiniteElement): + gv = name_leafview() + fem = name_fem(expr) + return "{} {}({}, {});".format(gfstype, name, gv, fem) + if isinstance(expr, MixedElement): + args = ", ".join(name_gfs(childgfs) for childgfs in expr._sub_elements) + return "{} {}({});".format(gfstype, name, args) + if isinstance(expr, VectorElement): + gv = name_leafview() + fem = name_fem(expr._sub_elements[0]) + return "{} {}({}, {});".format(gfstype, name, gv, fem) + if isinstance(expr, EnrichedElement): + raise NotImplementedError("Dune does not support enriched elements!") + if isinstance(expr, RestrictedElement): + raise NotImplementedError("Dune does not support restricted elements!") + + +@dune_symbol +def name_gfs(expr): + name = "{}_gfs".format(FEM_name_mangling(expr)).lower() + define_gfs(expr, name) + return name + + +@dune_preamble +def define_dofestimate(name): + # Provide a worstcase estimate for the number of entries per row based on the given gridfunction space and cell geometry + if isQuadrilateral(_formdata.coefficient_elements[0]): + geo_factor = "4" + else: + geo_factor = "6" + gfs = name_gfs(_formdata.coefficient_elements[0]) + ini = name_initree() + return ["int generic_dof_estimate = {} * {}.maxLocalSize();".format(geo_factor, gfs), + "int dof_estimate = {}.get<int>(\"istl.number_of_nnz\", generic_dof_estimate);".format(ini)] + +@dune_symbol +def name_dofestimate(): + define_dofestimate("dofestimate") + return "dofestimate" + +@dune_preamble +def typedef_matrixbackend(name): + dune_include("dune/pdelab/backend/istl/bcrsmatrixbackend.hh") + return "typedef Dune::PDELab::istl::BCRSMatrixBackend<> {};".format(name) + +@dune_symbol +def type_matrixbackend(): + typedef_matrixbackend("MatrixBackend") + return "MatrixBackend" + +@dune_preamble +def define_matrixbackend(name): + mbtype = type_matrixbackend() + dof = name_dofestimate() + return "{} {}({});".format(mbtype, name, dof) + +@dune_symbol +def name_matrixbackend(): + define_matrixbackend("mb") + return "mb" + +@dune_preamble +def typedef_parameters(name): + return "typedef LocalOperatorParameters {}".format(name) + +@dune_symbol +def type_parameters(): + typedef_parameters("Params") + return "Params" + +@dune_preamble +def define_parameters(name): + partype = type_parameters() + return "{} {}();".format(partype, name) + +@dune_symbol +def name_parameters(): + define_parameters("params") + return "params" + +@dune_preamble +def typedef_localoperator(name): + params = type_parameters() + return "typedef LocalOperator<{}> {}".format(params, name) + +@dune_symbol +def type_localoperator(): + typedef_localoperator("LocalOperator") + return "LocalOperator" + +@dune_preamble +def define_localoperator(name): + loptype = type_localoperator() + ini = name_initree() + params = name_parameters() + return "{} {}({}, {})".format(loptype, name, ini, params) + +@dune_symbol +def name_localoperator(): + define_localoperator("lop") + return "lop" + +@dune_preamble +def typedef_gridoperator(name): + ugfs = type_gfs(_formdata.coefficient_elements[0]) + vgfs = type_gfs(_formdata.argument_elements[0]) + lop = type_localoperator() + ucc = type_constraintscontainer(_formdata.coefficient_elements[0]) + vcc = type_constraintscontainer(_formdata.argument_elements[0]) + mb = type_matrixbackend() + df = type_domainfield() + r = type_range() + dune_include("dune/pdelab/gridoperator/gridoperator.hh") + return "typedef Dune::PDELab::GridOperator<{}, {}, {}, {}, {}, {}, {}, {}, {}> {}".format(ugfs, vgfs, lop, mb, df, r, r, ucc, vcc, name) + +@dune_symbol +def type_gridoperator(): + typedef_gridoperator("GO") + return "GO" + +@dune_preamble +def define_gridoperator(name): + gotype = type_gridoperator() + ugfs = name_gfs(_formdata.coefficient_elements[0]) + ucc = name_constraintscontainer(_formdata.coefficient_elements[0]) + vgfs = name_gfs(_formdata.argument_elements[0]) + vcc = name_constraintscontainer(_formdata.argument_elements[0]) + lop = name_localoperator() + mb = name_matrixbackend() + return "{} {}({}, {}, {}, {}, {}, {});".format(gotype, name, ugfs, ucc, vgfs, vcc, lop, mb) + +@dune_symbol +def name_gridoperator(): + define_gridoperator("go") + return "go" + +@dune_preamble +def typedef_vector(name): + gotype = type_gridoperator() + return "typedef {}::Traits::Domain {}".format(gotype, name) + +@dune_symbol +def type_vector(): + typedef_vector("V") + return "V" + +@dune_preamble +def define_vector(name): + vtype = type_vector() + gfs = name_gfs(_formdata.coefficient_elements[0]) + return ["{} {}({});".format(vtype, name, gfs), "{} = 0.0;".format(name)] + +@dune_symbol +def name_vector(): + define_vector("x") + return "x" + +@dune_preamble +def typedef_linearsolver(name): + dune_include("dune/pdelab/backend/istlsolverbackend.hh") + return "typedef Dune::PDELab::ISTLBackend_SEQ_UMFPack {};".format(name) + +@dune_symbol +def type_linearsolver(): + typedef_linearsolver("LinearSolver") + return "LinearSolver" + +@dune_preamble +def define_linearsolver(name): + lstype = type_linearsolver() + return "{} {}(false);".format(lstype, name) + +@dune_symbol +def name_linearsolver(): + define_linearsolver("ls") + return "ls" + +@dune_preamble +def define_reduction(name): + ini = name_initree() + return "double {} = {}.get<double>(\"reduction\", 1e-12);".format(name, ini) + +@dune_symbol +def name_reduction(): + define_reduction("reduction") + return "reduction" + +@dune_symbol +def typedef_stationarylinearproblemsolver(name): + dune_include("dune/pdelab/stationary/linearproblem.hh") + gotype = type_gridoperator() + lstype = type_linearsolver() + xtype = type_vector() + return "typedef Dune::PDELab::StationaryLinearProblemSolver<{}, {}, {}> {}".format(gotype, lstype, xtype, name) + +@dune_symbol +def type_stationarylinearproblemsolver(): + typedef_stationarylinearproblemsolver("SLP") + return "SLP" + +@dune_preamble +def define_stationarylinearproblemsolver(name): + slptype = type_stationarylinearproblemsolver() + go = name_gridoperator() + ls = name_linearsolver() + x = name_vector() + red = name_reduction() + return "{} {}({}, {}, {}, {});".format(slptype, name, go, ls, x, red) + +@dune_symbol +def name_stationarylinearproblemsolver(): + define_stationarylinearproblemsolver("slp") + return "slp" + +@dune_preamble +def dune_solve(): + from ufl.algorithms.predicates import is_multilinear + if is_multilinear(_formdata.preprocessed_form): + slp = name_stationarylinearproblemsolver() + return "{}.apply();".format(slp) + else: + pass + +@dune_preamble +def define_vtkfile(name): + ini = name_initree() + dune_include("string") + return "std::string {} = {}.get<std::string>(\"vtk.filename\", \"output\");".format(name, ini) + +@dune_symbol +def name_vtkfile(): + define_vtkfile("vtkfile") + return "vtkfile" + +@dune_preamble +def vtkoutput(): + dune_include("dune/pdelab/gridfunctionspace/vtk.hh") + vtkwriter = name_vtkwriter() + gfs = name_gfs(_formdata.coefficient_elements[0]) + vec = name_vector() + vtkfile = name_vtkfile() + dune_solve() + return ["Dune::PDELab::addSolutionToVTKWriter({}, {}, {});".format(vtkwriter, gfs, vec), + "{}.write({}, Dune::VTK::appendedraw);".format(vtkwriter, vtkfile)] + + +def generate_driver(formdata): + # Set the global data: + global _formdata + _formdata = formdata + + # This should trigger everything IMO + vtkoutput() + + # Print the results: + from dune.perftool.generation import cache_preambles + for p in sorted(cache_preambles(), key=lambda x : x[0]): + print p -- GitLab