From f873b0fa804876d5887347c3fd6ae05b3343ec95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20He=C3=9F?= <rene.hess@iwr.uni-heidelberg.de> Date: Mon, 22 Aug 2016 13:02:06 +0200 Subject: [PATCH] Replace _form in driver with _drive_data dictionary --- python/dune/perftool/compile.py | 2 +- python/dune/perftool/pdelab/driver.py | 61 ++++++++++++++------------- 2 files changed, 32 insertions(+), 31 deletions(-) diff --git a/python/dune/perftool/compile.py b/python/dune/perftool/compile.py index e4a3869a..8d221141 100644 --- a/python/dune/perftool/compile.py +++ b/python/dune/perftool/compile.py @@ -65,7 +65,7 @@ def compile_form(): if get_option("driver_file"): with cache_context('driver', delete=True): from dune.perftool.pdelab.driver import generate_driver - generate_driver(formdatas[0].preprocessed_form, get_option("driver_file")) + generate_driver(formdatas, data) # In case of multiple forms: Genarate one file that includes all localoperator files if len(formdatas) > 1: diff --git a/python/dune/perftool/pdelab/driver.py b/python/dune/perftool/pdelab/driver.py index f26143f2..73361786 100644 --- a/python/dune/perftool/pdelab/driver.py +++ b/python/dune/perftool/pdelab/driver.py @@ -42,13 +42,12 @@ fem_metadata_dependent_symbol = generator_factory(item_tags=("symbol",), # Have a global variable with the entire form data. This allows functions that depend # deterministically on the entire data set to directly access it instead of passing it # through the entire generator chain. -_form = None +_driver_data = {} # Have a function access this global data structure -def set_form(f): - global _form - _form = f +def set_driver_data(formdatas, datas): + _driver_data['form'] = formdatas[0].preprocessed_form def is_linear(form): @@ -134,7 +133,7 @@ def name_initree(): @preamble def define_dimension(name): - return "static const int {} = {};".format(name, _form.ufl_cell().geometric_dimension()) + return "static const int {} = {};".format(name, _driver_data['form'].ufl_cell().geometric_dimension()) @symbol @@ -146,11 +145,11 @@ def name_dimension(): @preamble def typedef_grid(name): dim = name_dimension() - if any(_form.ufl_cell().cellname() in x for x in ["vertex", "interval", "quadrilateral", "hexahedron"]): + if any(_driver_data['form'].ufl_cell().cellname() in x for x in ["vertex", "interval", "quadrilateral", "hexahedron"]): gridt = "Dune::YaspGrid<{}>".format(dim) include_file("dune/grid/yaspgrid.hh", filetag="driver") else: - if any(_form.ufl_cell().cellname() in x for x in ["triangle", "tetrahedron"]): + if any(_driver_data['form'].ufl_cell().cellname() in x for x in ["triangle", "tetrahedron"]): # gridt = "Dune::UGGrid<{}>".format(dim) # include_file("dune/grid/uggrid.hh", filetag="driver") gridt = "Dune::ALUGrid<{}, {}, Dune::simplex, Dune::conforming>".format(dim, dim) @@ -537,11 +536,11 @@ def name_gfs(expr): @preamble def define_dofestimate(name): # Provide a worstcase estimate for the number of entries per row based on the given gridfunction space and cell geometry - if isQuadrilateral(_form.coefficients()[0].ufl_element()): + if isQuadrilateral(_driver_data['form'].coefficients()[0].ufl_element()): geo_factor = "4" else: geo_factor = "6" - gfs = name_gfs(_form.coefficients()[0].ufl_element()) + gfs = name_gfs(_driver_data['form'].coefficients()[0].ufl_element()) ini = name_initree() return ["int generic_dof_estimate = {} * {}.maxLocalSize();".format(geo_factor, gfs), "int {} = {}.get<int>(\"istl.number_of_nnz\", generic_dof_estimate);".format(name, ini)] @@ -605,8 +604,8 @@ def typedef_localoperator(name): # No Parameter class here, yet # params = type_parameters() # return "typedef LocalOperator<{}> {};".format(params, name) - ugfs = type_gfs(_form.coefficients()[0].ufl_element()) - vgfs = type_gfs(_form.arguments()[0].ufl_element()) + ugfs = type_gfs(_driver_data['form'].coefficients()[0].ufl_element()) + vgfs = type_gfs(_driver_data['form'].arguments()[0].ufl_element()) from dune.perftool.generation import get_global_context_value data = get_global_context_value("data") formdata = get_global_context_value("formdata") @@ -639,11 +638,11 @@ def name_localoperator(): @preamble def typedef_gridoperator(name): - ugfs = type_gfs(_form.coefficients()[0].ufl_element()) - vgfs = type_gfs(_form.arguments()[0].ufl_element()) + ugfs = type_gfs(_driver_data['form'].coefficients()[0].ufl_element()) + vgfs = type_gfs(_driver_data['form'].arguments()[0].ufl_element()) lop = type_localoperator() - ucc = type_constraintscontainer(_form.coefficients()[0].ufl_element()) - vcc = type_constraintscontainer(_form.arguments()[0].ufl_element()) + ucc = type_constraintscontainer(_driver_data['form'].coefficients()[0].ufl_element()) + vcc = type_constraintscontainer(_driver_data['form'].arguments()[0].ufl_element()) mb = type_matrixbackend() df = type_domainfield() r = type_range() @@ -660,10 +659,10 @@ def type_gridoperator(): @preamble def define_gridoperator(name): gotype = type_gridoperator() - ugfs = name_gfs(_form.coefficients()[0].ufl_element()) - ucc = name_assembled_constraints(_form.coefficients()[0].ufl_element()) - vgfs = name_gfs(_form.arguments()[0].ufl_element()) - vcc = name_assembled_constraints(_form.arguments()[0].ufl_element()) + ugfs = name_gfs(_driver_data['form'].coefficients()[0].ufl_element()) + ucc = name_assembled_constraints(_driver_data['form'].coefficients()[0].ufl_element()) + vgfs = name_gfs(_driver_data['form'].arguments()[0].ufl_element()) + vcc = name_assembled_constraints(_driver_data['form'].arguments()[0].ufl_element()) lop = name_localoperator() mb = name_matrixbackend() return ["{} {}({}, {}, {}, {}, {}, {});".format(gotype, name, ugfs, ucc, vgfs, vcc, lop, mb), @@ -692,7 +691,7 @@ def type_vector(): @preamble def define_vector(name): vtype = type_vector() - gfs = name_gfs(_form.coefficients()[0].ufl_element()) + gfs = name_gfs(_driver_data['form'].coefficients()[0].ufl_element()) return ["{} {}({});".format(vtype, name, gfs), "{} = 0.0;".format(name)] @@ -772,7 +771,7 @@ def name_solution_function(expr): @preamble def interpolate_vector(name): define_vector(name) - element = _form.coefficients()[0].ufl_element() + element = _driver_data['form'].coefficients()[0].ufl_element() bf = name_boundary_function(element) gfs = name_gfs(element) return "Dune::PDELab::interpolate({}, {}, {});".format(bf, @@ -784,7 +783,7 @@ def interpolate_vector(name): @preamble def interpolate_solution_expression(name): define_vector(name) - element = _form.coefficients()[0].ufl_element() + element = _driver_data['form'].coefficients()[0].ufl_element() sol = name_solution_function(element) gfs = name_gfs(element) return "Dune::PDELab::interpolate({}, {}, {});".format(sol, @@ -794,7 +793,7 @@ def interpolate_solution_expression(name): def maybe_interpolate_vector(name): - element = _form.coefficients()[0].ufl_element() + element = _driver_data['form'].coefficients()[0].ufl_element() if has_constraints(element): interpolate_vector(name) else: @@ -913,7 +912,7 @@ def name_stationarynonlinearproblemsolver(): @preamble def dune_solve(): # Test if form is linear in ansatzfunction - if is_linear(_form): + if is_linear(_driver_data['form']): if get_option("matrix_free"): go = name_gridoperator() x = name_vector() @@ -975,7 +974,7 @@ def name_discrete_grid_function(gfs, vector_name): @preamble def compare_L2_squared(): - element = _form.coefficients()[0].ufl_element() + element = _driver_data['form'].coefficients()[0].ufl_element() v = name_vector() gfs = name_gfs(element) vdgf = name_discrete_grid_function(gfs, v) @@ -1064,7 +1063,7 @@ def name_predicate(): @preamble def vtkoutput(): - element = _form.coefficients()[0].ufl_element() + element = _driver_data['form'].coefficients()[0].ufl_element() define_gfs_name(element) include_file("dune/pdelab/gridfunctionspace/vtk.hh", filetag="driver") vtkwriter = name_vtkwriter() @@ -1078,7 +1077,7 @@ def vtkoutput(): if get_option("exact_solution_expression"): from ufl import MixedElement, VectorElement, TensorElement - if isinstance(_form.coefficients()[0].ufl_element(), (MixedElement, VectorElement, TensorElement)): + if isinstance(_driver_data['form'].coefficients()[0].ufl_element(), (MixedElement, VectorElement, TensorElement)): raise NotImplementedError("Comparing to exact solution is olny implemented for scalar elements.") if get_option("compare_dofs"): @@ -1090,9 +1089,9 @@ def vtkoutput(): "{}.write({}, Dune::VTK::ascii);".format(vtkwriter, vtkfile)] -def generate_driver(form, filename): - # The driver module uses a global variable for ease of use - set_form(form) +def generate_driver(formdatas, data): + # The driver module uses a global dictionary for storing necessary data + set_driver_data(formdatas, data) # The vtkoutput is the generating method that triggers all others. # Alternatively, one could use the `solve` method. @@ -1104,6 +1103,8 @@ def generate_driver(form, filename): driver_body = Block(contents=[i for i in retrieve_cache_items("driver and preamble", make_generable=True)]) driver = FunctionBody(driver_signature, driver_body) + filename = get_option("driver_file") + from dune.perftool.file import generate_file generate_file(filename, "driver", [driver]) -- GitLab