Skip to content
Snippets Groups Projects
Commit 71391900 authored by René Heß's avatar René Heß
Browse files

Create sections in driver block

parent c923bcd2
No related branches found
No related tags found
No related merge requests found
......@@ -312,11 +312,6 @@ def generate_driver():
contents = []
add_section("grid", "Setup grid (view)...")
add_section("fem", "Set up finite element maps...")
add_section("gfs", "Set up grid function spaces...")
add_section("constraints", "Set up constraints container...")
add_section("gridoperator", "Set up grid grid operators...")
add_section("vector", "Set up solution vectors...")
add_section("driverblock", "Set up driver block...")
add_section("timings", "Maybe take performance measurements...")
add_section("solver", "Set up (non)linear solvers...")
......@@ -352,8 +347,17 @@ def generate_driver():
driver = FunctionBody(driver_signature, driver_body)
# Generate driver block
constructor_preamble_order = ["grid",
"fem",
"gridfunction",
"gfs",
"constraints",
"gridoperator",
"vector",
"solver",
"instat"]
from dune.codegen.pdelab.localoperator import cgen_class_from_cache
driver_block = cgen_class_from_cache("driver_block")
driver_block = cgen_class_from_cache("driver_block", constructor_preamble_order=constructor_preamble_order)
filename = get_option("driver_file")
......
......@@ -101,7 +101,7 @@ def declare_composite_grid_function(identifier, name, children, root):
return "std::shared_ptr<{}> {};".format(composite_gfs_type, name)
@preamble(section="vector", kernel="driver_block")
@preamble(section="gridfunction", kernel="driver_block")
def define_composite_grid_function(identifier, name, children, root=True):
declare_composite_grid_function(identifier, name, children, root)
composite_gfs_type = type_composite_grid_function(identifier, children, root)
......@@ -159,7 +159,7 @@ def declare_function(name, boolean):
return "std::shared_ptr<{}> {};".format(function_type, name)
@preamble(section="vector", kernel="driver_block")
@preamble(section="gridfunction", kernel="driver_block")
def define_function(name, func, boolean):
declare_function(name, boolean)
function_type = type_function(boolean)
......@@ -224,7 +224,7 @@ def declare_grid_function(identifier, name, root):
return "std::shared_ptr<{}> {};".format(grid_function_type, name)
@preamble(section="vector", kernel="driver_block")
@preamble(section="gridfunction", kernel="driver_block")
def define_grid_function(identifier, name, func, root=True):
declare_grid_function(identifier, name, root)
gv = name_leafview()
......
......@@ -135,7 +135,7 @@ def local_operator_likwid():
return "{}->register_likwid_timers();".format(lop_name)
@preamble(section="timings")
@preamble(section="timings", kernel="main")
def local_operator_ssc_marks():
lop_name = main_name_localoperator(get_form_ident())
return "{}->dump_ssc_marks();".format(lop_name)
......
......@@ -4,12 +4,12 @@ from dune.codegen.ufl.visitor import UFL2LoopyVisitor
import pymbolic.primitives as prim
@preamble(section="gridoperator")
@preamble(section="gridoperator", kernel="driver_block")
def set_lop_to_starting_time():
from dune.codegen.pdelab.driver import get_form_ident
from dune.codegen.pdelab.driver.gridoperator import name_localoperator
lop = name_localoperator(get_form_ident())
return "{}.setTime(0.0);".format(lop)
return "{}->setTime(0.0);".format(lop)
class DriverUFL2PymbolicVisitor(UFL2LoopyVisitor):
......
......@@ -576,7 +576,7 @@ def generate_kernels_per_integral(integrals):
yield generate_kernel(integrals)
def extract_kernel_from_cache(tag, name, signature, wrap_in_cgen=True, add_timings=True):
def extract_kernel_from_cache(tag, name, signature, wrap_in_cgen=True, add_timings=True, constructor_preamble_order=None):
# Now extract regular loopy kernel components
from dune.codegen.loopy.target import DuneTarget
domains = [i for i in retrieve_cache_items("{} and domain".format(tag))]
......@@ -634,7 +634,22 @@ def extract_kernel_from_cache(tag, name, signature, wrap_in_cgen=True, add_timin
kernel = vectorize_micro_elements(kernel)
# Now add the preambles to the kernel
preambles = [(i, p) for i, p in enumerate(retrieve_cache_items("{} and preamble".format(tag)))]
if constructor_preamble_order:
def add_section(section_tag):
content = []
tagcontents = [i for i in retrieve_cache_items("preamble and {} and {}".format(tag, section_tag))]
if tagcontents:
content.append("// {}".format(section_tag.capitalize()))
content.extend(tagcontents)
content.append("")
return content
preambles = []
for section in constructor_preamble_order:
preambles = preambles + add_section(section)
preambles = [(i, p) for i, p in enumerate(preambles)]
else:
preambles = [(i, p) for i, p in enumerate(retrieve_cache_items("{} and preamble".format(tag)))]
kernel = kernel.copy(preambles=preambles)
# Remove inames that have become obsolete
......@@ -810,7 +825,7 @@ class LoopyKernelMethod(ClassMember):
ClassMember.__init__(self, content, name=kernel.name if kernel is not None else "")
def cgen_class_from_cache(tag, members=[]):
def cgen_class_from_cache(tag, members=[], constructor_preamble_order=None):
from dune.codegen.generation import retrieve_cache_items
# Sort the given member functions by their name to help debugging by fixing
......@@ -827,7 +842,12 @@ def cgen_class_from_cache(tag, members=[]):
tparams = [i for i in retrieve_cache_items('{} and template_param'.format(tag))]
# Construct the constructor
constructor_knl = extract_kernel_from_cache(tag, "constructor_kernel", None, wrap_in_cgen=False, add_timings=False)
constructor_knl = extract_kernel_from_cache(tag,
"constructor_kernel",
None,
wrap_in_cgen=False,
add_timings=False,
constructor_preamble_order=constructor_preamble_order)
from dune.codegen.loopy.target import DuneTarget
constructor_knl = constructor_knl.copy(target=DuneTarget(declare_temporaries=False))
signature = "{}({})".format(basename, ", ".join(next(iter(p.generate(with_semicolon=False))) for p in constructor_params))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment