diff --git a/python/dune/codegen/pdelab/driver/gridfunctionspace.py b/python/dune/codegen/pdelab/driver/gridfunctionspace.py index e86077a6415da80acb2d7ee2faf132bdfb5b6eb7..4649a08b2bf985466eeccfdfb50f23fb14b550dc 100644 --- a/python/dune/codegen/pdelab/driver/gridfunctionspace.py +++ b/python/dune/codegen/pdelab/driver/gridfunctionspace.py @@ -337,6 +337,7 @@ def name_gfs(element, is_dirichlet, treepath=(), root=True, main=False): main_define_gfs(element, is_dirichlet, name, root) else: define_gfs(element, is_dirichlet, name, root) + driver_block_get_gridfunctionsspace(element, is_dirichlet, root, name=name) return name @@ -546,9 +547,10 @@ def main_name_trial_subgfs(treepath): @class_member(classtag="driver_block") -def driver_block_get_gridfunctionsspace(element, is_dirichlet, root): +def driver_block_get_gridfunctionsspace(element, is_dirichlet, root, name=None): gfs_type = type_gfs(element, is_dirichlet, root=root) - name = name_gfs(element, is_dirichlet, root=root) + if not name: + name = name_gfs(element, is_dirichlet, root=root) return ["std::shared_ptr<{}> get_gridfunctionsspace(){{".format(gfs_type), " return {};".format(name), "}"] diff --git a/python/dune/codegen/pdelab/driver/gridoperator.py b/python/dune/codegen/pdelab/driver/gridoperator.py index 2780344d3ddf6f6027bee455c09f0f524ac4d3cf..b57d4a1f4a9bc8b96297a86ee839565d7079a171 100644 --- a/python/dune/codegen/pdelab/driver/gridoperator.py +++ b/python/dune/codegen/pdelab/driver/gridoperator.py @@ -78,6 +78,7 @@ def define_gridoperator(name, form_ident): def name_gridoperator(form_ident): name = "go_{}".format(form_ident) define_gridoperator(name, form_ident) + driver_block_get_gridoperator(form_ident, name=name) return name @@ -96,9 +97,10 @@ def main_type_gridoperator(form_ident): @class_member(classtag="driver_block") -def driver_block_get_gridoperator(form_ident): +def driver_block_get_gridoperator(form_ident, name=None): gridoperator_type = type_gridoperator(form_ident) - name = name_gridoperator(form_ident) + if not name: + name = name_gridoperator(form_ident) return ["std::shared_ptr<{}> get_gridoperator(){{".format(gridoperator_type), " return {};".format(name), "}"] diff --git a/python/dune/codegen/pdelab/driver/interpolate.py b/python/dune/codegen/pdelab/driver/interpolate.py index 9814d59e58f6f6f7b31164238485e2e4a3f052cd..271ec382a333850abcce05f95babfc309a73f088 100644 --- a/python/dune/codegen/pdelab/driver/interpolate.py +++ b/python/dune/codegen/pdelab/driver/interpolate.py @@ -112,7 +112,8 @@ def name_boundary_grid_function(element, func): assert isinstance(element, (FiniteElement, TensorProductElement)) name = "boundary_grid_function" define_boundary_grid_function(name, func) - return name + driver_block_get_boundarygridfunction(element, func, name=name) + return name def boundary_lambda(func): @@ -182,8 +183,9 @@ def name_boundary_function(func): @class_member(classtag="driver_block") -def driver_block_get_boundarygridfunction(element, func): - name = name_boundary_grid_function(element, func) +def driver_block_get_boundarygridfunction(element, func, name=None): + if not name: + name = name_boundary_grid_function(element, func) bgf_type = type_boundary_grid_function(func) return ["std::shared_ptr<{}> get_boundarygridfunction(){{".format(bgf_type), " return {};".format(name), @@ -204,7 +206,7 @@ def main_type_boundary_grid_function(func): @preamble(section="postprocessing", kernel="main") -def main_define_boundar_grid_function(name, element, func): +def main_define_boundary_grid_function(name, element, func): driver_block_name = name_driver_block() driver_block_get_boundarygridfunction(element, func) return "auto {} = {}.get_boundarygridfunction();".format(name, driver_block_name) @@ -213,7 +215,6 @@ def main_define_boundar_grid_function(name, element, func): @cached def main_name_boundary_grid_function(element, func): assert isinstance(func, tuple) - name = "boundary_grid_function" - main_define_boundar_grid_function(name, element, func) + main_define_boundary_grid_function(name, element, func) return name diff --git a/python/dune/codegen/pdelab/driver/solve.py b/python/dune/codegen/pdelab/driver/solve.py index 4d4f01ed33ac5ea1ac5b82dd8c3bc8d78102da16..1506cbe253584d7c8191243f9f8d47de80747215 100644 --- a/python/dune/codegen/pdelab/driver/solve.py +++ b/python/dune/codegen/pdelab/driver/solve.py @@ -105,6 +105,11 @@ def define_vector(name, form_ident): def name_vector(form_ident): name = "x_{}".format(form_ident) define_vector(name, form_ident) + + # Register get method + driver_block_get_coefficient(form_ident, name=name) + + # Interpolate dirichlet boundary condition interpolate_dirichlet_data(name) return name @@ -124,9 +129,10 @@ def main_type_vector(form_ident): @class_member(classtag="driver_block") -def driver_block_get_coefficient(form_ident): +def driver_block_get_coefficient(form_ident, name=None): vector_type = type_vector(form_ident) - name = name_vector(form_ident) + if not name: + name = name_vector(form_ident) return ["std::shared_ptr<{}> get_coefficient(){{".format(vector_type), " return {};".format(name), "}"] @@ -233,13 +239,18 @@ def define_stationarylinearproblemsolver(name): def name_stationarylinearproblemsolver(): name = "slp" define_stationarylinearproblemsolver(name) + + # Register get method + driver_block_get_solver(name=name) + return name @class_member(classtag="driver_block") -def driver_block_get_solver(): +def driver_block_get_solver(name=None): solver_type = type_stationarylinearproblemsolver() - name = name_stationarylinearproblemsolver() + if not name: + name = name_stationarylinearproblemsolver() return ["std::shared_ptr<{}> get_solver(){{".format(solver_type), " return {};".format(name), "}"]