Skip to content
Snippets Groups Projects
Commit 056ac183 authored by René Heß's avatar René Heß
Browse files

[skip ci] Make sure to register get_* methods

These methods are part of the driver block class interface and should also be
generated if they are not used from the main function.
parent c60b7385
No related branches found
No related tags found
No related merge requests found
...@@ -337,6 +337,7 @@ def name_gfs(element, is_dirichlet, treepath=(), root=True, main=False): ...@@ -337,6 +337,7 @@ def name_gfs(element, is_dirichlet, treepath=(), root=True, main=False):
main_define_gfs(element, is_dirichlet, name, root) main_define_gfs(element, is_dirichlet, name, root)
else: else:
define_gfs(element, is_dirichlet, name, root) define_gfs(element, is_dirichlet, name, root)
driver_block_get_gridfunctionsspace(element, is_dirichlet, root, name=name)
return name return name
...@@ -546,9 +547,10 @@ def main_name_trial_subgfs(treepath): ...@@ -546,9 +547,10 @@ def main_name_trial_subgfs(treepath):
@class_member(classtag="driver_block") @class_member(classtag="driver_block")
def driver_block_get_gridfunctionsspace(element, is_dirichlet, root): def driver_block_get_gridfunctionsspace(element, is_dirichlet, root, name=None):
gfs_type = type_gfs(element, is_dirichlet, root=root) gfs_type = type_gfs(element, is_dirichlet, root=root)
name = name_gfs(element, is_dirichlet, root=root) if not name:
name = name_gfs(element, is_dirichlet, root=root)
return ["std::shared_ptr<{}> get_gridfunctionsspace(){{".format(gfs_type), return ["std::shared_ptr<{}> get_gridfunctionsspace(){{".format(gfs_type),
" return {};".format(name), " return {};".format(name),
"}"] "}"]
......
...@@ -78,6 +78,7 @@ def define_gridoperator(name, form_ident): ...@@ -78,6 +78,7 @@ def define_gridoperator(name, form_ident):
def name_gridoperator(form_ident): def name_gridoperator(form_ident):
name = "go_{}".format(form_ident) name = "go_{}".format(form_ident)
define_gridoperator(name, form_ident) define_gridoperator(name, form_ident)
driver_block_get_gridoperator(form_ident, name=name)
return name return name
...@@ -96,9 +97,10 @@ def main_type_gridoperator(form_ident): ...@@ -96,9 +97,10 @@ def main_type_gridoperator(form_ident):
@class_member(classtag="driver_block") @class_member(classtag="driver_block")
def driver_block_get_gridoperator(form_ident): def driver_block_get_gridoperator(form_ident, name=None):
gridoperator_type = type_gridoperator(form_ident) gridoperator_type = type_gridoperator(form_ident)
name = name_gridoperator(form_ident) if not name:
name = name_gridoperator(form_ident)
return ["std::shared_ptr<{}> get_gridoperator(){{".format(gridoperator_type), return ["std::shared_ptr<{}> get_gridoperator(){{".format(gridoperator_type),
" return {};".format(name), " return {};".format(name),
"}"] "}"]
......
...@@ -112,7 +112,8 @@ def name_boundary_grid_function(element, func): ...@@ -112,7 +112,8 @@ def name_boundary_grid_function(element, func):
assert isinstance(element, (FiniteElement, TensorProductElement)) assert isinstance(element, (FiniteElement, TensorProductElement))
name = "boundary_grid_function" name = "boundary_grid_function"
define_boundary_grid_function(name, func) define_boundary_grid_function(name, func)
return name driver_block_get_boundarygridfunction(element, func, name=name)
return name
def boundary_lambda(func): def boundary_lambda(func):
...@@ -182,8 +183,9 @@ def name_boundary_function(func): ...@@ -182,8 +183,9 @@ def name_boundary_function(func):
@class_member(classtag="driver_block") @class_member(classtag="driver_block")
def driver_block_get_boundarygridfunction(element, func): def driver_block_get_boundarygridfunction(element, func, name=None):
name = name_boundary_grid_function(element, func) if not name:
name = name_boundary_grid_function(element, func)
bgf_type = type_boundary_grid_function(func) bgf_type = type_boundary_grid_function(func)
return ["std::shared_ptr<{}> get_boundarygridfunction(){{".format(bgf_type), return ["std::shared_ptr<{}> get_boundarygridfunction(){{".format(bgf_type),
" return {};".format(name), " return {};".format(name),
...@@ -204,7 +206,7 @@ def main_type_boundary_grid_function(func): ...@@ -204,7 +206,7 @@ def main_type_boundary_grid_function(func):
@preamble(section="postprocessing", kernel="main") @preamble(section="postprocessing", kernel="main")
def main_define_boundar_grid_function(name, element, func): def main_define_boundary_grid_function(name, element, func):
driver_block_name = name_driver_block() driver_block_name = name_driver_block()
driver_block_get_boundarygridfunction(element, func) driver_block_get_boundarygridfunction(element, func)
return "auto {} = {}.get_boundarygridfunction();".format(name, driver_block_name) return "auto {} = {}.get_boundarygridfunction();".format(name, driver_block_name)
...@@ -213,7 +215,6 @@ def main_define_boundar_grid_function(name, element, func): ...@@ -213,7 +215,6 @@ def main_define_boundar_grid_function(name, element, func):
@cached @cached
def main_name_boundary_grid_function(element, func): def main_name_boundary_grid_function(element, func):
assert isinstance(func, tuple) assert isinstance(func, tuple)
name = "boundary_grid_function" name = "boundary_grid_function"
main_define_boundar_grid_function(name, element, func) main_define_boundary_grid_function(name, element, func)
return name return name
...@@ -105,6 +105,11 @@ def define_vector(name, form_ident): ...@@ -105,6 +105,11 @@ def define_vector(name, form_ident):
def name_vector(form_ident): def name_vector(form_ident):
name = "x_{}".format(form_ident) name = "x_{}".format(form_ident)
define_vector(name, form_ident) define_vector(name, form_ident)
# Register get method
driver_block_get_coefficient(form_ident, name=name)
# Interpolate dirichlet boundary condition
interpolate_dirichlet_data(name) interpolate_dirichlet_data(name)
return name return name
...@@ -124,9 +129,10 @@ def main_type_vector(form_ident): ...@@ -124,9 +129,10 @@ def main_type_vector(form_ident):
@class_member(classtag="driver_block") @class_member(classtag="driver_block")
def driver_block_get_coefficient(form_ident): def driver_block_get_coefficient(form_ident, name=None):
vector_type = type_vector(form_ident) vector_type = type_vector(form_ident)
name = name_vector(form_ident) if not name:
name = name_vector(form_ident)
return ["std::shared_ptr<{}> get_coefficient(){{".format(vector_type), return ["std::shared_ptr<{}> get_coefficient(){{".format(vector_type),
" return {};".format(name), " return {};".format(name),
"}"] "}"]
...@@ -233,13 +239,18 @@ def define_stationarylinearproblemsolver(name): ...@@ -233,13 +239,18 @@ def define_stationarylinearproblemsolver(name):
def name_stationarylinearproblemsolver(): def name_stationarylinearproblemsolver():
name = "slp" name = "slp"
define_stationarylinearproblemsolver(name) define_stationarylinearproblemsolver(name)
# Register get method
driver_block_get_solver(name=name)
return name return name
@class_member(classtag="driver_block") @class_member(classtag="driver_block")
def driver_block_get_solver(): def driver_block_get_solver(name=None):
solver_type = type_stationarylinearproblemsolver() solver_type = type_stationarylinearproblemsolver()
name = name_stationarylinearproblemsolver() if not name:
name = name_stationarylinearproblemsolver()
return ["std::shared_ptr<{}> get_solver(){{".format(solver_type), return ["std::shared_ptr<{}> get_solver(){{".format(solver_type),
" return {};".format(name), " return {};".format(name),
"}"] "}"]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment