Skip to content
Snippets Groups Projects
Commit bca56732 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

write generated kernels into the operator file

parent 8bf01d3c
No related branches found
No related tags found
No related merge requests found
......@@ -35,9 +35,10 @@ class ClassMember(Generable):
assert isinstance(member, Generable)
def generate(self):
yield "\n\n"
yield "{}:\n".format(access_modifier_string(self.access))
for line in self.member.generate():
yield line
yield line + '\n'
class Constructor(Generable):
def __init__(self, block=Block([]), arg_decls=[], clsname=None, access=AccessModifier.PUBLIC):
......@@ -82,15 +83,6 @@ class Class(Generable):
for con in constructors:
assert isinstance(con, Constructor)
def access_modifier(self, am):
if am == AccessModifier.PRIVATE:
return "private"
if am == AccessModifier.PUBLIC:
return "public"
if am == AccessModifier.PROTECTED:
return "protected"
raise ValueError("Unknown access modifier in class generation")
def generate(self):
# define the class header
from cgen import Value
......@@ -109,8 +101,6 @@ class Class(Generable):
# add base class inheritance
yield ",\n".join(" : {} {}\n".format(access_modifier_string(bc.inheritance), bc.name) for bc in self.base_classes)
# Set the
# Now yield the entire block
block = Block(contents=self.constructors + self.members)
......
......@@ -110,17 +110,23 @@ def generate_term(integrand=None, measure=None):
kernel = preprocess_kernel(kernel)
# Return the actual code (might instead return kernels...)
from loopy import generate_code
return str(generate_code(kernel)[0])
return kernel
from dune.perftool.cgen.clazz import ClassMember
class AssemblyMethod(ClassMember):
def __init__(self, signature, kernel):
from loopy import generate_code
from cgen import LiteralLines
content = LiteralLines('\n'+'\n'.join(signature) + '\n' + generate_code(kernel)[0])
ClassMember.__init__(self, content)
def cgen_class_from_cache(name, tag):
def cgen_class_from_cache(name, tag, members=[]):
from dune.perftool.generation import retrieve_cache_items
base_classes = [bc for bc in retrieve_cache_items(tags=(tag, "baseclass"), union=False)]
from dune.perftool.cgen import Class
return Class(name, base_classes=base_classes)
return Class(name, base_classes=base_classes, members=members)
def generate_localoperator(form, operatorfile):
......@@ -137,9 +143,9 @@ def generate_localoperator(form, operatorfile):
# Generate the necessary residual methods
for integral in form.integrals():
body = generate_term(integrand=integral.integrand(), measure=integral.integral_type())
kernel = generate_term(integrand=integral.integrand(), measure=integral.integral_type())
signature = measure_specific_details(integral.integral_type())["residual_signature"]
operator_methods.append((signature, body))
operator_methods.append(AssemblyMethod(signature, kernel))
# Generate the necessary jacobian methods
from dune.perftool.options import get_option
......@@ -151,9 +157,9 @@ def generate_localoperator(form, operatorfile):
jacform = expand_derivatives(derivative(form, form.coefficients()[0]))
for integral in jacform.integrals():
body = generate_term(integrand=integral.integrand(), measure=integral.integral_type())
kernel = generate_term(integrand=integral.integrand(), measure=integral.integral_type())
signature = measure_specific_details(integral.integral_type())["jacobian_signature"]
operator_methods.append((signature, body))
operator_methods.append(AssemblyMethod(signature, kernel))
# TODO: JacobianApply for matrix-free computations.
......@@ -170,5 +176,5 @@ def generate_localoperator(form, operatorfile):
from dune.perftool.file import generate_file
# TODO take the name of this thing from the UFL file
lop = cgen_class_from_cache("LocalOperator", "operator")
lop = cgen_class_from_cache("LocalOperator", "operator", members=operator_methods)
generate_file(get_option("operator_file"), "operator", [lop])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment