From 056ac183a1225119ab36508ecb3920aa4f100de1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20He=C3=9F?= <rene.hess@iwr.uni-heidelberg.de> Date: Wed, 21 Aug 2019 10:12:59 +0200 Subject: [PATCH] [skip ci] Make sure to register get_* methods These methods are part of the driver block class interface and should also be generated if they are not used from the main function. --- .../pdelab/driver/gridfunctionspace.py | 6 ++++-- .../codegen/pdelab/driver/gridoperator.py | 6 ++++-- .../dune/codegen/pdelab/driver/interpolate.py | 13 +++++++------ python/dune/codegen/pdelab/driver/solve.py | 19 +++++++++++++++---- 4 files changed, 30 insertions(+), 14 deletions(-) diff --git a/python/dune/codegen/pdelab/driver/gridfunctionspace.py b/python/dune/codegen/pdelab/driver/gridfunctionspace.py index e86077a6..4649a08b 100644 --- a/python/dune/codegen/pdelab/driver/gridfunctionspace.py +++ b/python/dune/codegen/pdelab/driver/gridfunctionspace.py @@ -337,6 +337,7 @@ def name_gfs(element, is_dirichlet, treepath=(), root=True, main=False): main_define_gfs(element, is_dirichlet, name, root) else: define_gfs(element, is_dirichlet, name, root) + driver_block_get_gridfunctionsspace(element, is_dirichlet, root, name=name) return name @@ -546,9 +547,10 @@ def main_name_trial_subgfs(treepath): @class_member(classtag="driver_block") -def driver_block_get_gridfunctionsspace(element, is_dirichlet, root): +def driver_block_get_gridfunctionsspace(element, is_dirichlet, root, name=None): gfs_type = type_gfs(element, is_dirichlet, root=root) - name = name_gfs(element, is_dirichlet, root=root) + if not name: + name = name_gfs(element, is_dirichlet, root=root) return ["std::shared_ptr<{}> get_gridfunctionsspace(){{".format(gfs_type), " return {};".format(name), "}"] diff --git a/python/dune/codegen/pdelab/driver/gridoperator.py b/python/dune/codegen/pdelab/driver/gridoperator.py index 2780344d..b57d4a1f 100644 --- a/python/dune/codegen/pdelab/driver/gridoperator.py +++ b/python/dune/codegen/pdelab/driver/gridoperator.py @@ -78,6 +78,7 @@ def define_gridoperator(name, form_ident): def name_gridoperator(form_ident): name = "go_{}".format(form_ident) define_gridoperator(name, form_ident) + driver_block_get_gridoperator(form_ident, name=name) return name @@ -96,9 +97,10 @@ def main_type_gridoperator(form_ident): @class_member(classtag="driver_block") -def driver_block_get_gridoperator(form_ident): +def driver_block_get_gridoperator(form_ident, name=None): gridoperator_type = type_gridoperator(form_ident) - name = name_gridoperator(form_ident) + if not name: + name = name_gridoperator(form_ident) return ["std::shared_ptr<{}> get_gridoperator(){{".format(gridoperator_type), " return {};".format(name), "}"] diff --git a/python/dune/codegen/pdelab/driver/interpolate.py b/python/dune/codegen/pdelab/driver/interpolate.py index 9814d59e..271ec382 100644 --- a/python/dune/codegen/pdelab/driver/interpolate.py +++ b/python/dune/codegen/pdelab/driver/interpolate.py @@ -112,7 +112,8 @@ def name_boundary_grid_function(element, func): assert isinstance(element, (FiniteElement, TensorProductElement)) name = "boundary_grid_function" define_boundary_grid_function(name, func) - return name + driver_block_get_boundarygridfunction(element, func, name=name) + return name def boundary_lambda(func): @@ -182,8 +183,9 @@ def name_boundary_function(func): @class_member(classtag="driver_block") -def driver_block_get_boundarygridfunction(element, func): - name = name_boundary_grid_function(element, func) +def driver_block_get_boundarygridfunction(element, func, name=None): + if not name: + name = name_boundary_grid_function(element, func) bgf_type = type_boundary_grid_function(func) return ["std::shared_ptr<{}> get_boundarygridfunction(){{".format(bgf_type), " return {};".format(name), @@ -204,7 +206,7 @@ def main_type_boundary_grid_function(func): @preamble(section="postprocessing", kernel="main") -def main_define_boundar_grid_function(name, element, func): +def main_define_boundary_grid_function(name, element, func): driver_block_name = name_driver_block() driver_block_get_boundarygridfunction(element, func) return "auto {} = {}.get_boundarygridfunction();".format(name, driver_block_name) @@ -213,7 +215,6 @@ def main_define_boundar_grid_function(name, element, func): @cached def main_name_boundary_grid_function(element, func): assert isinstance(func, tuple) - name = "boundary_grid_function" - main_define_boundar_grid_function(name, element, func) + main_define_boundary_grid_function(name, element, func) return name diff --git a/python/dune/codegen/pdelab/driver/solve.py b/python/dune/codegen/pdelab/driver/solve.py index 4d4f01ed..1506cbe2 100644 --- a/python/dune/codegen/pdelab/driver/solve.py +++ b/python/dune/codegen/pdelab/driver/solve.py @@ -105,6 +105,11 @@ def define_vector(name, form_ident): def name_vector(form_ident): name = "x_{}".format(form_ident) define_vector(name, form_ident) + + # Register get method + driver_block_get_coefficient(form_ident, name=name) + + # Interpolate dirichlet boundary condition interpolate_dirichlet_data(name) return name @@ -124,9 +129,10 @@ def main_type_vector(form_ident): @class_member(classtag="driver_block") -def driver_block_get_coefficient(form_ident): +def driver_block_get_coefficient(form_ident, name=None): vector_type = type_vector(form_ident) - name = name_vector(form_ident) + if not name: + name = name_vector(form_ident) return ["std::shared_ptr<{}> get_coefficient(){{".format(vector_type), " return {};".format(name), "}"] @@ -233,13 +239,18 @@ def define_stationarylinearproblemsolver(name): def name_stationarylinearproblemsolver(): name = "slp" define_stationarylinearproblemsolver(name) + + # Register get method + driver_block_get_solver(name=name) + return name @class_member(classtag="driver_block") -def driver_block_get_solver(): +def driver_block_get_solver(name=None): solver_type = type_stationarylinearproblemsolver() - name = name_stationarylinearproblemsolver() + if not name: + name = name_stationarylinearproblemsolver() return ["std::shared_ptr<{}> get_solver(){{".format(solver_type), " return {};".format(name), "}"] -- GitLab