Skip to content
Snippets Groups Projects
Commit f8b3442e authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Finally, parameter class generation!

parent 0ab0b7ea
No related branches found
No related tags found
No related merge requests found
......@@ -35,12 +35,12 @@ class LocalBasisCacheWithoutReferences
typedef typename LocalBasisType::Traits::JacobianType JacobianType;
public:
typedef CacheReturnProxy<RangeType> RangeReturnType;
typedef CacheReturnProxy<RangeType> FunctionReturnType;
typedef CacheReturnProxy<JacobianType> JacobianReturnType;
RangeReturnType evaluateFunction(const DomainType& position, const LocalBasisType& localbasis) const
FunctionReturnType evaluateFunction(const DomainType& position, const LocalBasisType& localbasis) const
{
return RangeReturnType(&c.evaluateFunction(position, localbasis));
return FunctionReturnType(&c.evaluateFunction(position, localbasis));
}
JacobianReturnType evaluateJacobian(const DomainType& position, const LocalBasisType& localbasis) const
......
......@@ -73,9 +73,13 @@ class Constructor(Generable):
yield '\n'
yield "{}:\n".format(access_modifier_string(self.access))
yield self.clsname + "("
for ad in self.arg_decls:
for content in ad.generate(with_semicolon=False):
if self.arg_decls:
for content in self.arg_decls[0].generate(with_semicolon=False):
yield content
for ad in self.arg_decls[1:]:
yield ", "
for content in ad.generate(with_semicolon=False):
yield content
yield ")\n"
# add the initializer list
......
......@@ -42,6 +42,7 @@ def generate_file(filename, tag, content, headerguard=True):
assert isinstance(c, Generable)
for line in c.generate():
f.write(line)
f.write('\n\n')
if headerguard:
f.write("\n\n#endif //{}\n".format(macro))
......@@ -29,5 +29,4 @@ from dune.perftool.generation.loopy import (domain,
from dune.perftool.generation.context import (cache_context,
generic_context,
get_generic_context_value,
namedata_context,
)
......@@ -53,6 +53,3 @@ def generic_context(key, value):
def get_generic_context_value(key):
return _generic_context_cache[key]
import functools
namedata_context=functools.partial(generic_context, "namedata")
......@@ -43,7 +43,7 @@ def class_member(classtag=None, access=AccessModifier.PRIVATE):
return generator_factory(item_tags=(classtag, "member"), on_store=lambda m: ClassMember(m, access=access), counted=True)
def constructor_parameter(_type, name, classtag=None, constructortag=None):
def constructor_parameter(_type, name, classtag=None, constructortag="default"):
assert classtag
assert constructortag
from cgen import Value
......
......@@ -410,21 +410,15 @@ def name_matrixbackend():
return "mb"
@preamble
def typedef_parameters(name):
return "typedef LocalOperatorParameters {};".format(name)
@symbol
def type_parameters():
typedef_parameters("Params")
return "Params"
return "LocalOperatorParams"
@preamble
def define_parameters(name):
partype = type_parameters()
return "{} {}();".format(partype, name)
return "{} {};".format(partype, name)
@symbol
......@@ -455,9 +449,8 @@ def type_localoperator():
def define_localoperator(name):
loptype = type_localoperator()
ini = name_initree()
# params = name_parameters()
# return "{} {}({}, {});".format(loptype, name, ini, params)
return "{} {}({});".format(loptype, name, ini)
params = name_parameters()
return "{} {}({}, {});".format(loptype, name, ini, params)
@symbol
......
......@@ -8,19 +8,19 @@ from dune.perftool.pdelab.quadrature import (name_quadrature_position,
@symbol
def name_elementgeometry():
def name_entitygeometry():
return 'eg'
@symbol
def name_element():
eg = name_elementgeometry()
return "{}.element()".format(eg)
def name_entity():
eg = name_entitygeometry()
return "{}.entity()".format(eg)
@preamble
def define_geometry(name):
eg = name_elementgeometry()
eg = name_entitygeometry()
return "auto {} = {}.geometry();".format(name,
eg
)
......
......@@ -37,7 +37,7 @@ def name_initree_constructor_param():
def define_initree(name):
param_name = name_initree_constructor_param()
include_file('dune/common/parametertree.hh', filetag="operatorfile")
constructor_parameter("const Dune::ParameterTree&", param_name, classtag="operator", constructortag="iniconstructor")
constructor_parameter("const Dune::ParameterTree&", param_name, classtag="operator")
initializer_list(name, [param_name], classtag="operator")
return "const Dune::ParameterTree& {};".format(name)
......@@ -189,7 +189,7 @@ def cgen_class_from_cache(tag, members=[]):
from dune.perftool.generation import retrieve_cache_items
# Generate the name by concatenating basename and template parameters
basename, fullname = class_type_from_cache("operator")
basename, fullname = class_type_from_cache(tag)
base_classes = [bc for bc in retrieve_cache_items('{} and baseclass'.format(tag))]
constructor_params = [bc for bc in retrieve_cache_items('{} and constructor_param'.format(tag))]
......@@ -231,7 +231,9 @@ def generate_localoperator_kernels(form, namedata):
# Have a data structure collect the generated kernels
operator_kernels = {}
from dune.perftool.generation import namedata_context
import functools
from dune.perftool.generation import generic_context
namedata_context = functools.partial(generic_context, "namedata")
# Generate the necessary residual methods
for integral in form.integrals():
......@@ -270,6 +272,7 @@ def generate_localoperator_file(kernels):
# Write the file!
from dune.perftool.file import generate_file
param = cgen_class_from_cache("parameterclass")
# TODO take the name of this thing from the UFL file
lop = cgen_class_from_cache("operator", members=operator_methods)
generate_file(get_option("operator_file"), "operatorfile", [lop])
generate_file(get_option("operator_file"), "operatorfile", [param, lop])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment