Skip to content
Snippets Groups Projects
Commit f873b0fa authored by René Heß's avatar René Heß
Browse files

Replace _form in driver with _drive_data dictionary

parent 1ba6f550
No related branches found
No related tags found
No related merge requests found
...@@ -65,7 +65,7 @@ def compile_form(): ...@@ -65,7 +65,7 @@ def compile_form():
if get_option("driver_file"): if get_option("driver_file"):
with cache_context('driver', delete=True): with cache_context('driver', delete=True):
from dune.perftool.pdelab.driver import generate_driver from dune.perftool.pdelab.driver import generate_driver
generate_driver(formdatas[0].preprocessed_form, get_option("driver_file")) generate_driver(formdatas, data)
# In case of multiple forms: Genarate one file that includes all localoperator files # In case of multiple forms: Genarate one file that includes all localoperator files
if len(formdatas) > 1: if len(formdatas) > 1:
......
...@@ -42,13 +42,12 @@ fem_metadata_dependent_symbol = generator_factory(item_tags=("symbol",), ...@@ -42,13 +42,12 @@ fem_metadata_dependent_symbol = generator_factory(item_tags=("symbol",),
# Have a global variable with the entire form data. This allows functions that depend # Have a global variable with the entire form data. This allows functions that depend
# deterministically on the entire data set to directly access it instead of passing it # deterministically on the entire data set to directly access it instead of passing it
# through the entire generator chain. # through the entire generator chain.
_form = None _driver_data = {}
# Have a function access this global data structure # Have a function access this global data structure
def set_form(f): def set_driver_data(formdatas, datas):
global _form _driver_data['form'] = formdatas[0].preprocessed_form
_form = f
def is_linear(form): def is_linear(form):
...@@ -134,7 +133,7 @@ def name_initree(): ...@@ -134,7 +133,7 @@ def name_initree():
@preamble @preamble
def define_dimension(name): def define_dimension(name):
return "static const int {} = {};".format(name, _form.ufl_cell().geometric_dimension()) return "static const int {} = {};".format(name, _driver_data['form'].ufl_cell().geometric_dimension())
@symbol @symbol
...@@ -146,11 +145,11 @@ def name_dimension(): ...@@ -146,11 +145,11 @@ def name_dimension():
@preamble @preamble
def typedef_grid(name): def typedef_grid(name):
dim = name_dimension() dim = name_dimension()
if any(_form.ufl_cell().cellname() in x for x in ["vertex", "interval", "quadrilateral", "hexahedron"]): if any(_driver_data['form'].ufl_cell().cellname() in x for x in ["vertex", "interval", "quadrilateral", "hexahedron"]):
gridt = "Dune::YaspGrid<{}>".format(dim) gridt = "Dune::YaspGrid<{}>".format(dim)
include_file("dune/grid/yaspgrid.hh", filetag="driver") include_file("dune/grid/yaspgrid.hh", filetag="driver")
else: else:
if any(_form.ufl_cell().cellname() in x for x in ["triangle", "tetrahedron"]): if any(_driver_data['form'].ufl_cell().cellname() in x for x in ["triangle", "tetrahedron"]):
# gridt = "Dune::UGGrid<{}>".format(dim) # gridt = "Dune::UGGrid<{}>".format(dim)
# include_file("dune/grid/uggrid.hh", filetag="driver") # include_file("dune/grid/uggrid.hh", filetag="driver")
gridt = "Dune::ALUGrid<{}, {}, Dune::simplex, Dune::conforming>".format(dim, dim) gridt = "Dune::ALUGrid<{}, {}, Dune::simplex, Dune::conforming>".format(dim, dim)
...@@ -537,11 +536,11 @@ def name_gfs(expr): ...@@ -537,11 +536,11 @@ def name_gfs(expr):
@preamble @preamble
def define_dofestimate(name): def define_dofestimate(name):
# Provide a worstcase estimate for the number of entries per row based on the given gridfunction space and cell geometry # Provide a worstcase estimate for the number of entries per row based on the given gridfunction space and cell geometry
if isQuadrilateral(_form.coefficients()[0].ufl_element()): if isQuadrilateral(_driver_data['form'].coefficients()[0].ufl_element()):
geo_factor = "4" geo_factor = "4"
else: else:
geo_factor = "6" geo_factor = "6"
gfs = name_gfs(_form.coefficients()[0].ufl_element()) gfs = name_gfs(_driver_data['form'].coefficients()[0].ufl_element())
ini = name_initree() ini = name_initree()
return ["int generic_dof_estimate = {} * {}.maxLocalSize();".format(geo_factor, gfs), return ["int generic_dof_estimate = {} * {}.maxLocalSize();".format(geo_factor, gfs),
"int {} = {}.get<int>(\"istl.number_of_nnz\", generic_dof_estimate);".format(name, ini)] "int {} = {}.get<int>(\"istl.number_of_nnz\", generic_dof_estimate);".format(name, ini)]
...@@ -605,8 +604,8 @@ def typedef_localoperator(name): ...@@ -605,8 +604,8 @@ def typedef_localoperator(name):
# No Parameter class here, yet # No Parameter class here, yet
# params = type_parameters() # params = type_parameters()
# return "typedef LocalOperator<{}> {};".format(params, name) # return "typedef LocalOperator<{}> {};".format(params, name)
ugfs = type_gfs(_form.coefficients()[0].ufl_element()) ugfs = type_gfs(_driver_data['form'].coefficients()[0].ufl_element())
vgfs = type_gfs(_form.arguments()[0].ufl_element()) vgfs = type_gfs(_driver_data['form'].arguments()[0].ufl_element())
from dune.perftool.generation import get_global_context_value from dune.perftool.generation import get_global_context_value
data = get_global_context_value("data") data = get_global_context_value("data")
formdata = get_global_context_value("formdata") formdata = get_global_context_value("formdata")
...@@ -639,11 +638,11 @@ def name_localoperator(): ...@@ -639,11 +638,11 @@ def name_localoperator():
@preamble @preamble
def typedef_gridoperator(name): def typedef_gridoperator(name):
ugfs = type_gfs(_form.coefficients()[0].ufl_element()) ugfs = type_gfs(_driver_data['form'].coefficients()[0].ufl_element())
vgfs = type_gfs(_form.arguments()[0].ufl_element()) vgfs = type_gfs(_driver_data['form'].arguments()[0].ufl_element())
lop = type_localoperator() lop = type_localoperator()
ucc = type_constraintscontainer(_form.coefficients()[0].ufl_element()) ucc = type_constraintscontainer(_driver_data['form'].coefficients()[0].ufl_element())
vcc = type_constraintscontainer(_form.arguments()[0].ufl_element()) vcc = type_constraintscontainer(_driver_data['form'].arguments()[0].ufl_element())
mb = type_matrixbackend() mb = type_matrixbackend()
df = type_domainfield() df = type_domainfield()
r = type_range() r = type_range()
...@@ -660,10 +659,10 @@ def type_gridoperator(): ...@@ -660,10 +659,10 @@ def type_gridoperator():
@preamble @preamble
def define_gridoperator(name): def define_gridoperator(name):
gotype = type_gridoperator() gotype = type_gridoperator()
ugfs = name_gfs(_form.coefficients()[0].ufl_element()) ugfs = name_gfs(_driver_data['form'].coefficients()[0].ufl_element())
ucc = name_assembled_constraints(_form.coefficients()[0].ufl_element()) ucc = name_assembled_constraints(_driver_data['form'].coefficients()[0].ufl_element())
vgfs = name_gfs(_form.arguments()[0].ufl_element()) vgfs = name_gfs(_driver_data['form'].arguments()[0].ufl_element())
vcc = name_assembled_constraints(_form.arguments()[0].ufl_element()) vcc = name_assembled_constraints(_driver_data['form'].arguments()[0].ufl_element())
lop = name_localoperator() lop = name_localoperator()
mb = name_matrixbackend() mb = name_matrixbackend()
return ["{} {}({}, {}, {}, {}, {}, {});".format(gotype, name, ugfs, ucc, vgfs, vcc, lop, mb), return ["{} {}({}, {}, {}, {}, {}, {});".format(gotype, name, ugfs, ucc, vgfs, vcc, lop, mb),
...@@ -692,7 +691,7 @@ def type_vector(): ...@@ -692,7 +691,7 @@ def type_vector():
@preamble @preamble
def define_vector(name): def define_vector(name):
vtype = type_vector() vtype = type_vector()
gfs = name_gfs(_form.coefficients()[0].ufl_element()) gfs = name_gfs(_driver_data['form'].coefficients()[0].ufl_element())
return ["{} {}({});".format(vtype, name, gfs), "{} = 0.0;".format(name)] return ["{} {}({});".format(vtype, name, gfs), "{} = 0.0;".format(name)]
...@@ -772,7 +771,7 @@ def name_solution_function(expr): ...@@ -772,7 +771,7 @@ def name_solution_function(expr):
@preamble @preamble
def interpolate_vector(name): def interpolate_vector(name):
define_vector(name) define_vector(name)
element = _form.coefficients()[0].ufl_element() element = _driver_data['form'].coefficients()[0].ufl_element()
bf = name_boundary_function(element) bf = name_boundary_function(element)
gfs = name_gfs(element) gfs = name_gfs(element)
return "Dune::PDELab::interpolate({}, {}, {});".format(bf, return "Dune::PDELab::interpolate({}, {}, {});".format(bf,
...@@ -784,7 +783,7 @@ def interpolate_vector(name): ...@@ -784,7 +783,7 @@ def interpolate_vector(name):
@preamble @preamble
def interpolate_solution_expression(name): def interpolate_solution_expression(name):
define_vector(name) define_vector(name)
element = _form.coefficients()[0].ufl_element() element = _driver_data['form'].coefficients()[0].ufl_element()
sol = name_solution_function(element) sol = name_solution_function(element)
gfs = name_gfs(element) gfs = name_gfs(element)
return "Dune::PDELab::interpolate({}, {}, {});".format(sol, return "Dune::PDELab::interpolate({}, {}, {});".format(sol,
...@@ -794,7 +793,7 @@ def interpolate_solution_expression(name): ...@@ -794,7 +793,7 @@ def interpolate_solution_expression(name):
def maybe_interpolate_vector(name): def maybe_interpolate_vector(name):
element = _form.coefficients()[0].ufl_element() element = _driver_data['form'].coefficients()[0].ufl_element()
if has_constraints(element): if has_constraints(element):
interpolate_vector(name) interpolate_vector(name)
else: else:
...@@ -913,7 +912,7 @@ def name_stationarynonlinearproblemsolver(): ...@@ -913,7 +912,7 @@ def name_stationarynonlinearproblemsolver():
@preamble @preamble
def dune_solve(): def dune_solve():
# Test if form is linear in ansatzfunction # Test if form is linear in ansatzfunction
if is_linear(_form): if is_linear(_driver_data['form']):
if get_option("matrix_free"): if get_option("matrix_free"):
go = name_gridoperator() go = name_gridoperator()
x = name_vector() x = name_vector()
...@@ -975,7 +974,7 @@ def name_discrete_grid_function(gfs, vector_name): ...@@ -975,7 +974,7 @@ def name_discrete_grid_function(gfs, vector_name):
@preamble @preamble
def compare_L2_squared(): def compare_L2_squared():
element = _form.coefficients()[0].ufl_element() element = _driver_data['form'].coefficients()[0].ufl_element()
v = name_vector() v = name_vector()
gfs = name_gfs(element) gfs = name_gfs(element)
vdgf = name_discrete_grid_function(gfs, v) vdgf = name_discrete_grid_function(gfs, v)
...@@ -1064,7 +1063,7 @@ def name_predicate(): ...@@ -1064,7 +1063,7 @@ def name_predicate():
@preamble @preamble
def vtkoutput(): def vtkoutput():
element = _form.coefficients()[0].ufl_element() element = _driver_data['form'].coefficients()[0].ufl_element()
define_gfs_name(element) define_gfs_name(element)
include_file("dune/pdelab/gridfunctionspace/vtk.hh", filetag="driver") include_file("dune/pdelab/gridfunctionspace/vtk.hh", filetag="driver")
vtkwriter = name_vtkwriter() vtkwriter = name_vtkwriter()
...@@ -1078,7 +1077,7 @@ def vtkoutput(): ...@@ -1078,7 +1077,7 @@ def vtkoutput():
if get_option("exact_solution_expression"): if get_option("exact_solution_expression"):
from ufl import MixedElement, VectorElement, TensorElement from ufl import MixedElement, VectorElement, TensorElement
if isinstance(_form.coefficients()[0].ufl_element(), (MixedElement, VectorElement, TensorElement)): if isinstance(_driver_data['form'].coefficients()[0].ufl_element(), (MixedElement, VectorElement, TensorElement)):
raise NotImplementedError("Comparing to exact solution is olny implemented for scalar elements.") raise NotImplementedError("Comparing to exact solution is olny implemented for scalar elements.")
if get_option("compare_dofs"): if get_option("compare_dofs"):
...@@ -1090,9 +1089,9 @@ def vtkoutput(): ...@@ -1090,9 +1089,9 @@ def vtkoutput():
"{}.write({}, Dune::VTK::ascii);".format(vtkwriter, vtkfile)] "{}.write({}, Dune::VTK::ascii);".format(vtkwriter, vtkfile)]
def generate_driver(form, filename): def generate_driver(formdatas, data):
# The driver module uses a global variable for ease of use # The driver module uses a global dictionary for storing necessary data
set_form(form) set_driver_data(formdatas, data)
# The vtkoutput is the generating method that triggers all others. # The vtkoutput is the generating method that triggers all others.
# Alternatively, one could use the `solve` method. # Alternatively, one could use the `solve` method.
...@@ -1104,6 +1103,8 @@ def generate_driver(form, filename): ...@@ -1104,6 +1103,8 @@ def generate_driver(form, filename):
driver_body = Block(contents=[i for i in retrieve_cache_items("driver and preamble", make_generable=True)]) driver_body = Block(contents=[i for i in retrieve_cache_items("driver and preamble", make_generable=True)])
driver = FunctionBody(driver_signature, driver_body) driver = FunctionBody(driver_signature, driver_body)
filename = get_option("driver_file")
from dune.perftool.file import generate_file from dune.perftool.file import generate_file
generate_file(filename, "driver", [driver]) generate_file(filename, "driver", [driver])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment