From f873b0fa804876d5887347c3fd6ae05b3343ec95 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ren=C3=A9=20He=C3=9F?= <rene.hess@iwr.uni-heidelberg.de>
Date: Mon, 22 Aug 2016 13:02:06 +0200
Subject: [PATCH] Replace _form in driver with _drive_data dictionary

---
 python/dune/perftool/compile.py       |  2 +-
 python/dune/perftool/pdelab/driver.py | 61 ++++++++++++++-------------
 2 files changed, 32 insertions(+), 31 deletions(-)

diff --git a/python/dune/perftool/compile.py b/python/dune/perftool/compile.py
index e4a3869a..8d221141 100644
--- a/python/dune/perftool/compile.py
+++ b/python/dune/perftool/compile.py
@@ -65,7 +65,7 @@ def compile_form():
         if get_option("driver_file"):
             with cache_context('driver', delete=True):
                 from dune.perftool.pdelab.driver import generate_driver
-                generate_driver(formdatas[0].preprocessed_form, get_option("driver_file"))
+                generate_driver(formdatas, data)
 
     # In case of multiple forms: Genarate one file that includes all localoperator files
     if len(formdatas) > 1:
diff --git a/python/dune/perftool/pdelab/driver.py b/python/dune/perftool/pdelab/driver.py
index f26143f2..73361786 100644
--- a/python/dune/perftool/pdelab/driver.py
+++ b/python/dune/perftool/pdelab/driver.py
@@ -42,13 +42,12 @@ fem_metadata_dependent_symbol = generator_factory(item_tags=("symbol",),
 # Have a global variable with the entire form data. This allows functions that depend
 # deterministically on the entire data set to directly access it instead of passing it
 # through the entire generator chain.
-_form = None
+_driver_data = {}
 
 
 # Have a function access this global data structure
-def set_form(f):
-    global _form
-    _form = f
+def set_driver_data(formdatas, datas):
+    _driver_data['form'] = formdatas[0].preprocessed_form
 
 
 def is_linear(form):
@@ -134,7 +133,7 @@ def name_initree():
 
 @preamble
 def define_dimension(name):
-    return "static const int {} = {};".format(name, _form.ufl_cell().geometric_dimension())
+    return "static const int {} = {};".format(name, _driver_data['form'].ufl_cell().geometric_dimension())
 
 
 @symbol
@@ -146,11 +145,11 @@ def name_dimension():
 @preamble
 def typedef_grid(name):
     dim = name_dimension()
-    if any(_form.ufl_cell().cellname() in x for x in ["vertex", "interval", "quadrilateral", "hexahedron"]):
+    if any(_driver_data['form'].ufl_cell().cellname() in x for x in ["vertex", "interval", "quadrilateral", "hexahedron"]):
         gridt = "Dune::YaspGrid<{}>".format(dim)
         include_file("dune/grid/yaspgrid.hh", filetag="driver")
     else:
-        if any(_form.ufl_cell().cellname() in x for x in ["triangle", "tetrahedron"]):
+        if any(_driver_data['form'].ufl_cell().cellname() in x for x in ["triangle", "tetrahedron"]):
             # gridt = "Dune::UGGrid<{}>".format(dim)
             # include_file("dune/grid/uggrid.hh", filetag="driver")
             gridt = "Dune::ALUGrid<{}, {}, Dune::simplex, Dune::conforming>".format(dim, dim)
@@ -537,11 +536,11 @@ def name_gfs(expr):
 @preamble
 def define_dofestimate(name):
     # Provide a worstcase estimate for the number of entries per row based on the given gridfunction space and cell geometry
-    if isQuadrilateral(_form.coefficients()[0].ufl_element()):
+    if isQuadrilateral(_driver_data['form'].coefficients()[0].ufl_element()):
         geo_factor = "4"
     else:
         geo_factor = "6"
-    gfs = name_gfs(_form.coefficients()[0].ufl_element())
+    gfs = name_gfs(_driver_data['form'].coefficients()[0].ufl_element())
     ini = name_initree()
     return ["int generic_dof_estimate =  {} * {}.maxLocalSize();".format(geo_factor, gfs),
             "int {} = {}.get<int>(\"istl.number_of_nnz\", generic_dof_estimate);".format(name, ini)]
@@ -605,8 +604,8 @@ def typedef_localoperator(name):
     # No Parameter class here, yet
     # params = type_parameters()
     # return "typedef LocalOperator<{}> {};".format(params, name)
-    ugfs = type_gfs(_form.coefficients()[0].ufl_element())
-    vgfs = type_gfs(_form.arguments()[0].ufl_element())
+    ugfs = type_gfs(_driver_data['form'].coefficients()[0].ufl_element())
+    vgfs = type_gfs(_driver_data['form'].arguments()[0].ufl_element())
     from dune.perftool.generation import get_global_context_value
     data = get_global_context_value("data")
     formdata = get_global_context_value("formdata")
@@ -639,11 +638,11 @@ def name_localoperator():
 
 @preamble
 def typedef_gridoperator(name):
-    ugfs = type_gfs(_form.coefficients()[0].ufl_element())
-    vgfs = type_gfs(_form.arguments()[0].ufl_element())
+    ugfs = type_gfs(_driver_data['form'].coefficients()[0].ufl_element())
+    vgfs = type_gfs(_driver_data['form'].arguments()[0].ufl_element())
     lop = type_localoperator()
-    ucc = type_constraintscontainer(_form.coefficients()[0].ufl_element())
-    vcc = type_constraintscontainer(_form.arguments()[0].ufl_element())
+    ucc = type_constraintscontainer(_driver_data['form'].coefficients()[0].ufl_element())
+    vcc = type_constraintscontainer(_driver_data['form'].arguments()[0].ufl_element())
     mb = type_matrixbackend()
     df = type_domainfield()
     r = type_range()
@@ -660,10 +659,10 @@ def type_gridoperator():
 @preamble
 def define_gridoperator(name):
     gotype = type_gridoperator()
-    ugfs = name_gfs(_form.coefficients()[0].ufl_element())
-    ucc = name_assembled_constraints(_form.coefficients()[0].ufl_element())
-    vgfs = name_gfs(_form.arguments()[0].ufl_element())
-    vcc = name_assembled_constraints(_form.arguments()[0].ufl_element())
+    ugfs = name_gfs(_driver_data['form'].coefficients()[0].ufl_element())
+    ucc = name_assembled_constraints(_driver_data['form'].coefficients()[0].ufl_element())
+    vgfs = name_gfs(_driver_data['form'].arguments()[0].ufl_element())
+    vcc = name_assembled_constraints(_driver_data['form'].arguments()[0].ufl_element())
     lop = name_localoperator()
     mb = name_matrixbackend()
     return ["{} {}({}, {}, {}, {}, {}, {});".format(gotype, name, ugfs, ucc, vgfs, vcc, lop, mb),
@@ -692,7 +691,7 @@ def type_vector():
 @preamble
 def define_vector(name):
     vtype = type_vector()
-    gfs = name_gfs(_form.coefficients()[0].ufl_element())
+    gfs = name_gfs(_driver_data['form'].coefficients()[0].ufl_element())
     return ["{} {}({});".format(vtype, name, gfs), "{} = 0.0;".format(name)]
 
 
@@ -772,7 +771,7 @@ def name_solution_function(expr):
 @preamble
 def interpolate_vector(name):
     define_vector(name)
-    element = _form.coefficients()[0].ufl_element()
+    element = _driver_data['form'].coefficients()[0].ufl_element()
     bf = name_boundary_function(element)
     gfs = name_gfs(element)
     return "Dune::PDELab::interpolate({}, {}, {});".format(bf,
@@ -784,7 +783,7 @@ def interpolate_vector(name):
 @preamble
 def interpolate_solution_expression(name):
     define_vector(name)
-    element = _form.coefficients()[0].ufl_element()
+    element = _driver_data['form'].coefficients()[0].ufl_element()
     sol = name_solution_function(element)
     gfs = name_gfs(element)
     return "Dune::PDELab::interpolate({}, {}, {});".format(sol,
@@ -794,7 +793,7 @@ def interpolate_solution_expression(name):
 
 
 def maybe_interpolate_vector(name):
-    element = _form.coefficients()[0].ufl_element()
+    element = _driver_data['form'].coefficients()[0].ufl_element()
     if has_constraints(element):
         interpolate_vector(name)
     else:
@@ -913,7 +912,7 @@ def name_stationarynonlinearproblemsolver():
 @preamble
 def dune_solve():
     # Test if form is linear in ansatzfunction
-    if is_linear(_form):
+    if is_linear(_driver_data['form']):
         if get_option("matrix_free"):
             go = name_gridoperator()
             x = name_vector()
@@ -975,7 +974,7 @@ def name_discrete_grid_function(gfs, vector_name):
 
 @preamble
 def compare_L2_squared():
-    element = _form.coefficients()[0].ufl_element()
+    element = _driver_data['form'].coefficients()[0].ufl_element()
     v = name_vector()
     gfs = name_gfs(element)
     vdgf = name_discrete_grid_function(gfs, v)
@@ -1064,7 +1063,7 @@ def name_predicate():
 
 @preamble
 def vtkoutput():
-    element = _form.coefficients()[0].ufl_element()
+    element = _driver_data['form'].coefficients()[0].ufl_element()
     define_gfs_name(element)
     include_file("dune/pdelab/gridfunctionspace/vtk.hh", filetag="driver")
     vtkwriter = name_vtkwriter()
@@ -1078,7 +1077,7 @@ def vtkoutput():
 
     if get_option("exact_solution_expression"):
         from ufl import MixedElement, VectorElement, TensorElement
-        if isinstance(_form.coefficients()[0].ufl_element(), (MixedElement, VectorElement, TensorElement)):
+        if isinstance(_driver_data['form'].coefficients()[0].ufl_element(), (MixedElement, VectorElement, TensorElement)):
             raise NotImplementedError("Comparing to exact solution is olny implemented for scalar elements.")
 
         if get_option("compare_dofs"):
@@ -1090,9 +1089,9 @@ def vtkoutput():
             "{}.write({}, Dune::VTK::ascii);".format(vtkwriter, vtkfile)]
 
 
-def generate_driver(form, filename):
-    # The driver module uses a global variable for ease of use
-    set_form(form)
+def generate_driver(formdatas, data):
+    # The driver module uses a global dictionary for storing necessary data
+    set_driver_data(formdatas, data)
 
     # The vtkoutput is the generating method that triggers all others.
     # Alternatively, one could use the `solve` method.
@@ -1104,6 +1103,8 @@ def generate_driver(form, filename):
     driver_body = Block(contents=[i for i in retrieve_cache_items("driver and preamble", make_generable=True)])
     driver = FunctionBody(driver_signature, driver_body)
 
+    filename = get_option("driver_file")
+
     from dune.perftool.file import generate_file
     generate_file(filename, "driver", [driver])
 
-- 
GitLab