Skip to content
Snippets Groups Projects
Commit 4f1eaa30 authored by René Heß's avatar René Heß
Browse files

[skip ci] Close to working poisson example

parent 09f92c13
No related branches found
No related tags found
No related merge requests found
......@@ -196,7 +196,7 @@ def name_inifile():
return "argv[1]"
@preamble(section="init")
@preamble(section="init", kernel="main")
def parse_initree(varname):
include_file("dune/common/parametertree.hh", filetag="driver")
include_file("dune/common/parametertreeparser.hh", filetag="driver")
......@@ -210,7 +210,7 @@ def name_initree():
return "initree"
@preamble(section="init")
@preamble(section="init", kernel="main")
def define_mpihelper(name):
include_file("dune/common/parallel/mpihelper.hh", filetag="driver")
if get_option("with_mpi"):
......@@ -225,7 +225,7 @@ def name_mpihelper():
return name
@preamble(section="grid")
@preamble(section="grid", kernel="main")
def check_parallel_execution():
from dune.codegen.pdelab.driver.gridfunctionspace import name_leafview
gv = name_leafview()
......@@ -294,7 +294,7 @@ def generate_driver():
'']
def add_section(tag, comment):
tagcontents = [i for i in retrieve_cache_items("preamble and {}".format(tag), make_generable=True)]
tagcontents = [i for i in retrieve_cache_items("preamble and main and {}".format(tag), make_generable=True)]
if tagcontents:
contents.append(LineComment(comment))
contents.append(Line("\n"))
......@@ -315,6 +315,7 @@ def generate_driver():
add_section("vector", "Set up solution vectors...")
add_section("timings", "Maybe take performance measurements...")
add_section("solver", "Set up (non)linear solvers...")
add_section("postprocessing", "Preparations for postprocessing (eg output, error,...)")
add_section("vtk", "Do visualization...")
add_section("instat", "Set up instationary stuff...")
add_section("printing", "Maybe print residuals and matrices to stdout...")
......@@ -346,10 +347,14 @@ def generate_driver():
driver_body = Block([TryCatchBlock(driver_body, catch_blocks)])
driver = FunctionBody(driver_signature, driver_body)
# Generate driver block
from dune.codegen.pdelab.localoperator import cgen_class_from_cache
driver_block = cgen_class_from_cache("driver_block")
filename = get_option("driver_file")
from dune.codegen.file import generate_file
generate_file(filename, "driver", [driver], headerguard=False)
generate_file(filename, "driver", [driver_block, driver], headerguard=False)
# Reset the caching data structure
from dune.codegen.generation import delete_cache_items
......
from dune.codegen.generation import (get_counted_variable,
from dune.codegen.generation import (class_member,
get_counted_variable,
global_context,
include_file,
preamble,
......@@ -35,21 +36,19 @@ def has_dirichlet_constraints(is_dirichlet):
return bool(is_dirichlet)
@preamble(section="constraints")
@preamble(section="constraints", kernel="driver_block")
def assemble_constraints(name):
element = get_trial_element()
gfs = name_trial_gfs()
is_dirichlet = preprocess_leaf_data(element, "is_dirichlet")
if has_dirichlet_constraints(is_dirichlet):
bctype_function = name_bctype_function(element, is_dirichlet)
return "Dune::PDELab::constraints({}, {}, {});".format(bctype_function,
gfs,
name,
)
return "Dune::PDELab::constraints({}, *{}, *{});".format(bctype_function,
gfs,
name,)
else:
return "Dune::PDELab::constraints({}, {});".format(gfs,
name,
)
return "Dune::PDELab::constraints(*{}, *{});".format(gfs,
name,)
def name_bctype_function(element, is_dirichlet):
......@@ -73,7 +72,7 @@ def name_bctype_function(element, is_dirichlet):
return name
@preamble(section="constraints")
@preamble(section="constraints", kernel="driver_block")
def define_bctype_function(element, is_dirichlet, name):
gv = name_leafview()
bctype_lambda = name_bctype_lambda(name, is_dirichlet)
......@@ -84,7 +83,7 @@ def define_bctype_function(element, is_dirichlet, name):
)
@preamble(section="constraints")
@preamble(section="constraints", kernel="driver_block")
def define_composite_bctype_function(element, is_dirichlet, name, subgfs):
include_file('dune/pdelab/constraints/common/constraintsparameters.hh', filetag='driver')
return "Dune::PDELab::CompositeConstraintsParameters<{}> {}({});".format(', '.join('decltype({})'.format(c) for c in subgfs),
......@@ -99,7 +98,7 @@ def name_bctype_lambda(name, func):
return name
@preamble(section="constraints")
@preamble(section="constraints", kernel="driver_block")
def define_intersection_lambda(name, func):
from ufl.classes import Expr
if func is None:
......@@ -113,11 +112,11 @@ def define_intersection_lambda(name, func):
raise ValueError("Expression not understood")
@preamble(section="constraints")
@class_member(classtag="driver_block")
def typedef_constraintscontainer(name):
gfs = type_trial_gfs()
r = type_range()
return "using {} = {}::ConstraintsContainer<{}>::Type;".format(name, gfs, r)
return "using {} = typename {}::template ConstraintsContainer<{}>::Type;".format(name, gfs, r)
def type_constraintscontainer():
......@@ -126,10 +125,18 @@ def type_constraintscontainer():
return name
@preamble(section="constraints")
@class_member(classtag="driver_block")
def declare_constraintscontainer(name):
cctype = type_constraintscontainer()
return "std::shared_ptr<{}> {};".format(cctype, name)
@preamble(section="constraints", kernel="driver_block")
def define_constraintscontainer(name):
declare_constraintscontainer(name)
cctype = type_constraintscontainer()
return ["{} {};".format(cctype, name), "{}.clear();".format(name)]
return ["{} = std::make_shared<{}>();".format(name, cctype),
"{}->clear();".format(name)]
def name_constraintscontainer():
......
from dune.codegen.generation import (class_basename,
class_member,
constructor_parameter,
initializer_list,
preamble,
template_parameter,
)
from dune.codegen.pdelab.driver import (get_form_ident,
name_initree,
)
from dune.codegen.pdelab.driver.gridfunctionspace import (name_leafview,
type_leafview,
)
@template_parameter(classtag="driver_block")
def driver_block_template_parameter():
return type_leafview()
@class_member(classtag="driver_block")
def driver_block_grid_view():
return "{} {};".format(type_leafview(), name_leafview())
def init_driver_block():
driver_block_template_parameter()
gridview_argument = "_{}".format(name_leafview())
constructor_parameter("{}&".format(type_leafview()), gridview_argument, classtag="driver_block")
constructor_parameter("Dune::ParameterTree", name_initree(), classtag="driver_block")
initializer_list(name_leafview(), [gridview_argument], classtag="driver_block")
driver_block_grid_view()
@class_basename(classtag="driver_block")
def driver_block_basename(name):
init_driver_block()
return "DriverBlock{}".format(name.upper())
def type_driver_block():
# palpo TODO: driver_ident!
form_ident = get_form_ident()
return "{}<{}>".format(driver_block_basename(form_ident), type_leafview())
@preamble(section="solver", kernel="main")
def define_driver_block(name):
driver_block_type = type_driver_block()
return "{} {}({}, {});".format(driver_block_type, name, name_leafview(), name_initree())
def name_driver_block():
# palpo TODO: driver_ident!
form_ident = get_form_ident()
name = "driverBlock{}".format(form_ident.capitalize())
define_driver_block(name)
return name
......@@ -24,7 +24,7 @@ from dune.codegen.pdelab.driver.solve import (define_vector,
from ufl import MixedElement, TensorElement, VectorElement
@preamble(section="error")
@preamble(section="error", kernel="main")
def define_test_fail_variable(name):
return 'bool {}(false);'.format(name)
......@@ -49,7 +49,7 @@ def type_discrete_grid_function(gfs):
return "{}_DGF".format(gfs.upper())
@preamble(section="error")
@preamble(section="error", kernel="main")
def define_discrete_grid_function(gfs, vector_name, dgf_name):
dgf_type = type_discrete_grid_function(gfs)
return ["using {} = Dune::PDELab::DiscreteGridFunction<decltype({}),decltype({})>;".format(dgf_type, gfs, vector_name),
......@@ -62,7 +62,7 @@ def name_discrete_grid_function(gfs, vector_name):
return dgf_name
@preamble(section="error")
@preamble(section="error", kernel="main")
def typedef_difference_squared_adapter(name, treepath):
sol = name_exact_solution_gridfunction(treepath)
vector = name_vector(get_form_ident())
......@@ -78,7 +78,7 @@ def type_difference_squared_adapter(treepath):
return name
@preamble(section="error")
@preamble(section="error", kernel="main")
def define_difference_squared_adapter(name, treepath):
t = type_difference_squared_adapter(treepath)
sol = name_exact_solution_gridfunction(treepath)
......@@ -95,7 +95,7 @@ def name_difference_squared_adapter(treepath):
return name
@preamble(section="error")
@preamble(section="error", kernel="main")
def _accumulate_L2_squared(treepath):
dsa = name_difference_squared_adapter(treepath)
accum_error = name_accumulated_L2_error()
......@@ -158,7 +158,7 @@ def accumulate_L2_squared():
_accumulate_L2_squared(())
@preamble(section="error")
@preamble(section="error", kernel="main")
def define_accumulated_L2_error(name):
t = type_range()
return "Dune::FieldVector<{}, 1> {}(0.0);".format(t, name)
......@@ -170,7 +170,7 @@ def name_accumulated_L2_error():
return name
@preamble(section="error")
@preamble(section="error", kernel="main")
def compare_L2_squared():
accumulate_L2_squared()
gv = name_leafview()
......@@ -186,7 +186,7 @@ def compare_L2_squared():
" {} = true;".format(fail)]
@preamble(section="return_stmt")
@preamble(section="return_stmt", kernel="main")
def return_statement():
fail = name_test_fail_variable()
return "return {};".format(fail)
from dune.codegen.error import CodegenUnsupportedFiniteElementError
from dune.codegen.generation import (include_file,
from dune.codegen.generation import (class_basename,
class_member,
include_file,
preamble,
)
from dune.codegen.options import (get_form_option,
......@@ -20,10 +22,10 @@ from dune.codegen.loopy.target import type_floatingpoint
from ufl import FiniteElement, MixedElement, TensorElement, VectorElement, TensorProductElement, TensorProductCell
@preamble(section="grid")
@class_member(classtag="driver_block")
def typedef_domainfield(name):
gridt = type_grid()
return "using {} = {}::ctype;".format(name, gridt)
gridt = type_leafview()
return "using {} = typename {}::ctype;".format(name, gridt)
def type_domainfield():
......@@ -31,7 +33,7 @@ def type_domainfield():
return "DF"
@preamble(section="init")
@class_member(classtag="driver_block")
def typedef_range(name):
return "using {} = {};".format(name, type_floatingpoint())
......@@ -42,7 +44,7 @@ def type_range():
return name
@preamble(section="grid")
@preamble(section="grid", kernel="main")
def typedef_grid(name):
dim = get_dimension()
if isQuadrilateral(get_trial_element().cell()):
......@@ -72,7 +74,7 @@ def type_grid():
return name
@preamble(section="grid")
@preamble(section="grid", kernel="main")
def define_grid(name):
include_file("dune/testtools/gridconstruction.hh", filetag="driver")
ini = name_initree()
......@@ -95,7 +97,7 @@ def name_grid():
return name
@preamble(section="grid")
@preamble(section="grid", kernel="main")
def typedef_leafview(name):
grid = type_grid()
return "using {} = {}::LeafGridView;".format(name, grid)
......@@ -107,7 +109,7 @@ def type_leafview():
return name
@preamble(section="grid")
@preamble(section="grid", kernel="main")
def define_leafview(name):
_type = type_leafview()
grid = name_grid()
......@@ -128,7 +130,7 @@ def get_short_name(element):
return element._short_name
@preamble(section="fem")
@class_member(classtag="driver_block")
def typedef_fem(element, name):
gv = type_leafview()
df = type_domainfield()
......@@ -220,16 +222,23 @@ def type_fem(element):
return name
@preamble(section="fem")
@class_member(classtag="driver_block")
def declare_fem(element, name):
fem_type = type_fem(element)
return "std::shared_ptr<{}> {};".format(fem_type, name)
@preamble(section="fem", kernel="driver_block")
def define_fem(element, name):
declare_fem(element, name)
femtype = type_fem(element)
# Determine whether the FEM is grid-dependent - currently on the Lagrange elements are
if get_short_name(element) == "CG":
gv = name_leafview()
return "{} {}({});".format(femtype, name, gv)
return "{} = std::make_shared<{}>({});".format(name, femtype, gv)
else:
return "{} {};".format(femtype, name)
return "{} = std::make_shared<{}>();".format(name, femtype)
def name_fem(element):
......@@ -277,6 +286,7 @@ def name_gfs(element, is_dirichlet, treepath=(), root=True):
name = "{}{}_gfs_{}".format(FEM_name_mangling(element).lower(),
"_dirichlet" if is_dirichlet[0] else "",
"_".join(str(t) for t in treepath))
define_gfs(element, is_dirichlet, name, root)
return name
......@@ -320,16 +330,23 @@ def type_gfs(element, is_dirichlet, root=True):
return name
@preamble(section="gfs")
@class_member(classtag="driver_block")
def declare_gfs(element, is_dirichlet, name, root):
gfstype = type_gfs(element, is_dirichlet, root=root)
return "std::shared_ptr<{}> {};".format(gfstype, name)
@preamble(section="gfs", kernel="driver_block")
def define_gfs(element, is_dirichlet, name, root):
declare_gfs(element, is_dirichlet, name, root)
gfstype = type_gfs(element, is_dirichlet, root=root)
gv = name_leafview()
fem = name_fem(element)
return ["{} {}({}, {});".format(gfstype, name, gv, fem),
"{}.name(\"{}\");".format(name, name)]
return ["{} = std::make_shared<{}>({}, *{});".format(name, gfstype, gv, fem),
"{}->name(\"{}\");".format(name, name)]
@preamble(section="gfs")
@preamble(section="gfs", kernel="driver_block")
def define_power_gfs(element, is_dirichlet, name, subgfs, root):
gfstype = type_gfs(element, is_dirichlet, root=root)
names = ["using namespace Dune::Indices;"]
......@@ -337,14 +354,14 @@ def define_power_gfs(element, is_dirichlet, name, subgfs, root):
return ["{} {}({});".format(gfstype, name, subgfs)] + names
@preamble(section="gfs")
@preamble(section="gfs", kernel="driver_block")
def define_composite_gfs(element, is_dirichlet, name, subgfs, root):
gfstype = type_gfs(element, is_dirichlet, root=root)
return ["{} {}({});".format(gfstype, name, ", ".join(subgfs)),
"{}.update();".format(name)]
@preamble(section="gfs")
@class_member(classtag="driver_block")
def typedef_gfs(element, is_dirichlet, name, root):
vb = type_vectorbackend(element, root)
gv = type_leafview()
......@@ -354,7 +371,7 @@ def typedef_gfs(element, is_dirichlet, name, root):
return "using {} = Dune::PDELab::GridFunctionSpace<{}, {}, {}, {}>;".format(name, gv, fem, cass, vb)
@preamble(section="gfs")
@class_member(classtag="driver_block")
def typedef_power_gfs(element, is_dirichlet, name, subgfs, root):
include_file("dune/pdelab/gridfunctionspace/powergridfunctionspace.hh", filetag="driver")
vb = type_vectorbackend(element, root)
......@@ -363,7 +380,7 @@ def typedef_power_gfs(element, is_dirichlet, name, subgfs, root):
return "using {} = Dune::PDELab::PowerGridFunctionSpace<{}, {}, {}, {}>;".format(name, subgfs, element.num_sub_elements(), vb, ot)
@preamble(section="gfs")
@class_member(classtag="driver_block")
def typedef_composite_gfs(element, name, subgfs, root):
vb = type_vectorbackend(element, root)
ot = type_orderingtag(isinstance(element, FiniteElement))
......@@ -371,7 +388,7 @@ def typedef_composite_gfs(element, name, subgfs, root):
return "using {} = Dune::PDELab::CompositeGridFunctionSpace<{}, {}, {}>;".format(name, vb, ot, args)
@preamble(section="gfs")
@class_member(classtag="driver_block")
def typedef_vectorbackend(name, element, root):
include_file("dune/pdelab/backend/istl.hh", filetag="driver")
if get_form_option("fastdg") and root:
......@@ -401,25 +418,25 @@ def type_orderingtag(leaf):
return "Dune::PDELab::EntityBlockedOrderingTag"
@preamble(section="gfs")
@class_member(classtag="driver_block")
def typedef_overlapping_dirichlet_constraintsassembler(name):
include_file("dune/pdelab/constraints/conforming.hh", filetag="driver")
return "using {} = Dune::PDELab::ConformingDirichletConstraints;".format(name)
@preamble(section="gfs")
@class_member(classtag="driver_block")
def typedef_p0parallel_constraintsassembler(name):
include_file("dune/pdelab/constraints/p0.hh", filetag="driver")
return "using {} = Dune::PDELab::P0ParallelConstraints;".format(name)
@preamble(section="gfs")
@class_member(classtag="driver_block")
def typedef_dirichlet_constraintsassembler(name):
include_file("dune/pdelab/constraints/conforming.hh", filetag="driver")
return "using {} = Dune::PDELab::ConformingDirichletConstraints;".format(name)
@preamble(section="gfs")
@class_member(classtag="driver_block")
def typedef_no_constraintsassembler(name):
return "using {} = Dune::PDELab::NoConstraints;".format(name)
......@@ -457,7 +474,7 @@ def name_subgfs(treepath):
return name
@preamble(section="vtk")
@preamble(section="vtk", kernel="driver_block")
def define_subgfs(name, treepath):
t = type_subgfs(treepath)
gfs = name_trial_gfs()
......
from dune.codegen.generation import (get_global_context_value,
from dune.codegen.generation import (class_member,
get_global_context_value,
include_file,
preamble,
)
......@@ -12,6 +13,9 @@ from dune.codegen.pdelab.driver import (get_cell,
from dune.codegen.pdelab.driver.constraints import (name_assembled_constraints,
type_constraintscontainer,
)
from dune.codegen.pdelab.driver.driverblock import (name_driver_block,
type_driver_block,
)
from dune.codegen.pdelab.driver.gridfunctionspace import (name_test_gfs,
name_trial_gfs,
preprocess_leaf_data,
......@@ -24,7 +28,7 @@ from dune.codegen.pdelab.localoperator import localoperator_basename
from dune.codegen.options import get_form_option
@preamble(section="gridoperator")
@class_member(classtag="driver_block")
def typedef_gridoperator(name, form_ident):
ugfs = type_trial_gfs()
vgfs = type_test_gfs()
......@@ -49,8 +53,15 @@ def type_gridoperator(form_ident):
return name
@preamble(section="gridoperator")
@class_member(classtag="driver_block")
def declare_gridoperator(name, form_ident):
gotype = type_gridoperator(form_ident)
return "std::shared_ptr<{}> {};".format(gotype, name)
@preamble(section="gridoperator", kernel="driver_block")
def define_gridoperator(name, form_ident):
declare_gridoperator(name, form_ident)
gotype = type_gridoperator(form_ident)
ugfs = name_trial_gfs()
vgfs = name_test_gfs()
......@@ -59,9 +70,9 @@ def define_gridoperator(name, form_ident):
cc = name_assembled_constraints()
lop = name_localoperator(form_ident)
mb = name_matrixbackend()
return ["{} {}({}, {}, {}, {}, {}, {});".format(gotype, name, ugfs, cc, vgfs, cc, lop, mb),
"std::cout << \"gfs with \" << {}.size() << \" dofs generated \"<< std::endl;".format(ugfs),
"std::cout << \"cc with \" << {}.size() << \" dofs generated \"<< std::endl;".format(cc)]
return ["{} = std::make_shared<{}>(*{}, *{}, *{}, *{}, *{}, *{});".format(name, gotype, ugfs, cc, vgfs, cc, lop, mb),
"std::cout << \"gfs with \" << {}->size() << \" dofs generated \"<< std::endl;".format(ugfs),
"std::cout << \"cc with \" << {}->size() << \" dofs generated \"<< std::endl;".format(cc)]
def name_gridoperator(form_ident):
......@@ -70,7 +81,43 @@ def name_gridoperator(form_ident):
return name
@preamble(section="gridoperator")
@class_member(classtag="driver_block")
def driver_block_get_gridoperator(form_ident):
gridoperator_type = type_gridoperator(form_ident)
name = name_gridoperator(form_ident)
return ["std::shared_ptr<{}> get_gridoperator(){{".format(gridoperator_type),
" return {};".format(name),
"}"]
@preamble(section="postprocessing", kernel="main")
def main_define_gridoperator(name, form_ident):
driver_block_name = name_driver_block()
driver_block_get_gridoperator(form_ident)
return "auto {} = {}.get_gridoperator();".format(name, driver_block_name)
@preamble(section="postprocessing", kernel="main")
def main_typedef_gridoperator(name, form_ident):
driver_block_type = type_driver_block()
db_gridoperator_type = type_gridoperator(form_ident)
gridoperator_type = "using {} = {}::{};".format(name, driver_block_type, db_gridoperator_type)
return gridoperator_type
def main_type_gridoperator(form_ident):
name = "GO_{}".format(form_ident)
main_typedef_gridoperator(name, form_ident)
return name
def main_name_gridoperator(form_ident):
name = "go_{}".format(form_ident)
main_define_gridoperator(name, form_ident)
return name
@class_member(classtag="driver_block")
def typedef_localoperator(name, form_ident):
ugfs = type_trial_gfs()
vgfs = type_test_gfs()
......@@ -86,13 +133,20 @@ def type_localoperator(form_ident):
return name
@preamble(section="gridoperator")
@class_member(classtag="driver_block")
def declare_localoperator(name, form_ident):
loptype = type_localoperator(form_ident)
return "std::shared_ptr<{}> {};".format(loptype, name)
@preamble(section="gridoperator", kernel="driver_block")
def define_localoperator(name, form_ident):
declare_localoperator(name, form_ident)
trial_gfs = name_trial_gfs()
test_gfs = name_test_gfs()
loptype = type_localoperator(form_ident)
ini = name_initree()
return "{} {}({}, {}, {});".format(loptype, name, trial_gfs, test_gfs, ini)
return "{} = std::make_shared<{}>(*{}, *{}, {});".format(name, loptype, trial_gfs, test_gfs, ini)
def name_localoperator(form_ident):
......@@ -101,7 +155,7 @@ def name_localoperator(form_ident):
return name
@preamble(section="gridoperator")
@preamble(section="gridoperator", kernel="driver_block")
def define_dofestimate(name):
# Provide a worstcase estimate for the number of entries per row based
# on the given gridfunction space and cell geometry
......@@ -117,8 +171,8 @@ def define_dofestimate(name):
gfs = name_trial_gfs()
ini = name_initree()
return ["{}.update();".format(gfs),
"int generic_dof_estimate = {} * {}.maxLocalSize();".format(geo_factor, gfs),
return ["{}->update();".format(gfs),
"int generic_dof_estimate = {} * {}->maxLocalSize();".format(geo_factor, gfs),
"int {} = {}.get<int>(\"istl.number_of_nnz\", generic_dof_estimate);".format(name, ini)]
......@@ -128,7 +182,7 @@ def name_dofestimate():
return name
@preamble(section="gridoperator")
@class_member(classtag="driver_block")
def typedef_matrixbackend(name):
include_file("dune/pdelab/backend/istl.hh", filetag="driver")
return "using {} = Dune::PDELab::ISTL::BCRSMatrixBackend<>;".format(name)
......@@ -140,11 +194,18 @@ def type_matrixbackend():
return name
@preamble(section="gridoperator")
@class_member(classtag="driver_block")
def declare_matrixbackend(name):
mbtype = type_matrixbackend()
return "std::shared_ptr<{}> {};".format(mbtype, name)
@preamble(section="gridoperator", kernel="driver_block")
def define_matrixbackend(name):
declare_matrixbackend(name)
mbtype = type_matrixbackend()
dof = name_dofestimate()
return "{} {}({});".format(mbtype, name, dof)
return "{} = std::make_shared<{}>({});".format(name, mbtype, dof)
def name_matrixbackend():
......
......@@ -128,7 +128,7 @@ def name_time():
return "time"
@preamble(section="instat")
@class_member(classtag="driver_block")
def typedef_timesteppingmethod(name):
r_type = type_range()
explicit = get_option('explicit_time_stepping')
......@@ -178,7 +178,7 @@ def name_timesteppingmethod():
return "tsm"
@preamble(section="gridoperator")
@class_member(classtag="driver_block")
def typedef_instationarygridoperator(name):
include_file("dune/pdelab/gridoperator/onestep.hh", filetag="driver")
go_type = type_gridoperator(get_form_ident())
......@@ -208,7 +208,7 @@ def name_instationarygridoperator():
return "igo"
@preamble(section="instat")
@class_member(classtag="driver_block")
def typedef_onestepmethod(name):
r_type = type_range()
igo_type = type_instationarygridoperator()
......@@ -237,7 +237,7 @@ def name_onestepmethod():
return "osm"
@preamble(section="instat")
@class_member(classtag="driver_block")
def typedef_explicitonestepmethod(name):
r_type = type_range()
igo_type = type_instationarygridoperator()
......
......@@ -25,11 +25,11 @@ def interpolate_dirichlet_data(name):
interpolate_vector(bf, gfs, name)
@preamble(section="vector")
@preamble(section="vector", kernel="driver_block")
def interpolate_vector(func, gfs, name):
return "Dune::PDELab::interpolate({}, {}, {});".format(func,
gfs,
name,
*gfs,
*name,
)
......@@ -53,14 +53,14 @@ def name_boundary_function(element, func):
return name
@preamble(section="vector")
@preamble(section="vector", kernel="driver_block")
def define_compositegfs_parameterfunction(name, children):
return "Dune::PDELab::CompositeGridFunction<{}> {}({});".format(', '.join('decltype({})'.format(c) for c in children),
name,
', '.join(children))
@preamble(section="vector")
@preamble(section="vector", kernel="driver_block")
def define_boundary_function(name, dirichlet):
gv = name_leafview()
lambdaname = name_boundary_lambda(dirichlet)
......@@ -87,7 +87,7 @@ def name_boundary_lambda(boundary):
return name
@preamble(section="vector")
@preamble(section="vector", kernel="driver_block")
def define_boundary_lambda(name, boundary):
if boundary is None:
boundary = 0.0
......
from dune.codegen.generation import (include_file,
from dune.codegen.generation import (class_basename,
class_member,
constructor_parameter,
include_file,
initializer_list,
preamble,
template_parameter
)
from dune.codegen.options import (get_form_option,
get_option,
......@@ -8,24 +13,31 @@ from dune.codegen.pdelab.driver import (get_form_ident,
is_linear,
name_initree,
)
from dune.codegen.pdelab.driver.driverblock import (name_driver_block,
type_driver_block,
)
from dune.codegen.pdelab.driver.gridfunctionspace import (name_trial_gfs,
name_leafview,
type_domainfield,
type_leafview,
type_trial_gfs,
)
from dune.codegen.pdelab.driver.constraints import (type_constraintscontainer,
name_assembled_constraints,
)
from dune.codegen.pdelab.driver.gridoperator import (name_gridoperator,
from dune.codegen.pdelab.driver.gridoperator import (main_name_gridoperator,
main_type_gridoperator,
name_gridoperator,
type_gridoperator,
)
from dune.codegen.pdelab.driver.interpolate import interpolate_dirichlet_data
from dune.codegen.pdelab.geometry import world_dimension
@preamble(section="solver")
@preamble(section="solver", kernel="main")
def dune_solve():
form_ident = get_form_ident()
# Test if form is linear in ansatzfunction
linear = is_linear()
......@@ -38,8 +50,8 @@ def dune_solve():
include_file("dune/codegen/matrixfree.hh", filetag="driver")
solve = "solveMatrixFree({},{});".format(go, x)
elif linear and not matrix_free:
slp = name_stationarylinearproblemsolver()
solve = "{}.apply();".format(slp)
slp = main_name_stationarylinearproblemsolver()
solve = "{}->apply();".format(slp)
elif not linear and matrix_free:
# TODO copy of linear case and obviously broken, used to generate something ;)
go = name_gridoperator(form_ident)
......@@ -63,14 +75,7 @@ def dune_solve():
return solve
def name_vector(form_ident):
name = "x_{}".format(form_ident)
define_vector(name, form_ident)
interpolate_dirichlet_data(name)
return name
@preamble(section="vector")
@class_member(classtag="driver_block")
def typedef_vector(name, form_ident):
gfs = type_trial_gfs()
df = type_domainfield()
......@@ -83,14 +88,65 @@ def type_vector(form_ident):
return name
@preamble(section="vector")
@class_member(classtag="driver_block")
def declare_vector(name, form_ident):
vtype = type_vector(form_ident)
return "std::shared_ptr<{}> {};".format(vtype, name)
@preamble(section="vector", kernel="driver_block")
def define_vector(name, form_ident):
declare_vector(name, form_ident)
vtype = type_vector(form_ident)
gfs = name_trial_gfs()
return ["{} {}({});".format(vtype, name, gfs), "{} = 0.0;".format(name)]
return ["{} = std::make_shared<{}>(*{});".format(name, vtype, gfs), "*{} = 0.0;".format(name)]
def name_vector(form_ident):
name = "x_{}".format(form_ident)
define_vector(name, form_ident)
interpolate_dirichlet_data(name)
return name
@preamble(section="postprocessing", kernel="main")
def main_typedef_vector(name, form_ident):
driver_block_type = type_driver_block()
db_vector_type = type_vector(form_ident)
vector_type = "using {} = {}::{};".format(name, driver_block_type, db_vector_type)
return vector_type
def main_type_vector(form_ident):
name = "V_{}".format(form_ident.upper())
main_typedef_vector(name, form_ident)
return name
@class_member(classtag="driver_block")
def driver_block_get_coefficient(form_ident):
vector_type = type_vector(form_ident)
name = name_vector(form_ident)
return ["std::shared_ptr<{}> get_coefficient(){{".format(vector_type),
" return {};".format(name),
"}"]
@preamble(section="postprocessing", kernel="main")
def main_define_vector(name, form_ident):
driver_block_name = name_driver_block()
driver_block_get_coefficient(form_ident)
return "auto {} = {}.get_coefficient();".format(name, driver_block_name)
@preamble(section="solver")
def main_name_vector(form_ident):
name = "x_{}".format(form_ident)
main_define_vector(name, form_ident)
interpolate_dirichlet_data(name)
return name
@class_member(classtag="driver_block")
def typedef_linearsolver(name):
include_file("dune/pdelab/backend/istl.hh", filetag="driver")
if get_option('overlapping'):
......@@ -107,15 +163,22 @@ def type_linearsolver():
return name
@preamble(section="solver")
@class_member(classtag="driver_block")
def declare_linearsolver(name):
lstype = type_linearsolver()
return "std::shared_ptr<{}> {};".format(lstype, name)
@preamble(section="solver", kernel="driver_block")
def define_linearsolver(name):
declare_linearsolver(name)
lstype = type_linearsolver()
if get_option('overlapping'):
gfs = name_trial_gfs()
cc = name_assembled_constraints()
return "{} {}({}, {});".format(lstype, name, gfs, cc)
return "{} = std::make_shared<{}>({}, {});".format(name, lstype, gfs, cc)
else:
return "{} {}(false);".format(lstype, name)
return "{} = std::make_shared<{}>(false);".format(name, lstype)
def name_linearsolver():
......@@ -124,7 +187,7 @@ def name_linearsolver():
return name
@preamble(section="solver")
@preamble(section="solver", kernel="driver_block")
def define_reduction(name):
ini = name_initree()
return "double {} = {}.get<double>(\"reduction\", 1e-12);".format(name, ini)
......@@ -136,7 +199,7 @@ def name_reduction():
return name
@preamble(section="solver")
@class_member(classtag="driver_block")
def typedef_stationarylinearproblemsolver(name):
include_file("dune/pdelab/stationary/linearproblem.hh", filetag="driver")
gotype = type_gridoperator(get_form_ident())
......@@ -150,22 +213,52 @@ def type_stationarylinearproblemsolver():
return "SLP"
@preamble(section="solver")
@class_member(classtag="driver_block")
def declare_stationarylinearproblemsolver(name):
slptype = type_stationarylinearproblemsolver()
return "std::shared_ptr<{}> {};".format(slptype, name)
@preamble(section="solver", kernel="driver_block")
def define_stationarylinearproblemsolver(name):
declare_stationarylinearproblemsolver(name)
slptype = type_stationarylinearproblemsolver()
go = name_gridoperator(get_form_ident())
ls = name_linearsolver()
x = name_vector(get_form_ident())
red = name_reduction()
return "{} {}({}, {}, {}, {});".format(slptype, name, go, ls, x, red)
return "{} = std::make_shared<{}>(*{}, *{}, *{}, {});".format(name, slptype, go, ls, x, red)
def name_stationarylinearproblemsolver():
define_stationarylinearproblemsolver("slp")
return "slp"
name = "slp"
define_stationarylinearproblemsolver(name)
return name
@class_member(classtag="driver_block")
def driver_block_get_solver():
solver_type = type_stationarylinearproblemsolver()
name = name_stationarylinearproblemsolver()
return ["std::shared_ptr<{}> get_solver(){{".format(solver_type),
" return {};".format(name),
"}"]
@preamble(section="solver", kernel="main")
def main_define_stationarylinearproblemsolver(name):
driver_block_name = name_driver_block()
driver_block_get_solver()
return "auto {} = {}.get_solver();".format(name, driver_block_name)
@preamble(section="solver")
def main_name_stationarylinearproblemsolver():
name = "slp"
main_define_stationarylinearproblemsolver(name)
return name
@class_member(classtag="driver_block")
def typedef_stationarynonlinearproblemsolver(name, go_type):
include_file("dune/pdelab/newton/newton.hh", filetag="driver")
ls_type = type_linearsolver()
......@@ -179,7 +272,7 @@ def type_stationarynonlinearproblemssolver(go_type):
return name
@preamble(section="solver")
@preamble(section="solver", kernel="driver_block")
def define_stationarynonlinearproblemsolver(name, go_type, go):
snptype = type_stationarynonlinearproblemssolver(go_type)
x = name_vector(get_form_ident())
......@@ -194,11 +287,12 @@ def name_stationarynonlinearproblemsolver(go_type, go):
def random_input(v):
include_file("random", system=True, filetag="driver")
return [" // Setup random input",
" std::size_t seed = 0;",
" auto rng = std::mt19937_64(seed);",
" auto dist = std::uniform_real_distribution<>(-1., 1.);",
" for (auto& v : {})".format(v),
" for (auto& v : *{})".format(v),
" v = dist(rng);"]
......@@ -215,17 +309,16 @@ def interpolate_input(v):
" return std::exp({});".format(expr),
" };",
" auto interpolate = Dune::PDELab::makeGridFunctionFromCallable({}, interpolate_lambda);".format(gv),
" Dune::PDELab::interpolate(interpolate,{},{});".format(gfs, v),
" Dune::PDELab::interpolate(interpolate,{},*{});".format(gfs, v),
]
@preamble(section="printing")
@preamble(section="printing", kernel="main")
def print_residual():
ini = name_initree()
n_go = name_gridoperator(get_form_ident())
v = name_vector(get_form_ident())
t_v = type_vector(get_form_ident())
include_file("random", system=True, filetag="driver")
n_go = main_name_gridoperator(get_form_ident())
v = main_name_vector(get_form_ident())
t_v = main_type_vector(get_form_ident())
if get_option("debug_interpolate_input"):
input = interpolate_input(v)
......@@ -234,20 +327,20 @@ def print_residual():
return ["if ({}.get<bool>(\"printresidual\", false)) {{".format(ini),
" using Dune::PDELab::Backend::native;",
" {} r({});".format(t_v, v)] + input + \
" {} r(*{});".format(t_v, v)] + input + \
[" r=0.0;",
" {}.residual({}, r);".format(n_go, v),
" {}->residual(*{}, r);".format(n_go, v),
" Dune::printvector(std::cout, native(r), \"residual vector\", \"row\");",
"}"]
@preamble(section="printing")
@preamble(section="printing", kernel="main")
def print_matrix():
ini = name_initree()
t_go = type_gridoperator(get_form_ident())
n_go = name_gridoperator(get_form_ident())
v = name_vector(get_form_ident())
t_v = type_vector(get_form_ident())
t_go = main_type_gridoperator(get_form_ident())
n_go = main_name_gridoperator(get_form_ident())
v = main_name_vector(get_form_ident())
t_v = main_type_vector(get_form_ident())
if get_option("debug_interpolate_input"):
input = interpolate_input(v)
......@@ -256,10 +349,10 @@ def print_matrix():
return ["if ({}.get<bool>(\"printmatrix\", false)) {{".format(ini),
" using Dune::PDELab::Backend::native;",
" {} r({});".format(t_v, v)] + input + \
" {} r(*{});".format(t_v, v)] + input + \
[" using M = typename {}::Traits::Jacobian;".format(t_go),
" M m({});".format(n_go),
" {}.jacobian({},m);".format(n_go, v),
" M m(*{});".format(n_go),
" {}->jacobian(*{},m);".format(n_go, v),
" using Dune::PDELab::Backend::native;",
" Dune::printmatrix(std::cout, native(m),\"global stiffness matrix\",\"row\",9,1);",
"}"]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment