From e513d2504794338a07dee252cb3cc19365ec9c79 Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Tue, 23 Feb 2016 13:54:17 +0100 Subject: [PATCH] Introduce some more fancy generators --- python/dune/perftool/generation/generators.py | 17 +- python/dune/perftool/pdelab/__init__.py | 7 +- python/dune/perftool/pdelab/argument.py | 12 +- python/dune/perftool/pdelab/driver.py | 157 +++++++++--------- python/dune/perftool/pdelab/geometry.py | 6 +- python/dune/perftool/pdelab/localoperator.py | 23 +-- python/dune/perftool/pdelab/quadrature.py | 8 +- 7 files changed, 116 insertions(+), 114 deletions(-) diff --git a/python/dune/perftool/generation/generators.py b/python/dune/perftool/generation/generators.py index 5da34171..e9e72f8d 100644 --- a/python/dune/perftool/generation/generators.py +++ b/python/dune/perftool/generation/generators.py @@ -4,9 +4,24 @@ are commonly needed for code generation """ from dune.perftool.generation import generator_factory +from dune.perftool.cgen.clazz import AccessModifier + +symbol = generator_factory(item_tags=("symbol",)) def include_file(include, filetag="operator"): from cgen import Include gen = generator_factory(on_store=lambda i: Include(i), item_tags=(filetag, "include"), no_deco=True) - return gen(include) \ No newline at end of file + return gen(include) + + +# TODO this should perhaps already take into account a potential construction of the base class object +def base_class(baseclass, classtag=None, access=AccessModifier.PUBLIC): + assert classtag + from dune.perftool.cgen.clazz import BaseClass + gen = generator_factory(item_tags=("baseclass", classtag), on_store=lambda n: BaseClass(n, inheritance=access), counted=True, no_deco=True) + return gen(baseclass) + + +def preamble(tag): + return generator_factory(item_tags=(tag, "preamble"), counted=True) \ No newline at end of file diff --git a/python/dune/perftool/pdelab/__init__.py b/python/dune/perftool/pdelab/__init__.py index fec0eba5..a73ae2e9 100644 --- a/python/dune/perftool/pdelab/__init__.py +++ b/python/dune/perftool/pdelab/__init__.py @@ -1,14 +1,9 @@ """ The pdelab specific parts of the code generation process """ # Define the generators that are used throughout all pdelab specific code generations. -from dune.perftool.generation import generator_factory from dune.perftool.loopy.transformer import quadrature_iname from loopy import CInstruction -dune_symbol = generator_factory(item_tags=("pdelab", "kernel", "symbol")) -dune_preamble = generator_factory(item_tags=("pdelab", "kernel", "preamble"), counted=True) -dune_include = generator_factory(on_store=lambda i: "#include<{}>".format(i), item_tags=("pdelab", "include"), no_deco=True) - def quadrature_preamble(assignees=[]): # TODO: How to enforce the order of quadrature preambles? Counted? @@ -16,6 +11,6 @@ def quadrature_preamble(assignees=[]): # Now define some commonly used generators that do not fall into a specific category -@dune_symbol +@symbol def name_index(index): return str(index._indices[0]) diff --git a/python/dune/perftool/pdelab/argument.py b/python/dune/perftool/pdelab/argument.py index d402b00a..946e8731 100644 --- a/python/dune/perftool/pdelab/argument.py +++ b/python/dune/perftool/pdelab/argument.py @@ -1,10 +1,10 @@ """ Generator functions related to trial and test functions and the accumulation loop""" -from dune.perftool.pdelab import dune_symbol, dune_preamble +from dune.perftool.generation import symbol from dune.perftool.ufl.modified_terminals import ModifiedArgumentDescriptor -@dune_symbol +@symbol def name_testfunction(modarg): ma = ModifiedArgumentDescriptor(modarg) if len(ma.expr.element().sub_elements()) > 0: @@ -12,19 +12,19 @@ def name_testfunction(modarg): return "{}a{}".format("grad_" if ma.grad else "", ma.expr.number()) -@dune_symbol +@symbol def name_trialfunction(modarg): ma = ModifiedArgumentDescriptor(modarg) return "{}c{}".format("grad_" if ma.grad else "", ma.expr.count()) -@dune_symbol +@symbol def name_testfunctionspace(*a): # TODO return "lfsv" -@dune_symbol +@symbol def name_trialfunctionspace(*a): # TODO return "lfsu" @@ -50,6 +50,6 @@ def name_argument(modarg): assert False -@dune_symbol +@symbol def name_residual(): return "r" diff --git a/python/dune/perftool/pdelab/driver.py b/python/dune/perftool/pdelab/driver.py index 7069b30d..17d1fbac 100644 --- a/python/dune/perftool/pdelab/driver.py +++ b/python/dune/perftool/pdelab/driver.py @@ -6,10 +6,7 @@ Currently, these are hardcoded as strings. It would be possible to switch these to cgen expression. OTOH, there is not much to be gained there. """ -from dune.perftool.generation import generator_factory, include_file -from dune.perftool.pdelab import dune_symbol - -driver_preamble = generator_factory(item_tags=("driver", "preamble"), counted=True) +from dune.perftool.generation import include_file, symbol, preamble # Have a global variable with the entire form data. This allows functions that depend # deterministically on the entire data set to directly access it instead of passing it @@ -62,13 +59,13 @@ def FEM_name_mangling(fem): raise NotImplementedError("FEM NAME MANGLING") -@dune_symbol +@symbol def name_inifile(): # TODO pass some other option here. return "argv[1]" -@driver_preamble +@preamble('driver') def parse_initree(varname): include_file("dune/common/parametertree.hh", filetag="driver") include_file("dune/common/parametertreeparser.hh", filetag="driver") @@ -76,25 +73,25 @@ def parse_initree(varname): return ["Dune::ParameterTree initree;", "Dune::ParameterTreeParser::readINITree({}, {});".format(filename, varname)] -@dune_symbol +@symbol def name_initree(): parse_initree("initree") # TODO we can get some other ini file here. return "initree" -@driver_preamble +@preamble('driver') def define_dimension(name): return "static const int {} = {};".format(name, _form.cell().geometric_dimension()) -@dune_symbol +@symbol def name_dimension(): define_dimension("dim") return "dim" -@driver_preamble +@preamble('driver') def typedef_grid(name): dim = name_dimension() if any(_form.cell().cellname() in x for x in ["vertex", "interval", "quadrilateral", "hexalateral"]): @@ -109,13 +106,13 @@ def typedef_grid(name): return "typedef {} {};".format(gridt, name) -@dune_symbol +@symbol def type_grid(): typedef_grid("Grid") return "Grid" -@driver_preamble +@preamble('driver') def define_grid(name): include_file("dune/testtools/gridconstruction.hh", filetag="driver") ini = name_initree() @@ -124,63 +121,63 @@ def define_grid(name): "std::shared_ptr<{}> grid = factory.getGrid();".format(_type)] -@dune_symbol +@symbol def name_grid(): define_grid("grid") return "grid" -@driver_preamble +@preamble('driver') def typedef_leafview(name): grid = type_grid() return "typedef {}::LeafGridView {};".format(grid, name) -@dune_symbol +@symbol def type_leafview(): typedef_leafview("GV") return "GV" -@driver_preamble +@preamble('driver') def define_leafview(name): _type = type_leafview() grid = name_grid() return "{} {} = {}->leafGridView();".format(_type, name, grid) -@dune_symbol +@symbol def name_leafview(): define_leafview("gv") return "gv" -@driver_preamble +@preamble('driver') def typedef_vtkwriter(name): include_file("dune/grid/io/file/vtk/subsamplingvtkwriter.hh", filetag="driver") gv = type_leafview() return "typedef Dune::SubsamplingVTKWriter<{}> {};".format(gv, name) -@dune_symbol +@symbol def type_vtkwriter(): typedef_vtkwriter("VTKWriter") return "VTKWriter" -@driver_preamble +@preamble('driver') def define_subsamplinglevel(name): ini = name_initree() return "int {} = {}.get<int>(\"vtk.subsamplinglevel\", 0);".format(name, ini) -@dune_symbol +@symbol def name_subsamplinglevel(): define_subsamplinglevel("sublevel") return "sublevel" -@driver_preamble +@preamble('driver') def define_vtkwriter(name): _type = type_vtkwriter() gv = name_leafview() @@ -188,36 +185,36 @@ def define_vtkwriter(name): return "{} {}({}, {});".format(_type, name, gv, subsamp) -@dune_symbol +@symbol def name_vtkwriter(): define_vtkwriter("vtkwriter") return "vtkwriter" -@driver_preamble +@preamble('driver') def typedef_domainfield(name): gridt = type_grid() return "typedef {}::ctype {};".format(gridt, name) -@dune_symbol +@symbol def type_domainfield(): typedef_domainfield("DF") return "DF" -@driver_preamble +@preamble('driver') def typedef_range(name): return "typedef double {};".format(name) -@dune_symbol +@symbol def type_range(): typedef_range("R") return "R" -@driver_preamble +@preamble('driver') def typedef_fem(expr, name): gv = type_leafview() df = type_domainfield() @@ -231,84 +228,84 @@ def typedef_fem(expr, name): raise NotImplementedError("FEM not implemented in dune-perftool") -@dune_symbol +@symbol def type_fem(expr): name = "{}_FEM".format(FEM_name_mangling(expr).upper()) typedef_fem(expr, name) return name -@driver_preamble +@preamble('driver') def define_fem(expr, name): femtype = type_fem(expr) gv = name_leafview() return "{} {}({});".format(femtype, name, gv) -@dune_symbol +@symbol def name_fem(expr): name = "{}_fem".format(FEM_name_mangling(expr).lower()) define_fem(expr, name) return name -@driver_preamble +@preamble('driver') def typedef_vectorbackend(name): include_file("dune/pdelab/backend/istlvectorbackend.hh", filetag="driver") return "typedef Dune::PDELab::ISTLVectorBackend<Dune::PDELab::ISTLParameters::no_blocking, 1> {};".format(name) -@dune_symbol +@symbol def type_vectorbackend(): typedef_vectorbackend("VectorBackend") return "VectorBackend" -@dune_symbol +@symbol def type_orderingtag(): return "Dune::PDELab::LexicographicOrderingTag" -@driver_preamble +@preamble('driver') def typedef_constraintsassembler(name): include_file("dune/pdelab/constraints/conforming.hh", filetag="driver") return "typedef Dune::PDELab::ConformingDirichletConstraints {};".format(name) -@dune_symbol +@symbol def type_constraintsassembler(): typedef_constraintsassembler("ConstraintsAssembler") return "ConstraintsAssembler" -@driver_preamble +@preamble('driver') def typedef_constraintscontainer(expr, name): gfs = type_gfs(expr) r = type_range() return "typedef {}::ConstraintsContainer<{}>::Type {};".format(gfs, r, name) -@dune_symbol +@symbol def type_constraintscontainer(expr): name = "{}_cc".format(FEM_name_mangling(expr)).upper() typedef_constraintscontainer(expr, name) return name -@driver_preamble +@preamble('driver') def define_constraintscontainer(expr, name): cctype = type_constraintscontainer(expr) return ["{} {};".format(cctype, name), "{}.clear();".format(name)] -@dune_symbol +@symbol def name_constraintscontainer(expr): name = "{}_cc".format(FEM_name_mangling(expr)).lower() define_constraintscontainer(expr, name) return name -@driver_preamble +@preamble('driver') def typedef_gfs(expr, name): vb = type_vectorbackend() from ufl import FiniteElement, MixedElement, VectorElement, EnrichedElement, RestrictedElement @@ -334,14 +331,14 @@ def typedef_gfs(expr, name): raise NotImplementedError("Dune does not support restricted elements!") -@dune_symbol +@symbol def type_gfs(expr): name = "{}_GFS".format(FEM_name_mangling(expr).upper()) typedef_gfs(expr, name) return name -@driver_preamble +@preamble('driver') def define_gfs(expr, name): gfstype = type_gfs(expr) from ufl import FiniteElement, MixedElement, VectorElement, EnrichedElement, RestrictedElement @@ -362,14 +359,14 @@ def define_gfs(expr, name): raise NotImplementedError("Dune does not support restricted elements!") -@dune_symbol +@symbol def name_gfs(expr): name = "{}_gfs".format(FEM_name_mangling(expr)).lower() define_gfs(expr, name) return name -@driver_preamble +@preamble('driver') def define_dofestimate(name): # Provide a worstcase estimate for the number of entries per row based on the given gridfunction space and cell geometry if isQuadrilateral(_form.coefficients()[0].element()): @@ -382,61 +379,61 @@ def define_dofestimate(name): "int dof_estimate = {}.get<int>(\"istl.number_of_nnz\", generic_dof_estimate);".format(ini)] -@dune_symbol +@symbol def name_dofestimate(): define_dofestimate("dofestimate") return "dofestimate" -@driver_preamble +@preamble('driver') def typedef_matrixbackend(name): include_file("dune/pdelab/backend/istl/bcrsmatrixbackend.hh", filetag="driver") return "typedef Dune::PDELab::istl::BCRSMatrixBackend<> {};".format(name) -@dune_symbol +@symbol def type_matrixbackend(): typedef_matrixbackend("MatrixBackend") return "MatrixBackend" -@driver_preamble +@preamble('driver') def define_matrixbackend(name): mbtype = type_matrixbackend() dof = name_dofestimate() return "{} {}({});".format(mbtype, name, dof) -@dune_symbol +@symbol def name_matrixbackend(): define_matrixbackend("mb") return "mb" -@driver_preamble +@preamble('driver') def typedef_parameters(name): return "typedef LocalOperatorParameters {};".format(name) -@dune_symbol +@symbol def type_parameters(): typedef_parameters("Params") return "Params" -@driver_preamble +@preamble('driver') def define_parameters(name): partype = type_parameters() return "{} {}();".format(partype, name) -@dune_symbol +@symbol def name_parameters(): define_parameters("params") return "params" -@driver_preamble +@preamble('driver') def typedef_localoperator(name): # No Parameter class here, yet # params = type_parameters() @@ -446,13 +443,13 @@ def typedef_localoperator(name): return "// Here in the future: typedef for the local operator with parameter class as template parameter" -@dune_symbol +@symbol def type_localoperator(): typedef_localoperator("LocalOperator") return "LocalOperator" -@driver_preamble +@preamble('driver') def define_localoperator(name): loptype = type_localoperator() ini = name_initree() @@ -460,13 +457,13 @@ def define_localoperator(name): return "{} {}({}, {});".format(loptype, name, ini, params) -@dune_symbol +@symbol def name_localoperator(): define_localoperator("lop") return "lop" -@driver_preamble +@preamble('driver') def typedef_gridoperator(name): ugfs = type_gfs(_form.coefficients()[0].element()) vgfs = type_gfs(_form.arguments()[0].element()) @@ -480,13 +477,13 @@ def typedef_gridoperator(name): return "typedef Dune::PDELab::GridOperator<{}, {}, {}, {}, {}, {}, {}, {}, {}> {};".format(ugfs, vgfs, lop, mb, df, r, r, ucc, vcc, name) -@dune_symbol +@symbol def type_gridoperator(): typedef_gridoperator("GO") return "GO" -@driver_preamble +@preamble('driver') def define_gridoperator(name): gotype = type_gridoperator() ugfs = name_gfs(_form.coefficients()[0].element()) @@ -498,74 +495,74 @@ def define_gridoperator(name): return "{} {}({}, {}, {}, {}, {}, {});".format(gotype, name, ugfs, ucc, vgfs, vcc, lop, mb) -@dune_symbol +@symbol def name_gridoperator(): define_gridoperator("go") return "go" -@driver_preamble +@preamble('driver') def typedef_vector(name): gotype = type_gridoperator() return "typedef {}::Traits::Domain {};".format(gotype, name) -@dune_symbol +@symbol def type_vector(): typedef_vector("V") return "V" -@driver_preamble +@preamble('driver') def define_vector(name): vtype = type_vector() gfs = name_gfs(_form.coefficients()[0].element()) return ["{} {}({});".format(vtype, name, gfs), "{} = 0.0;".format(name)] -@dune_symbol +@symbol def name_vector(): define_vector("x") return "x" -@driver_preamble +@preamble('driver') def typedef_linearsolver(name): include_file("dune/pdelab/backend/istlsolverbackend.hh", filetag="driver") return "typedef Dune::PDELab::ISTLBackend_SEQ_UMFPack {};".format(name) -@dune_symbol +@symbol def type_linearsolver(): typedef_linearsolver("LinearSolver") return "LinearSolver" -@driver_preamble +@preamble('driver') def define_linearsolver(name): lstype = type_linearsolver() return "{} {}(false);".format(lstype, name) -@dune_symbol +@symbol def name_linearsolver(): define_linearsolver("ls") return "ls" -@driver_preamble +@preamble('driver') def define_reduction(name): ini = name_initree() return "double {} = {}.get<double>(\"reduction\", 1e-12);".format(name, ini) -@dune_symbol +@symbol def name_reduction(): define_reduction("reduction") return "reduction" -@dune_symbol +@symbol def typedef_stationarylinearproblemsolver(name): include_file("dune/pdelab/stationary/linearproblem.hh", filetag="driver") gotype = type_gridoperator() @@ -574,13 +571,13 @@ def typedef_stationarylinearproblemsolver(name): return "typedef Dune::PDELab::StationaryLinearProblemSolver<{}, {}, {}> {}".format(gotype, lstype, xtype, name) -@dune_symbol +@symbol def type_stationarylinearproblemsolver(): typedef_stationarylinearproblemsolver("SLP") return "SLP" -@driver_preamble +@preamble('driver') def define_stationarylinearproblemsolver(name): slptype = type_stationarylinearproblemsolver() go = name_gridoperator() @@ -590,13 +587,13 @@ def define_stationarylinearproblemsolver(name): return "{} {}({}, {}, {}, {});".format(slptype, name, go, ls, x, red) -@dune_symbol +@symbol def name_stationarylinearproblemsolver(): define_stationarylinearproblemsolver("slp") return "slp" -@driver_preamble +@preamble('driver') def dune_solve(): from ufl.algorithms.predicates import is_multilinear # This is crap as it does check for linearity of the rank 1 form, @@ -608,20 +605,20 @@ def dune_solve(): raise NotImplementedError -@driver_preamble +@preamble('driver') def define_vtkfile(name): ini = name_initree() include_file("string", filetag="driver") return "std::string {} = {}.get<std::string>(\"vtk.filename\", \"output\");".format(name, ini) -@dune_symbol +@symbol def name_vtkfile(): define_vtkfile("vtkfile") return "vtkfile" -@driver_preamble +@preamble('driver') def vtkoutput(): include_file("dune/pdelab/gridfunctionspace/vtk.hh", filetag="driver") vtkwriter = name_vtkwriter() diff --git a/python/dune/perftool/pdelab/geometry.py b/python/dune/perftool/pdelab/geometry.py index 020f5d87..f100aea1 100644 --- a/python/dune/perftool/pdelab/geometry.py +++ b/python/dune/perftool/pdelab/geometry.py @@ -1,13 +1,13 @@ -from dune.perftool.pdelab import dune_symbol +from dune.perftool.generation import symbol -@dune_symbol +@symbol def name_dimension(): # TODO preamble define_dimension return "dim" -@dune_symbol +@symbol def name_facetarea(): # TODO preambles return "farea" diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py index 49d5eda0..52274e00 100644 --- a/python/dune/perftool/pdelab/localoperator.py +++ b/python/dune/perftool/pdelab/localoperator.py @@ -1,16 +1,11 @@ from __future__ import absolute_import from dune.perftool.options import get_option -from dune.perftool.generation import generator_factory, include_file -from dune.perftool.pdelab import dune_symbol +from dune.perftool.generation import include_file, base_class, symbol from dune.perftool.cgen.clazz import BaseClass, ClassMember from pytools import memoize -# Define the generators used in-here -public_base_class = generator_factory(item_tags=("baseclass", "operator"), on_store=lambda n: BaseClass(n), counted=True, no_deco=True) - - @generator_factory(item_tags=("initializer", "operator"), counted=True, cache_key_generator=lambda *a: a[0]) def initializer_list(obj, params): return "{}({})".format(obj, ", ".join(params)) @@ -29,14 +24,14 @@ def constructor_parameter(_type, name): return Value(_type, name) -@dune_symbol +@symbol def name_initree_constructor(): include_file('dune/common/parametertree.hh', filetag="operator") constructor_parameter("const Dune::ParameterTree&", "iniParams") return "iniParams" -@dune_symbol +@symbol def name_initree_member(): include_file('dune/common/parametertree.hh', filetag="operator") define_private_member("const Dune::ParameterTree&", "_iniParams") @@ -45,7 +40,7 @@ def name_initree_member(): return "_iniParams" -@dune_symbol +@symbol def localoperator_type(): # TODO use something from the form here to make it unique return "LocalOperator" @@ -61,7 +56,7 @@ def measure_specific_details(measure): # Add a base class from dune.perftool.pdelab.driver import type_localoperator loptype = type_localoperator() - public_base_class("Dune::PDELab::NumericalJacobian{}<{}>".format(which, loptype)) + base_class("Dune::PDELab::NumericalJacobian{}<{}>".format(which, loptype)) # Add the initializer list for that base class ini = name_initree_member() @@ -69,7 +64,7 @@ def measure_specific_details(measure): ["{}.get(\"numerical_epsilon.{}\", 1e-9)".format(ini, which.lower())]) if measure == "cell": - public_base_class('Dune::PDELab::FullVolumePattern') + base_class('Dune::PDELab::FullVolumePattern') numerical_jacobian("Volume") ret["residual_signature"] = ['template<typename EG, typename LFSV0, typename X, typename LFSV1, typename R>', @@ -78,7 +73,7 @@ def measure_specific_details(measure): 'void jacobian_volume(const EG& eg, const LFSV0& lfsv0, const X& x, const LFSV1& lfsv1, J& jac) const'] if measure == "exterior_facet": - public_base_class('Dune::PDELab::FullBoundaryPattern') + base_class('Dune::PDELab::FullBoundaryPattern') numerical_jacobian("Boundary") ret["residual_signature"] = ['template<typename IG, typename LFSV0, typename X, typename LFSV1, typename R>', @@ -87,7 +82,7 @@ def measure_specific_details(measure): 'void jacobian_boundary(const IG& ig, const LFSV0& lfsv0, const X& x, const LFSV1& lfsv1, J& jac) const'] if measure == "interior_facet": - public_base_class('Dune::PDELab::FullSkeletonPattern') + base_class('Dune::PDELab::FullSkeletonPattern') numerical_jacobian("Skeleton") ret["residual_signature"] = ['template<typename IG, typename LFSV0_S, typename X, typename LFSV1_S, typename LFSV0_N, typename R, typename LFSV1_N>', @@ -176,7 +171,7 @@ def generate_localoperator_kernels(form): include_file('dune/pdelab/localoperator/pattern.hh', filetag="operator") include_file('dune/geometry/quadraturerules.hh', filetag="operator") - public_base_class('Dune::PDELab::LocalOperatorDefaultFlags') + base_class('Dune::PDELab::LocalOperatorDefaultFlags') # Have a data structure collect the generated kernels operator_kernels = {} diff --git a/python/dune/perftool/pdelab/quadrature.py b/python/dune/perftool/pdelab/quadrature.py index 514e94ee..882be08d 100644 --- a/python/dune/perftool/pdelab/quadrature.py +++ b/python/dune/perftool/pdelab/quadrature.py @@ -1,9 +1,9 @@ -from dune.perftool.generation import generator_factory from dune.perftool.loopy.transformer import quadrature_iname, loopy_temporary_variable -from dune.perftool.pdelab import dune_symbol, quadrature_preamble, dune_preamble +from dune.perftool.generation import symbol +from dune.perftool.pdelab import quadrature_preamble -@dune_symbol +@symbol def quadrature_rule(): return "rule" @@ -14,7 +14,7 @@ def define_quadrature_factor(fac): return "auto {} = {}->weight();".format(fac, rule) -@dune_symbol +@symbol def name_factor(): loopy_temporary_variable("fac") define_quadrature_factor("fac") -- GitLab