Newer
Older
from dune.codegen.generation import (class_member,
get_global_context_value,
get_dimension,
get_test_element,
get_trial_element,
isQuadrilateral,
name_initree,
)
from dune.codegen.pdelab.driver.constraints import (name_assembled_constraints,
from dune.codegen.pdelab.driver.driverblock import (name_driver_block,
type_driver_block,
)
from dune.codegen.pdelab.driver.gridfunctionspace import (name_test_gfs,
name_trial_gfs,
preprocess_leaf_data,
type_domainfield,
type_range,
type_test_gfs,
type_trial_gfs,
)
from dune.codegen.pdelab.localoperator import localoperator_basename
from dune.codegen.options import get_form_option
@class_member(classtag="driver_block")
def typedef_gridoperator(name, form_ident):
ugfs = type_trial_gfs()
vgfs = type_test_gfs()
lop = type_localoperator(form_ident)
cc = type_constraintscontainer()
mb = type_matrixbackend()
df = type_domainfield()
r = type_range()
if get_form_option("fastdg"):
if not get_form_option("sumfact"):
raise CodegenCodegenError("FastDGGridOperator is only implemented for sumfactorization.")
include_file("dune/pdelab/gridoperator/fastdg.hh", filetag="driver")
return "using {} = Dune::PDELab::FastDGGridOperator<{}, {}, {}, {}, {}, {}, {}, {}, {}>;".format(name, ugfs, vgfs, lop, mb, df, r, r, cc, cc)
else:
include_file("dune/pdelab/gridoperator/gridoperator.hh", filetag="driver")
return "using {} = Dune::PDELab::GridOperator<{}, {}, {}, {}, {}, {}, {}, {}, {}>;".format(name, ugfs, vgfs, lop, mb, df, r, r, cc, cc)
def type_gridoperator(form_ident):
name = "GO_{}".format(form_ident)
typedef_gridoperator(name, form_ident)
@class_member(classtag="driver_block")
def declare_gridoperator(name, form_ident):
gotype = type_gridoperator(form_ident)
return "std::shared_ptr<{}> {};".format(gotype, name)
@preamble(section="gridoperator", kernel="driver_block")
def define_gridoperator(name, form_ident):
declare_gridoperator(name, form_ident)
gotype = type_gridoperator(form_ident)
ugfs = name_trial_gfs()
vgfs = name_test_gfs()
if ugfs != vgfs:
raise NotImplementedError("Non-Galerkin methods currently not supported!")
cc = name_assembled_constraints()
lop = name_localoperator(form_ident)
return ["{} = std::make_shared<{}>(*{}, *{}, *{}, *{}, *{}, *{});".format(name, gotype, ugfs, cc, vgfs, cc, lop, mb),
"std::cout << \"gfs with \" << {}->size() << \" dofs generated \"<< std::endl;".format(ugfs),
"std::cout << \"cc with \" << {}->size() << \" dofs generated \"<< std::endl;".format(cc)]
def name_gridoperator(form_ident):
name = "go_{}".format(form_ident)
define_gridoperator(name, form_ident)
driver_block_get_gridoperator(form_ident, name=name)
@preamble(section="postprocessing", kernel="main")
def main_typedef_gridoperator(name, form_ident):
driver_block_type = type_driver_block()
db_gridoperator_type = type_gridoperator(form_ident)
gridoperator_type = "using {} = {}::{};".format(name, driver_block_type, db_gridoperator_type)
return gridoperator_type
def main_type_gridoperator(form_ident):
main_typedef_gridoperator(name, form_ident)
return name
@class_member(classtag="driver_block")
def driver_block_get_gridoperator(form_ident, name=None):
gridoperator_type = type_gridoperator(form_ident)
if not name:
name = name_gridoperator(form_ident)
return ["std::shared_ptr<{}> getGridOperator(){{".format(gridoperator_type),
" return {};".format(name),
"}"]
@preamble(section="postprocessing", kernel="main")
def main_define_gridoperator(name, form_ident):
driver_block_name = name_driver_block()
driver_block_get_gridoperator(form_ident)
return "auto {} = {}.getGridOperator();".format(name, driver_block_name)
def main_name_gridoperator(form_ident):
main_define_gridoperator(name, form_ident)
return name
@class_member(classtag="driver_block")
def typedef_localoperator(name, form_ident):
ugfs = type_trial_gfs()
vgfs = type_test_gfs()
filename = get_form_option("filename", form_ident)
lopname = localoperator_basename(form_ident)
Dominic Kempf
committed
return "using {} = {}<{}, {}>;".format(name, lopname, ugfs, vgfs)
def type_localoperator(form_ident):
name = "LOP_{}".format(form_ident.upper())
typedef_localoperator(name, form_ident)
@class_member(classtag="driver_block")
def declare_localoperator(name, form_ident):
loptype = type_localoperator(form_ident)
return "std::shared_ptr<{}> {};".format(loptype, name)
@preamble(section="gridoperator", kernel="driver_block")
def define_localoperator(name, form_ident):
declare_localoperator(name, form_ident)
trial_gfs = name_trial_gfs()
test_gfs = name_test_gfs()
loptype = type_localoperator(form_ident)
return "{} = std::make_shared<{}>(*{}, *{}, {});".format(name, loptype, trial_gfs, test_gfs, ini)
def name_localoperator(form_ident):
name = "lop_{}".format(form_ident)
define_localoperator(name, form_ident)
@preamble(section="gridoperator", kernel="driver_block")
def define_dofestimate(name):
# Provide a worstcase estimate for the number of entries per row based
# on the given gridfunction space and cell geometry
if isQuadrilateral(get_cell()):
geo_factor = 2**get_dimension()
else:
if get_dimension() < 3:
else:
# TODO no idea how a generic estimate for 3D simplices looks like
geo_factor = 12
gfs = name_trial_gfs()
ini = name_initree()
return ["{}->update();".format(gfs),
"int generic_dof_estimate = {} * {}->maxLocalSize();".format(geo_factor, gfs),
"int {} = {}.get<int>(\"istl.number_of_nnz\", generic_dof_estimate);".format(name, ini)]
def name_dofestimate():
name = "dofestimate"
define_dofestimate(name)
return name
@class_member(classtag="driver_block")
def typedef_matrixbackend(name):
include_file("dune/pdelab/backend/istl.hh", filetag="driver")
return "using {} = Dune::PDELab::ISTL::BCRSMatrixBackend<>;".format(name)
def type_matrixbackend():
name = "MatrixBackend"
typedef_matrixbackend(name)
return name
@class_member(classtag="driver_block")
def declare_matrixbackend(name):
mbtype = type_matrixbackend()
return "std::shared_ptr<{}> {};".format(mbtype, name)
@preamble(section="gridoperator", kernel="driver_block")
mbtype = type_matrixbackend()
dof = name_dofestimate()
return "{} = std::make_shared<{}>({});".format(name, mbtype, dof)
def name_matrixbackend():
name = "mb"
define_matrixbackend(name)
return name