Skip to content
Snippets Groups Projects
Commit a517e118 authored by René Heß's avatar René Heß
Browse files

[skip ci] Towards driver block for Stokes

Working:
- Driver block
- Solving and visualization

Not working:
- Error calculation
parent 056ac183
No related branches found
No related tags found
No related merge requests found
......@@ -10,6 +10,7 @@ from dune.codegen.pdelab.driver import (FEM_name_mangling,
from dune.codegen.pdelab.driver.gridfunctionspace import (name_gfs,
name_leafview,
name_trial_gfs,
type_leafview,
type_range,
type_trial_gfs,
preprocess_leaf_data,
......@@ -42,16 +43,106 @@ def assemble_constraints(name):
gfs = name_trial_gfs()
is_dirichlet = preprocess_leaf_data(element, "is_dirichlet")
if has_dirichlet_constraints(is_dirichlet):
bctype_function = name_bctype_function(element, is_dirichlet)
return "Dune::PDELab::constraints({}, *{}, *{});".format(bctype_function,
gfs,
name,)
bctype_function = name_bctype_grid_function(element, is_dirichlet)
return "Dune::PDELab::constraints(*{}, *{}, *{});".format(bctype_function,
gfs,
name,)
else:
return "Dune::PDELab::constraints(*{}, *{});".format(gfs,
name,)
#
# Bctype lambda
#
def bctype_lambda(func, local=False):
from ufl.classes import Expr
if func is None:
func = 0.0
if isinstance(func, (int, float)):
return "[&](const auto& is, const auto& xl){{ return {}; }}".format(float(func))
elif isinstance(func, Expr):
from dune.codegen.pdelab.driver.visitor import ufl_to_code
return "[&](const auto& is, const auto& xl){{ {}; }}".format(ufl_to_code(func))
raise ValueError("Expression not understood")
#
# Bctype function
#
# std::function that will be used to create the bctype grid function
@class_member(classtag="driver_block")
def typedef_bctype_function(name):
leafview_type = type_leafview()
entity = "typename {}::Intersection".format(leafview_type)
coordinate = "typename {}::Intersection::LocalCoordinate".format(leafview_type)
return "using {} = std::function<bool({}, {})>;".format(name, entity, coordinate)
def type_bctype_function():
name = "BctypeFunction"
typedef_bctype_function(name)
return name
@class_member(classtag="driver_block")
def declare_bctype_function(name, element, is_dirichlet):
bctype_function_type = type_bctype_function()
return "std::shared_ptr<{}> {};".format(bctype_function_type, name)
@preamble(section="constriants", kernel="driver_block")
def define_bctype_function(name, element, is_dirichlet):
declare_bctype_function(name, element, is_dirichlet)
bctype_function_type = type_bctype_function()
bct_lambda = bctype_lambda(is_dirichlet)
return "{} = std::make_shared<{}> ({});".format(name, bctype_function_type, bct_lambda)
def name_bctype_function(element, is_dirichlet):
name = get_counted_variable("bctype_function")
define_bctype_function(name, element, is_dirichlet)
return name
#
# Bctype Grid Function
#
# GridFunction that returns 1 if the boundary is constraint (else 0)
@class_member(classtag="driver_block")
def typedef_bctype_grid_function(name):
bctype_function_type = type_bctype_function()
return "using {} = Dune::PDELab::LocalCallableToBoundaryConditionAdapter<{}>;".format(name,
bctype_function_type)
def type_bctype_grid_function():
name = "BctypeGridFunction"
typedef_bctype_grid_function(name)
return name
@class_member(classtag="driver_block")
def declare_bctype_grid_function(name):
bctype_type = type_bctype_grid_function()
return "std::shared_ptr<{}> {};".format(bctype_type, name)
@preamble(section="constraints", kernel="driver_block")
def define_bctype_grid_function(name, element, is_dirichlet):
declare_bctype_grid_function(name)
bctype_type = type_bctype_grid_function()
bctype_function = name_bctype_function(element, is_dirichlet)
return "{} = std::make_shared<{}> (*{});".format(name, bctype_type, bctype_function)
def name_bctype_grid_function(element, is_dirichlet):
# Note: Previously, there was a separate code branch for VectorElement here,
# which was implemented through PDELabs Power constraints concept.
# However this completely fails if you have different constraints for
......@@ -60,56 +151,64 @@ def name_bctype_function(element, is_dirichlet):
k = 0
childs = []
for subel in element.sub_elements():
childs.append(name_bctype_function(subel, is_dirichlet[k:k + subel.value_size()]))
childs.append(name_bctype_grid_function(subel, is_dirichlet[k:k + subel.value_size()]))
k = k + subel.value_size()
name = "{}_bctype".format("_".join(childs))
define_composite_bctype_function(element, is_dirichlet, name, tuple(childs))
name = "{}".format("_".join(childs))
define_composite_constraints_parameters(name, element, tuple(childs))
return name
else:
assert isinstance(element, (FiniteElement, TensorProductElement))
name = get_counted_variable("bctype")
define_bctype_function(element, is_dirichlet[0], name)
name = get_counted_variable("bctype_grid_function")
define_bctype_grid_function(name, element, is_dirichlet[0])
return name
@preamble(section="constraints", kernel="driver_block")
def define_bctype_function(element, is_dirichlet, name):
gv = name_leafview()
bctype_lambda = name_bctype_lambda(name, is_dirichlet)
include_file('dune/pdelab/function/callableadapter.hh', filetag='driver')
return "auto {} = Dune::PDELab::makeBoundaryConditionFromCallable({}, {});".format(name,
gv,
bctype_lambda,
)
#
# Composite constraints parameter
#
@preamble(section="constraints", kernel="driver_block")
def define_composite_bctype_function(element, is_dirichlet, name, subgfs):
include_file('dune/pdelab/constraints/common/constraintsparameters.hh', filetag='driver')
return "Dune::PDELab::CompositeConstraintsParameters<{}> {}({});".format(', '.join('decltype({})'.format(c) for c in subgfs),
name,
', '.join(c for c in subgfs)
)
@class_member(classtag="driver_block")
def typedef_composite_constraints_parameters(name, element, gfs_tuple):
assert isinstance(element, MixedElement)
assert len(element.sub_elements()) == len(gfs_tuple)
types = []
for subel, subgfs in zip(element.sub_elements(), gfs_tuple):
if isinstance(subel, MixedElement):
types.append(type_composite_constriants_parameters(subel, subgfs))
else:
assert isinstance(subel, (FiniteElement, TensorProductElement))
types.append(type_bctype_grid_function())
return "using {} = Dune::PDELab::CompositeConstraintsParameters<{}>;".format(name,
", ".join(t for t in types))
def type_composite_constriants_parameters(element, gfs_tuple):
if isinstance(gfs_tuple, str):
gfs_tuple = (gfs_tuple,)
name = "CompositeConstraintsParameters_{}".format('_'.join(c for c in gfs_tuple))
if len(element.sub_elements()) == len(gfs_tuple):
typedef_composite_constraints_parameters(name, element, gfs_tuple)
return name
def name_bctype_lambda(name, func):
name = name + "_lambda"
define_intersection_lambda(name, func)
return name
@class_member(classtag="driver_block")
def declare_composite_constraints_parameter(name, element, gfs_tuple):
ccp_type = type_composite_constriants_parameters(element, gfs_tuple)
return "std::shared_ptr<{}> {};".format(ccp_type, name)
@preamble(section="constraints", kernel="driver_block")
def define_intersection_lambda(name, func):
from ufl.classes import Expr
if func is None:
func = 0.
if isinstance(func, (int, float)):
return "auto {} = [&](const auto& x){{ return {}; }};".format(name, float(func))
elif isinstance(func, Expr):
from dune.codegen.pdelab.driver.visitor import ufl_to_code
return "auto {} = [&](const auto& is, const auto& xl){{ {} }};".format(name, ufl_to_code(func))
def define_composite_constraints_parameters(name, element, gfs_tuple):
include_file('dune/pdelab/constraints/common/constraintsparameters.hh', filetag='driver')
declare_composite_constraints_parameter(name, element, gfs_tuple)
ccp_type = type_composite_constriants_parameters(element, gfs_tuple)
return "{} = std::make_shared<{}>({});".format(name, ccp_type, ', '.join('*{}'.format(c) for c in gfs_tuple))
raise ValueError("Expression not understood")
#
# Constraint container
#
@class_member(classtag="driver_block")
......
......@@ -11,8 +11,9 @@ from dune.codegen.pdelab.driver import (get_form_ident,
)
from dune.codegen.pdelab.driver.gridfunctionspace import (main_type_trial_gfs,
name_leafview,
name_trial_subgfs,
main_name_trial_subgfs,
main_type_range,
main_type_subgfs,
)
from dune.codegen.pdelab.driver.interpolate import (interpolate_vector,
main_name_boundary_grid_function,
......@@ -40,31 +41,40 @@ def name_test_fail_variable():
def name_exact_solution_gridfunction(treepath):
element = get_trial_element()
func = preprocess_leaf_data(element, "exact_solution")
if isinstance(element, MixedElement):
index = treepath_to_index(element, treepath)
func = (func[index],)
element = element.extract_component(index)[1]
# if isinstance(element, MixedElement):
# index = treepath_to_index(element, treepath)
# func = (func[index],)
# element = element.extract_component(index)[1]
return main_name_boundary_grid_function(element, func)
def type_discrete_grid_function(gfs):
return "{}_DGF".format(gfs.upper())
def type_discrete_grid_function(treepath):
name = "DiscreteGridFunction_{}".format("_".join(str(t) for t in treepath))
return name
@preamble(section="error", kernel="main")
def define_discrete_grid_function(gfs, vector_name, dgf_name):
dgf_type = type_discrete_grid_function(gfs)
gfs_type = main_type_trial_gfs()
def define_discrete_grid_function(gfs, vector_name, dgf_name, treepath):
dgf_type = type_discrete_grid_function(treepath)
if len(treepath) == 0:
gfs_type = main_type_trial_gfs()
else:
gfs_type = main_type_subgfs(treepath)
form_ident = get_form_ident()
vector_type = main_type_vector(form_ident)
# If this is the root we get the gfs from the driver block as a
# pointer. This means we need to dereference it
if len(treepath) == 0:
gfs = '*' + gfs
return ["using {} = Dune::PDELab::DiscreteGridFunction<{}, {}>;".format(dgf_type, gfs_type, vector_type),
"{} {}(*{},*{});".format(dgf_type, dgf_name, gfs, vector_name)]
"{} {}({},*{});".format(dgf_type, dgf_name, gfs, vector_name)]
def name_discrete_grid_function(gfs, vector_name):
dgf_name = "{}_dgf".format(gfs)
define_discrete_grid_function(gfs, vector_name, dgf_name)
return dgf_name
def name_discrete_grid_function(gfs, vector_name, treepath):
name = "discreteGridFunction_{}".format("_".join(str(t) for t in treepath))
define_discrete_grid_function(gfs, vector_name, name, treepath)
return name
@preamble(section="error", kernel="main")
......@@ -75,11 +85,11 @@ def typedef_difference_squared_adapter(name, treepath):
index = treepath_to_index(element, treepath)
func = (func[index],)
element = element.extract_component(index)[1]
bgf_type = main_type_boundary_grid_function(func)
bgf_type = main_type_subgfs(treepath)
vector = main_name_vector(get_form_ident())
gfs = name_trial_subgfs(treepath)
dgf = name_discrete_grid_function(gfs, vector)
dgf_type = type_discrete_grid_function(gfs)
gfs = main_name_trial_subgfs(treepath)
dgf = name_discrete_grid_function(gfs, vector, treepath)
dgf_type = type_discrete_grid_function(treepath)
return 'using {} = Dune::PDELab::DifferenceSquaredAdapter<{}, {}>;'.format(name, bgf_type, dgf_type)
......@@ -94,8 +104,8 @@ def define_difference_squared_adapter(name, treepath):
t = type_difference_squared_adapter(treepath)
sol = name_exact_solution_gridfunction(treepath)
vector = main_name_vector(get_form_ident())
gfs = name_trial_subgfs(treepath)
dgf = name_discrete_grid_function(gfs, vector)
gfs = main_name_trial_subgfs(treepath)
dgf = name_discrete_grid_function(gfs, vector, treepath)
return '{} {}(*{}, {});'.format(t, name, sol, dgf)
......
......@@ -268,9 +268,11 @@ def name_trial_gfs():
def main_name_trial_gfs():
name = "gridFunctionSpace"
element = get_trial_element()
is_dirichlet = preprocess_leaf_data(element, "is_dirichlet")
return name_gfs(element, is_dirichlet, main=True)
main_define_gfs(name, element, is_dirichlet)
return name
def name_test_gfs():
......@@ -279,7 +281,7 @@ def name_test_gfs():
return name_gfs(element, is_dirichlet)
def name_gfs(element, is_dirichlet, treepath=(), root=True, main=False):
def name_gfs(element, is_dirichlet, treepath=(), root=True):
"""Generate name of grid function space
This function will call itself recursively to build the grid function space
......@@ -294,51 +296,37 @@ def name_gfs(element, is_dirichlet, treepath=(), root=True, main=False):
Treepath for the grid function space tree
root : bool
Called for the root of the tree?
main : bool
Can be called for driver block generation or for generation from the
main of the program. Instead of copying the code to a new method
main_name_gfs we use switches to avoid code duplication.
"""
if isinstance(element, (VectorElement, TensorElement)):
subel = element.sub_elements()[0]
subgfs = name_gfs(subel, is_dirichlet[:subel.value_size()], treepath=treepath + (0,), root=False, main=main)
subgfs = name_gfs(subel, is_dirichlet[:subel.value_size()], treepath=treepath + (0,), root=False)
name = "{}_pow{}gfs_{}".format(subgfs,
element.num_sub_elements(),
"_".join(str(t) for t in treepath))
if main:
# TODO
assert False
else:
define_power_gfs(element, is_dirichlet, name, subgfs, root)
return name
define_power_gfs(element, is_dirichlet, name, subgfs, root)
elif isinstance(element, MixedElement):
k = 0
subgfs = []
for i, subel in enumerate(element.sub_elements()):
subgfs.append(name_gfs(subel, is_dirichlet[k:k + subel.value_size()], treepath=treepath + (i,), root=False, main=main))
subgfs.append(name_gfs(subel, is_dirichlet[k:k + subel.value_size()], treepath=treepath + (i,), root=False))
k = k + subel.value_size()
name = "_".join(subgfs)
if len(subgfs) == 1:
name = "{}_dummy".format(name)
name = "{}_{}".format(name, "_".join(str(t) for t in treepath))
if main:
# TODO
assert False
else:
define_composite_gfs(element, is_dirichlet, name, tuple(subgfs), root)
return name
define_composite_gfs(element, is_dirichlet, name, tuple(subgfs), root)
else:
assert isinstance(element, (FiniteElement, TensorProductElement))
name = "{}{}_gfs_{}".format(FEM_name_mangling(element).lower(),
"_dirichlet" if is_dirichlet[0] else "",
"_".join(str(t) for t in treepath))
if main:
main_define_gfs(element, is_dirichlet, name, root)
else:
define_gfs(element, is_dirichlet, name, root)
driver_block_get_gridfunctionsspace(element, is_dirichlet, root, name=name)
return name
define_gfs(element, is_dirichlet, name, root)
if root:
driver_block_get_gridfunctionsspace(element, is_dirichlet, name=name)
return name
def type_test_gfs():
......@@ -398,17 +386,19 @@ def define_gfs(element, is_dirichlet, name, root):
@preamble(section="gfs", kernel="driver_block")
def define_power_gfs(element, is_dirichlet, name, subgfs, root):
declare_gfs(element, is_dirichlet, name, root)
gfstype = type_gfs(element, is_dirichlet, root=root)
names = ["using namespace Dune::Indices;"]
names = names + ["{0}.child(_{1}).name(\"{0}_{1}\");".format(name, i) for i in range(element.num_sub_elements())]
return ["{} {}({});".format(gfstype, name, subgfs)] + names
names = names + ["{0}->child(_{1}).name(\"{0}_{1}\");".format(name, i) for i in range(element.num_sub_elements())]
return ["{} = std::make_shared<{}>(*{});".format(name, gfstype, subgfs)] + names
@preamble(section="gfs", kernel="driver_block")
def define_composite_gfs(element, is_dirichlet, name, subgfs, root):
declare_gfs(element, is_dirichlet, name, root)
gfstype = type_gfs(element, is_dirichlet, root=root)
return ["{} {}({});".format(gfstype, name, ", ".join(subgfs)),
"{}.update();".format(name)]
return ["{} = std::make_shared<{}>({});".format(name, gfstype, ", ".join("*{}".format(c) for c in subgfs)),
"{}->update();".format(name)]
@class_member(classtag="driver_block")
......@@ -510,48 +500,39 @@ def type_constraintsassembler(is_dirichlet):
return name
def type_subgfs(tree_path):
def main_type_subgfs(treepath):
include_file('dune/pdelab/gridfunctionspace/subspace.hh', filetag='driver')
gfs = type_trial_gfs()
return "Dune::PDELab::GridFunctionSubSpace<{}, Dune::TypeTree::TreePath<{}> >".format(gfs, ', '.join(str(t) for t in tree_path))
gfs = main_type_trial_gfs()
return "Dune::PDELab::GridFunctionSubSpace<{}, Dune::TypeTree::TreePath<{}> >".format(gfs, ', '.join(str(t) for t in treepath))
@preamble(section="vtk", kernel="driver_block")
def define_subgfs(name, treepath):
t = type_subgfs(treepath)
gfs = name_trial_gfs()
return "{} {}({});".format(t, name, gfs)
@preamble(section="postprocessing", kernel="main")
def main_define_subgfs(name, treepath):
t = main_type_subgfs(treepath)
gfs = main_name_trial_gfs()
return "{} {}(*{});".format(t, name, gfs)
def name_subgfs(treepath):
gfs = name_trial_gfs()
def main_name_subgfs(treepath):
gfs = main_name_trial_gfs()
name = "{}_{}".format(gfs, "_".join(str(t) for t in treepath))
define_subgfs(name, treepath)
main_define_subgfs(name, treepath)
return name
def name_trial_subgfs(treepath):
if len(treepath) == 0:
return name_trial_gfs()
else:
return name_subgfs(treepath)
def main_name_trial_subgfs(treepath):
if len(treepath) == 0:
return name_trial_gfs()
return main_name_trial_gfs()
else:
# TODO
assert False
# return name_subgfs(treepath)
return main_name_subgfs(treepath)
@class_member(classtag="driver_block")
def driver_block_get_gridfunctionsspace(element, is_dirichlet, root, name=None):
gfs_type = type_gfs(element, is_dirichlet, root=root)
def driver_block_get_gridfunctionsspace(element, is_dirichlet, name=None):
gfs_type = type_gfs(element, is_dirichlet)
if not name:
name = name_gfs(element, is_dirichlet, root=root)
return ["std::shared_ptr<{}> get_gridfunctionsspace(){{".format(gfs_type),
name = name_gfs(element, is_dirichlet)
return ["std::shared_ptr<{}> getGridFunctionsSpace(){{".format(gfs_type),
" return {};".format(name),
"}"]
......@@ -564,15 +545,15 @@ def main_typedef_trial_gfs(name, element, is_dirichlet):
def main_type_trial_gfs():
name = "GridFunctionSpace"
element = get_trial_element()
is_dirichlet = preprocess_leaf_data(element, "is_dirichlet")
gfs_type = type_gfs(element, is_dirichlet)
main_typedef_trial_gfs(gfs_type, element, is_dirichlet)
return gfs_type
main_typedef_trial_gfs(name, element, is_dirichlet)
return name
@preamble(section="postprocessing", kernel="main")
def main_define_gfs(element, is_dirichlet, name, root):
def main_define_gfs(name, element, is_dirichlet):
driver_block_name = name_driver_block()
driver_block_get_gridfunctionsspace(element, is_dirichlet, root)
return "auto {} = {}.get_gridfunctionsspace();".format(name, driver_block_name)
driver_block_get_gridfunctionsspace(element, is_dirichlet)
return "auto {} = {}.getGridFunctionsSpace();".format(name, driver_block_name)
......@@ -91,7 +91,7 @@ def main_typedef_gridoperator(name, form_ident):
def main_type_gridoperator(form_ident):
name = "GO_{}".format(form_ident)
name = "GridOperator"
main_typedef_gridoperator(name, form_ident)
return name
......@@ -101,7 +101,7 @@ def driver_block_get_gridoperator(form_ident, name=None):
gridoperator_type = type_gridoperator(form_ident)
if not name:
name = name_gridoperator(form_ident)
return ["std::shared_ptr<{}> get_gridoperator(){{".format(gridoperator_type),
return ["std::shared_ptr<{}> getGridOperator(){{".format(gridoperator_type),
" return {};".format(name),
"}"]
......@@ -110,11 +110,11 @@ def driver_block_get_gridoperator(form_ident, name=None):
def main_define_gridoperator(name, form_ident):
driver_block_name = name_driver_block()
driver_block_get_gridoperator(form_ident)
return "auto {} = {}.get_gridoperator();".format(name, driver_block_name)
return "auto {} = {}.getGridOperator();".format(name, driver_block_name)
def main_name_gridoperator(form_ident):
name = "go_{}".format(form_ident)
name = "gridOperator"
main_define_gridoperator(name, form_ident)
return name
......
......@@ -15,7 +15,7 @@ from dune.codegen.pdelab.driver.gridoperator import (name_gridoperator,
name_localoperator,
)
from dune.codegen.pdelab.driver.constraints import (has_dirichlet_constraints,
name_bctype_function,
name_bctype_grid_function,
name_constraintscontainer,
)
from dune.codegen.pdelab.driver.interpolate import (interpolate_dirichlet_data,
......@@ -63,7 +63,7 @@ def time_loop():
is_dirichlet = preprocess_leaf_data(element, "is_dirichlet")
assemble_new_constraints = ""
if has_dirichlet_constraints(is_dirichlet):
bctype = name_bctype_function(element, is_dirichlet)
bctype = name_bctype_grid_function(element, is_dirichlet)
cc = name_constraintscontainer()
assemble_new_constraints = (" // Assemble constraints for new time step\n"
" {}.setTime({}+dt);\n"
......
......@@ -38,18 +38,54 @@ def interpolate_vector(func, gfs, name):
)
@class_member(classtag="driver_block")
def typedef_composite_boundary_grid_function(name, children):
templates = ','.join('std::decay_t<decltype(*{})>'.format(c) for c in children)
return "using {} = Dune::PDELab::CompositeGridFunction<{}>;".format(name, templates)
def type_composite_boundary_grid_function(children, root):
if root:
name = "BoundaryGridFunction"
else:
name = "CompositeGridFunction_{}".format('_'.join(c for c in children))
typedef_composite_boundary_grid_function(name, children)
return name
@class_member(classtag="driver_block")
def declare_composite_boundary_grid_function(name, children, root):
composite_gfs_type = type_composite_boundary_grid_function(children, root)
return "std::shared_ptr<{}> {};".format(composite_gfs_type, name)
@preamble(section="vector", kernel="driver_block")
def define_compositegfs_parameterfunction(name, children):
return "Dune::PDELab::CompositeGridFunction<{}> {}({});".format(', '.join('decltype({})'.format(c) for c in children),
name,
', '.join(children))
def define_composite_boundary_grid_function(name, children, root=False):
declare_composite_boundary_grid_function(name, children, root)
composite_gfs_type = type_composite_boundary_grid_function(children, root)
return "{} = std::make_shared<{}>({});".format(name, composite_gfs_type, ', '.join('*{}'.format(c) for c in children))
def _is_local(func):
# palpo TODO
return True
# return False
# assert isinstance(func, tuple)
# if len(func) == 2:
# return True
# else:
# try:
# assert len(func) == 1
# except:
# from pudb import set_trace; set_trace()
# return False
@class_member(classtag="driver_block")
def typedef_boundary_grid_function(name, func):
def typedef_boundary_grid_function(name, local):
leafview_type = type_leafview()
range_type = type_range()
boundary_function_type = type_boundary_function(func)
boundary_function_type = type_boundary_function(local)
# palpo TODO: 1 in the format below!
return "using {} = Dune::PDELab::LocalCallableToGridFunctionAdapter<{}, {}, {}, {}>;".format(name,
leafview_type,
......@@ -58,24 +94,33 @@ def typedef_boundary_grid_function(name, func):
boundary_function_type)
def type_boundary_grid_function(func):
name = "BoundaryGridFunction"
typedef_boundary_grid_function(name, func)
def type_boundary_grid_function(local, root):
# TODO: remove local stuff
# if local:
# name = "BoundaryGridFunctionLocal"
# else:
# name = "BoundaryGridFunctionGlobal"
if root:
name = "BoundaryGridFunction"
else:
name = "BoundaryGridFunctionLeaf"
typedef_boundary_grid_function(name, local)
return name
@class_member(classtag="driver_block")
def declare_boundary_grid_function(name, func):
bgf_type = type_boundary_grid_function(func)
def declare_boundary_grid_function(name, local, root):
bgf_type = type_boundary_grid_function(local, root)
return "std::shared_ptr<{}> {};".format(bgf_type, name)
@preamble(section="vector", kernel="driver_block")
def define_boundary_grid_function(name, func):
declare_boundary_grid_function(name, func)
def define_boundary_grid_function(name, func, root=False):
local = _is_local(func)
declare_boundary_grid_function(name, local, root)
gv = name_leafview()
boundary_function = name_boundary_function(func)
bgf_type = type_boundary_grid_function(func)
boundary_function = name_boundary_function(func, local)
bgf_type = type_boundary_grid_function(local, root)
include_file('dune/pdelab/function/callableadapter.hh', filetag='driver')
if is_stationary():
return "{} = std::make_shared<{}>({}, *{});".format(name, bgf_type, gv, boundary_function)
......@@ -92,93 +137,90 @@ def define_boundary_grid_function(name, func):
@cached
def name_boundary_grid_function(element, func):
def name_boundary_grid_function(element, func, root=True):
assert isinstance(func, tuple)
if isinstance(element, MixedElement):
# palpo TODO
assert False
k = 0
childs = []
for subel in element.sub_elements():
childs.append(name_boundary_grid_function(subel, func[k:k + subel.value_size()]))
childs.append(name_boundary_grid_function(subel, func[k:k + subel.value_size()], root=False))
k = k + subel.value_size()
name = "_".join(childs)
if len(childs) == 1:
name = "{}_dummy".format(name)
define_compositegfs_parameterfunction(name, tuple(childs))
return name
if root:
name = "boundary_grid_function"
define_composite_boundary_grid_function(name, tuple(childs), root=root)
else:
assert isinstance(element, (FiniteElement, TensorProductElement))
name = "boundary_grid_function"
define_boundary_grid_function(name, func)
name = get_counted_variable("boundary_grid_function")
if root:
name = "boundary_grid_function"
define_boundary_grid_function(name, func, root=root)
if root:
print("palpo 1 element: {}".format(element))
driver_block_get_boundarygridfunction(element, func, name=name)
return name
def boundary_lambda(func):
def boundary_lambda(func, local):
# palpo TODO
assert isinstance(func, tuple)
assert len(func) == 1
func = func[0]
if func is None:
func = 0.0
from ufl.classes import Expr
if isinstance(func, (int, float)):
return "[&](const auto& x){{ return {}; }}".format(float(func))
elif isinstance(func, Expr):
return "[&](const auto& is, const auto& x){{ return {}; }}".format(float(func))
else:
from ufl.classes import Expr
assert isinstance(func, Expr)
from dune.codegen.pdelab.driver.visitor import ufl_to_code
return "[&](const auto& is, const auto& xl){{ {}; }}".format(ufl_to_code(func))
else:
raise NotImplementedError("What is this?")
@class_member(classtag="driver_block")
def typedef_boundary_function(name, func):
assert isinstance(func, tuple)
assert len(func) == 1
func = func[0]
def typedef_boundary_function(name, local):
range_type = type_range()
leafview_type = type_leafview()
if func is None:
func = 0.0
from ufl.classes import Expr
if isinstance(func, (int, float)):
if not local:
coordinate = "typename {}::template Codim<0>::Entity::Geometry::GlobalCoordinate".format(leafview_type)
return "using {} = std::function<{}({})>;".format(name, range_type, coordinate)
elif isinstance(func, Expr):
else:
entity = "typename {}::template Codim<0>::Entity".format(leafview_type)
coordinate = "typename {}::template Codim<0>::Entity::Geometry::LocalCoordinate".format(leafview_type)
return "using {} = std::function<{}({}, {})>;".format(name, range_type, entity, coordinate)
else:
raise NotImplementedError("What is this?")
def type_boundary_function(func):
def type_boundary_function(local):
name = "BoundaryFunction"
typedef_boundary_function(name, func)
if not local:
name = name + "Global"
else:
name = name + "Local"
typedef_boundary_function(name, local)
return name
@class_member(classtag="driver_block")
def declare_boundary_function(name, func):
bf_type = type_boundary_function(func)
def declare_boundary_function(name, local):
bf_type = type_boundary_function(local)
return "std::shared_ptr<{}> {};".format(bf_type, name)
@preamble(section="vector", kernel="driver_block")
def define_boundary_function(name, func):
declare_boundary_function(name, func)
bf_type = type_boundary_function(func)
bf_lambda = boundary_lambda(func)
def define_boundary_function(name, func, local):
declare_boundary_function(name, local)
bf_type = type_boundary_function(local)
bf_lambda = boundary_lambda(func, local)
return "{} = std::make_shared<{}> ({});".format(name, bf_type, bf_lambda)
@cached
def name_boundary_function(func):
name = get_counted_variable("function")
define_boundary_function(name, func)
def name_boundary_function(func, local):
name = get_counted_variable("boundary_function")
define_boundary_function(name, func, local)
return name
......@@ -186,8 +228,8 @@ def name_boundary_function(func):
def driver_block_get_boundarygridfunction(element, func, name=None):
if not name:
name = name_boundary_grid_function(element, func)
bgf_type = type_boundary_grid_function(func)
return ["std::shared_ptr<{}> get_boundarygridfunction(){{".format(bgf_type),
bgf_type = "BoundaryGridFunction"
return ["std::shared_ptr<{}> getBoundaryGridFunction(){{".format(bgf_type),
" return {};".format(name),
"}"]
......@@ -195,7 +237,8 @@ def driver_block_get_boundarygridfunction(element, func, name=None):
@preamble(section="postprocessing", kernel="main")
def main_typedef_boundary_grid_function(name, func):
driver_block_type = type_driver_block()
bgf_type = type_boundary_grid_function(func)
local = _is_local(func)
bgf_type = type_boundary_grid_function(local, True)
return "using {} = {}::{};".format(name, driver_block_type, bgf_type)
......@@ -208,13 +251,14 @@ def main_type_boundary_grid_function(func):
@preamble(section="postprocessing", kernel="main")
def main_define_boundary_grid_function(name, element, func):
driver_block_name = name_driver_block()
driver_block_get_boundarygridfunction(element, func)
return "auto {} = {}.get_boundarygridfunction();".format(name, driver_block_name)
print("palpo 2 element: {}".format(element))
# driver_block_get_boundarygridfunction(element, func)
return "auto {} = {}.getBoundaryGridFunction();".format(name, driver_block_name)
@cached
def main_name_boundary_grid_function(element, func):
assert isinstance(func, tuple)
name = "boundary_grid_function"
name = "boundaryGridFunction"
main_define_boundary_grid_function(name, element, func)
return name
......@@ -123,7 +123,7 @@ def main_typedef_vector(name, form_ident):
def main_type_vector(form_ident):
name = "V_{}".format(form_ident.upper())
name = "Coefficient"
main_typedef_vector(name, form_ident)
return name
......@@ -133,7 +133,7 @@ def driver_block_get_coefficient(form_ident, name=None):
vector_type = type_vector(form_ident)
if not name:
name = name_vector(form_ident)
return ["std::shared_ptr<{}> get_coefficient(){{".format(vector_type),
return ["std::shared_ptr<{}> getCoefficient(){{".format(vector_type),
" return {};".format(name),
"}"]
......@@ -142,13 +142,12 @@ def driver_block_get_coefficient(form_ident, name=None):
def main_define_vector(name, form_ident):
driver_block_name = name_driver_block()
driver_block_get_coefficient(form_ident)
return "auto {} = {}.get_coefficient();".format(name, driver_block_name)
return "auto {} = {}.getCoefficient();".format(name, driver_block_name)
def main_name_vector(form_ident):
name = "x_{}".format(form_ident)
name = "coefficient"
main_define_vector(name, form_ident)
interpolate_dirichlet_data(name)
return name
......@@ -251,7 +250,7 @@ def driver_block_get_solver(name=None):
solver_type = type_stationarylinearproblemsolver()
if not name:
name = name_stationarylinearproblemsolver()
return ["std::shared_ptr<{}> get_solver(){{".format(solver_type),
return ["std::shared_ptr<{}> getSolver(){{".format(solver_type),
" return {};".format(name),
"}"]
......@@ -260,11 +259,11 @@ def driver_block_get_solver(name=None):
def main_define_stationarylinearproblemsolver(name):
driver_block_name = name_driver_block()
driver_block_get_solver()
return "auto {} = {}.get_solver();".format(name, driver_block_name)
return "auto {} = {}.getSolver();".format(name, driver_block_name)
def main_name_stationarylinearproblemsolver():
name = "slp"
name = "solver"
main_define_stationarylinearproblemsolver(name)
return name
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment