From 46498bbbe997f205d827606c7e1ec6ebf0f8ef04 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Thu, 3 Aug 2017 17:08:52 +0200
Subject: [PATCH] Fix the instationary driver stuff

---
 .../dune/perftool/pdelab/driver/__init__.py   |  4 ++
 .../perftool/pdelab/driver/instationary.py    | 56 ++++++++++++-------
 python/dune/perftool/pdelab/driver/solve.py   |  6 +-
 python/dune/perftool/pdelab/driver/vtk.py     |  2 +-
 4 files changed, 45 insertions(+), 23 deletions(-)

diff --git a/python/dune/perftool/pdelab/driver/__init__.py b/python/dune/perftool/pdelab/driver/__init__.py
index b441722e..b5efebf3 100644
--- a/python/dune/perftool/pdelab/driver/__init__.py
+++ b/python/dune/perftool/pdelab/driver/__init__.py
@@ -63,6 +63,10 @@ def get_formdata():
     return _driver_data['formdata']
 
 
+def get_mass_formdata():
+    return _driver_data["mass_formdata"]
+
+
 def is_stationary():
     return 'mass_form' not in _driver_data
 
diff --git a/python/dune/perftool/pdelab/driver/instationary.py b/python/dune/perftool/pdelab/driver/instationary.py
index bf8759e7..8b5ebc7f 100644
--- a/python/dune/perftool/pdelab/driver/instationary.py
+++ b/python/dune/perftool/pdelab/driver/instationary.py
@@ -1,11 +1,33 @@
-from dune.perftool.generation import preamble
+from dune.perftool.generation import (include_file,
+                                      preamble,
+                                      )
 from dune.perftool.pdelab.driver import (get_formdata,
+                                         get_mass_formdata,
+                                         get_trial_element,
                                          is_linear,
+                                         name_initree,
+                                         preprocess_leaf_data,
                                          )
-from dune.perftool.pdelab.driver.gridoperator import (type_gridoperator,)
+from dune.perftool.pdelab.driver.gridfunctionspace import (name_gfs,
+                                                           type_range,
+                                                           )
+from dune.perftool.pdelab.driver.gridoperator import (name_gridoperator,
+                                                      name_parameters,
+                                                      type_gridoperator,)
+from dune.perftool.pdelab.driver.constraints import (name_bctype_function,
+                                                     name_constraintscontainer,
+                                                     )
+from dune.perftool.pdelab.driver.interpolate import name_boundary_function
 from dune.perftool.pdelab.driver.solve import (print_matrix,
                                                print_residual,
+                                               name_stationarynonlinearproblemsolver,
+                                               name_vector,
+                                               type_stationarynonlinearproblemssolver,
+                                               type_vector,
                                                )
+from dune.perftool.pdelab.driver.vtk import (name_vtk_sequence_writer,
+                                             visualize_initial_condition,
+                                             )
 from dune.perftool.options import get_option
 
 
@@ -28,12 +50,6 @@ def solve_instationary():
 
     print_residual()
     print_matrix()
-    from dune.perftool.pdelab.driver.error import compare_dofs, compare_L2_squared
-    if get_option("exact_solution_expression"):
-        if get_option("compare_dofs"):
-            compare_dofs()
-        if get_option("compare_l2errorsquared"):
-            compare_L2_squared()
 
 
 @preamble
@@ -42,10 +58,11 @@ def time_loop():
     formdata = get_formdata()
     params = name_parameters(formdata)
     time = name_time()
-    expr = get_trial_element()
-    bctype = name_bctype_function(expr)
-    gfs = name_gfs(expr)
-    cc = name_constraintscontainer(expr)
+    element = get_trial_element()
+    is_dirichlet = preprocess_leaf_data(element, "is_dirichlet")
+    bctype = name_bctype_function(element, is_dirichlet)
+    gfs = name_gfs(element, is_dirichlet)
+    cc = name_constraintscontainer()
     vector_type = type_vector(formdata)
     vector = name_vector(formdata)
 
@@ -55,7 +72,8 @@ def time_loop():
         osm = name_explicitonestepmethod()
         apply_call = "{}.apply(time, dt, {}, {}new);".format(osm, vector, vector)
     else:
-        boundary = name_boundary_function(expr)
+        dirichlet = preprocess_leaf_data(element, "dirichlet_expression")
+        boundary = name_boundary_function(element, dirichlet)
         osm = name_onestepmethod()
         apply_call = "{}.apply(time, dt, {}, {}, {}new);".format(osm, vector, boundary, vector)
 
@@ -129,8 +147,8 @@ def name_timesteppingmethod():
 @preamble
 def typedef_instationarygridoperator(name):
     include_file("dune/pdelab/gridoperator/onestep.hh", filetag="driver")
-    go_type = type_gridoperator(_driver_data['formdata'])
-    mass_go_type = type_gridoperator(_driver_data['mass_formdata'])
+    go_type = type_gridoperator(get_formdata())
+    mass_go_type = type_gridoperator(get_mass_formdata())
     explicit = get_option('explicit_time_stepping')
     if explicit:
         return "using {} = Dune::PDELab::OneStepGridOperator<{},{},false>;".format(name, go_type, mass_go_type)
@@ -146,8 +164,8 @@ def type_instationarygridoperator():
 @preamble
 def define_instationarygridoperator(name):
     igo_type = type_instationarygridoperator()
-    go = name_gridoperator(_driver_data['formdata'])
-    mass_go = name_gridoperator(_driver_data['mass_formdata'])
+    go = name_gridoperator(get_formdata())
+    mass_go = name_gridoperator(get_mass_formdata())
     return "{} {}({}, {});".format(igo_type, name, go, mass_go)
 
 
@@ -161,7 +179,7 @@ def typedef_onestepmethod(name):
     r_type = type_range()
     igo_type = type_instationarygridoperator()
     snp_type = type_stationarynonlinearproblemssolver(igo_type)
-    vector_type = type_vector(_driver_data['formdata'])
+    vector_type = type_vector(get_formdata())
     return "using {} = Dune::PDELab::OneStepMethod<{}, {}, {}, {}, {}>;".format(name, r_type, igo_type, snp_type, vector_type, vector_type)
 
 
@@ -190,7 +208,7 @@ def typedef_explicitonestepmethod(name):
     r_type = type_range()
     igo_type = type_instationarygridoperator()
     ls_type = type_linearsolver()
-    vector_type = type_vector(_driver_data['formdata'])
+    vector_type = type_vector(get_formdata())
     return "using {} = Dune::PDELab::ExplicitOneStepMethod<{}, {}, {}, {}>;".format(name, r_type, igo_type, ls_type, vector_type)
 
 
diff --git a/python/dune/perftool/pdelab/driver/solve.py b/python/dune/perftool/pdelab/driver/solve.py
index 0179fe68..5abbac27 100644
--- a/python/dune/perftool/pdelab/driver/solve.py
+++ b/python/dune/perftool/pdelab/driver/solve.py
@@ -40,8 +40,8 @@ def dune_solve():
         include_file("dune/perftool/matrixfree.hh", filetag="driver")
         solve = "solveNonlinearMatrixFree({},{});".format(go, x)
     elif not linear and not matrix_free:
-        go_type = type_gridoperator(_driver_data['formdata'])
-        go = name_gridoperator(_driver_data['formdata'])
+        go_type = type_gridoperator(get_formdata())
+        go = name_gridoperator(get_formdata())
         snp = name_stationarynonlinearproblemsolver(go_type, go)
         solve = "{}.apply();".format(snp)
 
@@ -172,7 +172,7 @@ def name_stationarylinearproblemsolver():
 def typedef_stationarynonlinearproblemsolver(name, go_type):
     include_file("dune/pdelab/newton/newton.hh", filetag="driver")
     ls_type = type_linearsolver()
-    x_type = type_vector(_driver_data['formdata'])
+    x_type = type_vector(get_formdata())
     return "using {} = Dune::PDELab::Newton<{}, {}, {}>;".format(name, go_type, ls_type, x_type)
 
 
diff --git a/python/dune/perftool/pdelab/driver/vtk.py b/python/dune/perftool/pdelab/driver/vtk.py
index 2d2fa20a..71783adc 100644
--- a/python/dune/perftool/pdelab/driver/vtk.py
+++ b/python/dune/perftool/pdelab/driver/vtk.py
@@ -125,7 +125,7 @@ def visualize_initial_condition():
     vtkwriter = name_vtk_sequence_writer()
     element = get_trial_element()
     define_gfs_names(element)
-    gfs = name_gfs(element)
+    gfs = name_trial_gfs()
     vector = name_vector(get_formdata())
     predicate = name_predicate()
     from dune.perftool.pdelab.driver.instationary import name_time
-- 
GitLab