Skip to content
Snippets Groups Projects
Commit fddeee02 authored by René Heß's avatar René Heß
Browse files

Fix some bugs

parent 271f0dbc
No related branches found
No related tags found
No related merge requests found
......@@ -15,8 +15,7 @@ from dune.codegen.pdelab.driver.gridfunctionspace import (main_type_trial_gfs,
main_type_range,
main_type_subgfs,
)
from dune.codegen.pdelab.driver.interpolate import (interpolate_vector,
main_name_grid_function,
from dune.codegen.pdelab.driver.interpolate import (main_name_grid_function,
main_type_grid_function,
)
from dune.codegen.pdelab.driver.solve import (define_vector,
......@@ -58,7 +57,7 @@ def define_discrete_grid_function(gfs, vector_name, dgf_name, treepath):
if len(treepath) == 0:
gfs = '*' + gfs
return ["using {} = Dune::PDELab::DiscreteGridFunction<{}, {}>;".format(dgf_type, gfs_type, vector_type),
"{} {}({},*{});".format(dgf_type, dgf_name, gfs, vector_name)]
"{} {}({}, *{});".format(dgf_type, dgf_name, gfs, vector_name)]
def name_discrete_grid_function(gfs, vector_name, treepath):
......
......@@ -61,7 +61,7 @@ def main_type_range():
def typedef_grid(name):
dim = get_dimension()
if isQuadrilateral(get_trial_element().cell()):
range_type = type_range()
range_type = main_type_range()
if get_option("grid_unstructured"):
gridt = "Dune::UGGrid<{}>".format(dim)
include_file("dune/grid/uggrid.hh", filetag="driver")
......
......@@ -78,7 +78,7 @@ def time_loop():
# Choose between explicit and implicit time stepping
explicit = get_option('explicit_time_stepping')
if explicit:
osm = main_name_explicitonestepmethod()
osm = main_name_onestepmethod(is_implicit=False)
apply_call = "{}->apply(time, dt, *{}, {}new);".format(osm, vector, vector)
else:
osm = main_name_onestepmethod()
......@@ -288,7 +288,7 @@ def define_explicitonestepmethod(name):
tsm = name_timesteppingmethod()
igo = name_instationarygridoperator()
ls = name_linearsolver()
return "{} = std::make_shared<{}>({}, {}, {});".format(name, eosm_type, tsm, igo, ls)
return "{} = std::make_shared<{}>(*{}, *{}, *{});".format(name, eosm_type, tsm, igo, ls)
def name_explicitonestepmethod():
......@@ -310,7 +310,7 @@ def driver_block_get_onestepmethod(is_implicit=True, name=None):
else:
if not name:
name = name_explicitonestepmethod
name = name_explicitonestepmethod()
osm_type = type_explicitonestepmethod()
method_name = "getOneStepMethod"
return ["std::shared_ptr<{}> {}(){{".format(osm_type, method_name),
......
......@@ -16,7 +16,8 @@ from dune.codegen.pdelab.driver import (get_form_ident,
from dune.codegen.pdelab.driver.driverblock import (name_driver_block,
type_driver_block,
)
from dune.codegen.pdelab.driver.gridfunctionspace import (name_trial_gfs,
from dune.codegen.pdelab.driver.gridfunctionspace import (main_name_trial_gfs,
name_trial_gfs,
name_leafview,
type_domainfield,
type_leafview,
......@@ -61,7 +62,7 @@ def dune_solve():
elif not linear and not matrix_free:
go_type = type_gridoperator(form_ident)
go = name_gridoperator(form_ident)
snp = name_stationarynonlinearproblemsolver(go_type, go)
snp = main_name_stationarynonlinearproblemsolver(go_type, go)
solve = "{}->apply();".format(snp)
if get_form_option("generate_residuals"):
......@@ -311,6 +312,19 @@ def driver_block_get_nonlinear_solver(go_type, go, name=None):
"}"]
@preamble(section="solver", kernel="main")
def main_define_stationarynonlinearproblemsolver(name, go_type, go):
driver_block_name = name_driver_block()
driver_block_get_nonlinear_solver(go_type, go)
return "auto {} = {}.getSolver();".format(name, driver_block_name)
def main_name_stationarynonlinearproblemsolver(go_type, go):
name = "solver"
main_define_stationarynonlinearproblemsolver(name, go_type, go)
return name
def random_input(v):
include_file("random", system=True, filetag="driver")
return [" // Setup random input",
......@@ -324,7 +338,7 @@ def random_input(v):
def interpolate_input(v):
dim = world_dimension()
gv = name_leafview()
gfs = name_trial_gfs()
gfs = main_name_trial_gfs()
expr = []
for i in range(dim):
expr.append("x[{}]*x[{}]".format(i, i))
......@@ -334,7 +348,7 @@ def interpolate_input(v):
" return std::exp({});".format(expr),
" };",
" auto interpolate = Dune::PDELab::makeGridFunctionFromCallable({}, interpolate_lambda);".format(gv),
" Dune::PDELab::interpolate(interpolate,{},*{});".format(gfs, v),
" Dune::PDELab::interpolate(interpolate, *{}, *{});".format(gfs, v),
]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment