Skip to content
Snippets Groups Projects
Commit dd07cdfb authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Add support for initializer lists

parent fe5dcf79
No related branches found
No related tags found
No related merge requests found
...@@ -41,11 +41,12 @@ class ClassMember(Generable): ...@@ -41,11 +41,12 @@ class ClassMember(Generable):
yield line + '\n' yield line + '\n'
class Constructor(Generable): class Constructor(Generable):
def __init__(self, block=Block([]), arg_decls=[], clsname=None, access=AccessModifier.PUBLIC): def __init__(self, block=Block([]), arg_decls=[], clsname=None, initializer_list=[], access=AccessModifier.PUBLIC):
self.clsname = clsname self.clsname = clsname
self.arg_decls = arg_decls self.arg_decls = arg_decls
self.access = access self.access = access
self.block = block self.block = block
self.il = initializer_list
def generate(self): def generate(self):
assert self.clsname assert self.clsname
...@@ -56,9 +57,16 @@ class Constructor(Generable): ...@@ -56,9 +57,16 @@ class Constructor(Generable):
for ad in self.arg_decls: for ad in self.arg_decls:
for content in ad.generate(with_semicolon=False): for content in ad.generate(with_semicolon=False):
yield content yield content
yield ")" yield ")\n"
# TODO Add initializer lists here as soon as they are needed # add the initializer list
if self.il:
yield " : {}".format(self.il[0])
for i in self.il[1:]:
yield ",\n"
yield " {}".format(i)
yield '\n'
for line in self.block.generate(): for line in self.block.generate():
yield line yield line
...@@ -99,10 +107,12 @@ class Class(Generable): ...@@ -99,10 +107,12 @@ class Class(Generable):
yield '\n' yield '\n'
# add base class inheritance # add base class inheritance
for i, bc in enumerate(self.base_classes): if self.base_classes:
yield " : {} {}".format(access_modifier_string(bc.inheritance), bc.name) yield " : {} {}".format(access_modifier_string(self.base_classes[0].inheritance), self.base_classes[0].name)
if i+1 != len(self.base_classes):
yield ',' for bc in self.base_classes[1:]:
yield ",\n"
yield " {} {}".format(access_modifier_string(bc.inheritance), bc.name)
yield '\n' yield '\n'
# Now yield the entire block # Now yield the entire block
......
...@@ -12,7 +12,17 @@ from pytools import memoize ...@@ -12,7 +12,17 @@ from pytools import memoize
operator_include = generator_factory(item_tags=("include", "operator"), on_store=lambda i: Include(i), no_deco=True) operator_include = generator_factory(item_tags=("include", "operator"), on_store=lambda i: Include(i), no_deco=True)
from dune.perftool.cgen.clazz import BaseClass from dune.perftool.cgen.clazz import BaseClass
public_base_class = generator_factory(item_tags=("baseclass", "operator"), on_store=lambda n: BaseClass(n), counted=True, no_deco=True) public_base_class = generator_factory(item_tags=("baseclass", "operator"), on_store=lambda n: BaseClass(n), counted=True, no_deco=True)
initializer_list = generator_factory(item_tags=("pdelab", "initializer", "operator"), counted=True, no_deco=True)
@generator_factory(item_tags=("initializer", "operator"), counted=True, cache_key_generator=lambda *a: a[0])
def initializer_list(obj, params):
return "{}({})".format(obj, ", ".join(params))
@generator_factory(item_tags=("operator", "member"), counted=True, cache_key_generator=lambda t,n : n)
def define_private_member(_type, name):
from cgen import Value
from dune.perftool.cgen.clazz import ClassMember, AccessModifier
return ClassMember(Value(_type, name), access=AccessModifier.PRIVATE)
@generator_factory(item_tags=("operator", "constructor_param"), counted=True) @generator_factory(item_tags=("operator", "constructor_param"), counted=True)
def constructor_parameter(_type, name): def constructor_parameter(_type, name):
...@@ -25,6 +35,14 @@ def name_initree_constructor(): ...@@ -25,6 +35,14 @@ def name_initree_constructor():
constructor_parameter("const Dune::ParameterTree&", "iniParams") constructor_parameter("const Dune::ParameterTree&", "iniParams")
return "iniParams" return "iniParams"
@dune_symbol
def name_initree_member():
operator_include('dune/common/parametertree.hh')
define_private_member("const Dune::ParameterTree&", "_iniParams")
in_constructor = name_initree_constructor()
initializer_list("_iniParams", [in_constructor])
return "_iniParams"
@dune_symbol @dune_symbol
def localoperator_type(): def localoperator_type():
#TODO use something from the form here to make it unique #TODO use something from the form here to make it unique
...@@ -43,9 +61,9 @@ def measure_specific_details(measure): ...@@ -43,9 +61,9 @@ def measure_specific_details(measure):
public_base_class("Dune::PDELab::NumericalJacobian{}<{}>".format(which, loptype)) public_base_class("Dune::PDELab::NumericalJacobian{}<{}>".format(which, loptype))
# Add the initializer list for that base class # Add the initializer list for that base class
ini = name_initree_constructor() ini = name_initree_member()
initializer_list("Dune::PDELab::NumericalJacobian{}<{}>({}.get(\"numerical_epsilon.{}\", 1e-9))".format(which, loptype, ini, which.lower())) initializer_list("Dune::PDELab::NumericalJacobian{}<{}>".format(which, loptype),
["{}.get(\"numerical_epsilon.{}\", 1e-9)".format(ini, which.lower())])
if measure == "cell": if measure == "cell":
public_base_class('Dune::PDELab::FullVolumePattern') public_base_class('Dune::PDELab::FullVolumePattern')
...@@ -80,9 +98,6 @@ def measure_specific_details(measure): ...@@ -80,9 +98,6 @@ def measure_specific_details(measure):
def generate_kernel(integrand=None, measure=None): def generate_kernel(integrand=None, measure=None):
assert integrand and measure assert integrand and measure
from dune.perftool.generation import delete_cache
delete_cache()
# Get the measure specifics # Get the measure specifics
specifics = measure_specific_details(measure) specifics = measure_specific_details(measure)
...@@ -111,6 +126,10 @@ def generate_kernel(integrand=None, measure=None): ...@@ -111,6 +126,10 @@ def generate_kernel(integrand=None, measure=None):
kernel = make_kernel(domains, instructions, arguments, temporary_variables=temporaries, target=DuneTarget()) kernel = make_kernel(domains, instructions, arguments, temporary_variables=temporaries, target=DuneTarget())
kernel = preprocess_kernel(kernel) kernel = preprocess_kernel(kernel)
# All items with the kernel tags can be destroyed once a kernel has been generated
from dune.perftool.generation import delete_cache_items
delete_cache_items("kernel")
# Return the actual code (might instead return kernels...) # Return the actual code (might instead return kernels...)
return kernel return kernel
...@@ -129,12 +148,14 @@ def cgen_class_from_cache(tag, members=[]): ...@@ -129,12 +148,14 @@ def cgen_class_from_cache(tag, members=[]):
base_classes = [bc for bc in retrieve_cache_items(tags=(tag, "baseclass"), union=False)] base_classes = [bc for bc in retrieve_cache_items(tags=(tag, "baseclass"), union=False)]
constructor_params = [bc for bc in retrieve_cache_items(tags=(tag, "constructor_param"), union=False)] constructor_params = [bc for bc in retrieve_cache_items(tags=(tag, "constructor_param"), union=False)]
il = [i for i in retrieve_cache_items(tags=(tag, "initializer"), union=False)]
pm = [m for m in retrieve_cache_items(tags=(tag, "member"), union=False)]
from dune.perftool.cgen.clazz import Constructor from dune.perftool.cgen.clazz import Constructor
constructor = Constructor(arg_decls=constructor_params, clsname=localoperator_type()) constructor = Constructor(arg_decls=constructor_params, clsname=localoperator_type(), initializer_list=il)
from dune.perftool.cgen import Class from dune.perftool.cgen import Class
return Class(localoperator_type, base_classes=base_classes, members=members, constructors=[constructor]) return Class(localoperator_type, base_classes=base_classes, members=members + pm, constructors=[constructor])
def generate_localoperator(form): def generate_localoperator(form):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment