Skip to content
Snippets Groups Projects
Commit 8bf01d3c authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Basics for writing C++ Classes

in cgen and from the cache.
parent 02c5e983
No related branches found
No related tags found
No related merge requests found
from cgen import Generable
from cgen import Generable, Block
class AccessModifier:
PRIVATE = 1
PUBLIC = 2
PROTECTED = 3
def access_modifier_string(am):
if am == AccessModifier.PRIVATE:
return "private"
if am == AccessModifier.PUBLIC:
return "public"
if am == AccessModifier.PROTECTED:
return "protected"
raise ValueError("Unknown access modifier in class generation")
class BaseClass(Generable):
def __init__(self, name, inheritance=AccessModifier.PUBLIC, construction=[]):
self.name = name
self.inheritance = inheritance
self.construction = construction
assert isinstance(name, str)
for param in construction:
assert isinstance(param, str)
def generate(self):
yield self.name
class ClassMember(Generable):
def __init__(self, member, access=AccessModifier.PUBLIC):
self.member = member
self.access = access
assert isinstance(member, Generable)
def generate(self):
yield "{}:\n".format(access_modifier_string(self.access))
for line in self.member.generate():
yield line
class Constructor(Generable):
def __init__(self, block=Block([]), arg_decls=[], clsname=None, access=AccessModifier.PUBLIC):
self.clsname = clsname
self.arg_decls = arg_decls
self.access = access
self.block = block
def generate(self):
assert self.clsname
yield "{}:\n".format(access_modifier_string(self.access))
yield self.classname + "("
for ad in self.arg_decls:
for content in ad.generate():
yield content
yield ")"
# TODO Add initializer lists here as soon as they are needed
for line in self.block.generate():
yield line
class Class(Generable):
""" Generator for a templated class """
def __init__(self, name, public_methods = [], tparams=[], constructors=[]):
def __init__(self, name, base_classes=[], members=[], tparam_decls=[], constructors=[]):
self.name = name
self.public_methods = public_methods
self.tparams = tparams
self.base_classes = base_classes
self.members = members
self.tparam_decls = tparam_decls
self.constructors = constructors
assert isinstance(name, str)
from cgen import FunctionBody
for m in self.methods:
assert isinstance(n, FunctionBody)
for bc in base_classes:
assert isinstance(bc, BaseClass)
for mem in members:
assert isinstance(mem, ClassMember)
from cgen import Declarator
for tp in tparam_decls:
assert isinstance(tp, Declarator)
for con in constructors:
assert isinstance(con, Constructor)
def access_modifier(self, am):
if am == AccessModifier.PRIVATE:
return "private"
if am == AccessModifier.PUBLIC:
return "public"
if am == AccessModifier.PROTECTED:
return "protected"
raise ValueError("Unknown access modifier in class generation")
def generate(self):
yield "class {}".format(self.name)
\ No newline at end of file
# define the class header
from cgen import Value
decl = Value('class', self.name)
if self.tparam_decls:
from cgen import Template
decl = Template(self.tparam_decls, decl)
# Yield the definition
for line in decl.generate(with_semicolon=False):
yield line
yield '\n'
# add base class inheritance
yield ",\n".join(" : {} {}\n".format(access_modifier_string(bc.inheritance), bc.name) for bc in self.base_classes)
# Set the
# Now yield the entire block
block = Block(contents=self.constructors + self.members)
# Yield the block
for line in block.generate():
yield line
yield ";\n"
......@@ -9,8 +9,9 @@ from cgen import Include
from pytools import memoize
# Define the generators used in-here
operator_include = generator_factory(item_tags=("pdelab", "include", "operator"), on_store=lambda i: Include(i), no_deco=True)
base_class = generator_factory(item_tags=("pdelab", "baseclass", "operator"), counted=True, no_deco=True)
operator_include = generator_factory(item_tags=("include", "operator"), on_store=lambda i: Include(i), no_deco=True)
from dune.perftool.cgen.clazz import BaseClass
public_base_class = generator_factory(item_tags=("baseclass", "operator"), on_store=lambda n: BaseClass(n), counted=True, no_deco=True)
initializer_list = generator_factory(item_tags=("pdelab", "initializer", "operator"), counted=True, no_deco=True)
# TODO definition
......@@ -36,7 +37,7 @@ def measure_specific_details(measure):
# Add a base class
from dune.perftool.pdelab.driver import type_localoperator
loptype = type_localoperator()
base_class("Dune::PDELab::NumericalJacobian{}<{}>".format(which, loptype))
public_base_class("Dune::PDELab::NumericalJacobian{}<{}>".format(which, loptype))
# Add the initializer list for that base class
ini = name_initree_constructor()
......@@ -44,7 +45,7 @@ def measure_specific_details(measure):
if measure == "cell":
base_class('Dune::PDELab::FullVolumePattern')
public_base_class('Dune::PDELab::FullVolumePattern')
numerical_jacobian("Volume")
ret["residual_signature"] = ['template<typename EG, typename LFSV0, typename X, typename LFSV1, typename R>',
......@@ -53,7 +54,7 @@ def measure_specific_details(measure):
'void jacobian_volume(const EG& eg, const LFSV0& lfsv0, const X& x, const LFSV1& lfsv1, J& jac) const']
if measure == "exterior_facet":
base_class('Dune::PDELab::FullBoundaryPattern')
public_base_class('Dune::PDELab::FullBoundaryPattern')
numerical_jacobian("Boundary")
ret["residual_signature"] = ['template<typename IG, typename LFSV0, typename X, typename LFSV1, typename R>',
......@@ -62,7 +63,7 @@ def measure_specific_details(measure):
'void jacobian_boundary(const IG& ig, const LFSV0& lfsv0, const X& x, const LFSV1& lfsv1, J& jac) const']
if measure == "interior_facet":
base_class('Dune::PDELab::FullSkeletonPattern')
public_base_class('Dune::PDELab::FullSkeletonPattern')
numerical_jacobian("Skeleton")
ret["residual_signature"] = ['template<typename IG, typename LFSV0_S, typename X, typename LFSV1_S, typename LFSV0_N, typename R, typename LFSV1_N>',
......@@ -113,6 +114,15 @@ def generate_term(integrand=None, measure=None):
return str(generate_code(kernel)[0])
def cgen_class_from_cache(name, tag):
from dune.perftool.generation import retrieve_cache_items
base_classes = [bc for bc in retrieve_cache_items(tags=(tag, "baseclass"), union=False)]
from dune.perftool.cgen import Class
return Class(name, base_classes=base_classes)
def generate_localoperator(form, operatorfile):
# For the moment, I do assume that there is but one integral of each type. This might differ
# if you use different quadrature orders for different terms.
......@@ -154,4 +164,11 @@ def generate_localoperator(form, operatorfile):
operator_include('dune/pdelab/localoperator/pattern.hh')
operator_include('dune/geometry/quadraturerules.hh')
base_class('Dune::PDELab::LocalOperatorDefaultFlags')
public_base_class('Dune::PDELab::LocalOperatorDefaultFlags')
# Write the file!
from dune.perftool.file import generate_file
# TODO take the name of this thing from the UFL file
lop = cgen_class_from_cache("LocalOperator", "operator")
generate_file(get_option("operator_file"), "operator", [lop])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment