Skip to content
Snippets Groups Projects
Commit f6b4eae7 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Use kernel tag in localoperator code generation

parent 38ef3f9f
No related branches found
No related tags found
No related merge requests found
...@@ -44,7 +44,7 @@ class FlipFlopBuffer(object): ...@@ -44,7 +44,7 @@ class FlipFlopBuffer(object):
return name return name
@generator_factory(item_tags=("kernel", "buffer"), cache_key_generator=lambda i, **kw: i) @generator_factory(item_tags=("buffer"), cache_key_generator=lambda i, **kw: i, context_tags=("kernel",))
def initialize_buffer(identifier, base_storage_size=None, num=2): def initialize_buffer(identifier, base_storage_size=None, num=2):
if base_storage_size is None: if base_storage_size is None:
raise PerftoolLoopyError("The buffer for identifier {} has not been initialized.".format(identifier)) raise PerftoolLoopyError("The buffer for identifier {} has not been initialized.".format(identifier))
......
...@@ -10,6 +10,7 @@ from dune.perftool.generation import (backend, ...@@ -10,6 +10,7 @@ from dune.perftool.generation import (backend,
domain, domain,
dump_accumulate_timer, dump_accumulate_timer,
get_backend, get_backend,
get_global_context_value,
global_context, global_context,
iname, iname,
include_file, include_file,
...@@ -474,16 +475,21 @@ def generate_kernel(integrals): ...@@ -474,16 +475,21 @@ def generate_kernel(integrals):
visitor = UFL2LoopyVisitor(interface, measure, indexmap) visitor = UFL2LoopyVisitor(interface, measure, indexmap)
get_backend(interface="accum_insn")(visitor, term, measure, subdomain_id) get_backend(interface="accum_insn")(visitor, term, measure, subdomain_id)
tag = get_global_context_value("kernel")
return extract_kernel_from_cache(tag)
def extract_kernel_from_cache(tag):
# Extract the information, which is needed to create a loopy kernel. # Extract the information, which is needed to create a loopy kernel.
# First extracting it, might be useful to alter it before kernel generation. # First extracting it, might be useful to alter it before kernel generation.
from dune.perftool.generation import retrieve_cache_functions, retrieve_cache_items from dune.perftool.generation import retrieve_cache_functions, retrieve_cache_items
from dune.perftool.loopy.target import DuneTarget from dune.perftool.loopy.target import DuneTarget
domains = [i for i in retrieve_cache_items("domain")] domains = [i for i in retrieve_cache_items("{} and domain".format(tag))]
instructions = [i for i in retrieve_cache_items("instruction")] instructions = [i for i in retrieve_cache_items("{} and instruction".format(tag))]
temporaries = {i.name: i for i in retrieve_cache_items("temporary")} temporaries = {i.name: i for i in retrieve_cache_items("{} and temporary".format(tag))}
arguments = [i for i in retrieve_cache_items("argument")] arguments = [i for i in retrieve_cache_items("{} and argument".format(tag))]
silenced = [l for l in retrieve_cache_items("silenced_warning")] silenced = [l for l in retrieve_cache_items("{} and silenced_warning".format(tag))]
transformations = [t for t in retrieve_cache_items("transformation")] transformations = [t for t in retrieve_cache_items("{} and transformation".format(tag))]
# Construct an options object # Construct an options object
from loopy import Options from loopy import Options
...@@ -524,7 +530,7 @@ def generate_kernel(integrals): ...@@ -524,7 +530,7 @@ def generate_kernel(integrals):
raise NotImplementedError("Only vectorizing sumfactoized code right now!") raise NotImplementedError("Only vectorizing sumfactoized code right now!")
# Now add the preambles to the kernel # Now add the preambles to the kernel
preambles = [(i, p) for i, p in enumerate(retrieve_cache_items("preamble"))] preambles = [(i, p) for i, p in enumerate(retrieve_cache_items("{} and preamble".format(tag)))]
kernel = kernel.copy(preambles=preambles) kernel = kernel.copy(preambles=preambles)
# Do the loopy preprocessing! # Do the loopy preprocessing!
...@@ -532,7 +538,7 @@ def generate_kernel(integrals): ...@@ -532,7 +538,7 @@ def generate_kernel(integrals):
# All items with the kernel tags can be destroyed once a kernel has been generated # All items with the kernel tags can be destroyed once a kernel has been generated
from dune.perftool.generation import delete_cache_items from dune.perftool.generation import delete_cache_items
delete_cache_items("(not file) and (not clazz)") delete_cache_items(tag)
return kernel return kernel
...@@ -682,7 +688,8 @@ def generate_localoperator_kernels(formdata, data): ...@@ -682,7 +688,8 @@ def generate_localoperator_kernels(formdata, data):
enum_pattern() enum_pattern()
pattern_baseclass() pattern_baseclass()
enum_alpha() enum_alpha()
kernel = generate_kernel(form.integrals_by_type(measure)) with global_context(kernel=assembler_routine_name()):
kernel = generate_kernel(form.integrals_by_type(measure))
# Maybe add numerical differentiation # Maybe add numerical differentiation
if get_option("numerical_jacobian"): if get_option("numerical_jacobian"):
...@@ -737,7 +744,8 @@ def generate_localoperator_kernels(formdata, data): ...@@ -737,7 +744,8 @@ def generate_localoperator_kernels(formdata, data):
with global_context(form_type="jacobian"): with global_context(form_type="jacobian"):
for measure in set(i.integral_type() for i in jacform.integrals()): for measure in set(i.integral_type() for i in jacform.integrals()):
with global_context(integral_type=measure): with global_context(integral_type=measure):
kernel = generate_kernel(jacform.integrals_by_type(measure)) with global_context(kernel=assembler_routine_name()):
kernel = generate_kernel(jacform.integrals_by_type(measure))
operator_kernels[(measure, 'jacobian')] = kernel operator_kernels[(measure, 'jacobian')] = kernel
# Generate dummy functions for those kernels, that vanished in the differentiation process # Generate dummy functions for those kernels, that vanished in the differentiation process
...@@ -762,7 +770,8 @@ def generate_localoperator_kernels(formdata, data): ...@@ -762,7 +770,8 @@ def generate_localoperator_kernels(formdata, data):
with global_context(form_type="jacobian_apply"): with global_context(form_type="jacobian_apply"):
for measure in set(i.integral_type() for i in jac_apply_form.integrals()): for measure in set(i.integral_type() for i in jac_apply_form.integrals()):
with global_context(integral_type=measure): with global_context(integral_type=measure):
kernel = generate_kernel(jac_apply_form.integrals_by_type(measure)) with global_context(kernel=assembler_routine_name()):
kernel = generate_kernel(jac_apply_form.integrals_by_type(measure))
operator_kernels[(measure, 'jacobian_apply')] = kernel operator_kernels[(measure, 'jacobian_apply')] = kernel
# Generate dummy functions for those kernels, that vanished in the differentiation process # Generate dummy functions for those kernels, that vanished in the differentiation process
......
...@@ -99,7 +99,7 @@ def name_leaf_lfs(leaf_element, restriction, val=None): ...@@ -99,7 +99,7 @@ def name_leaf_lfs(leaf_element, restriction, val=None):
return val return val
@generator_factory(cache_key_generator=lambda e, r, c, **kw: (e, r, c)) @generator_factory(cache_key_generator=lambda e, r, c, **kw: (e, r, c), context_tags=("kernel",))
def name_lfs(element, restriction, component, prefix=None): def name_lfs(element, restriction, component, prefix=None):
# Omitting the prefix is only valid upon a second call, which will # Omitting the prefix is only valid upon a second call, which will
# result in a cache hit. # result in a cache hit.
...@@ -178,7 +178,7 @@ def traverse_lfs_tree(arg): ...@@ -178,7 +178,7 @@ def traverse_lfs_tree(arg):
type_gfs(arg.argexpr.ufl_element(), basetype=gfs_basename, index_stack=()) type_gfs(arg.argexpr.ufl_element(), basetype=gfs_basename, index_stack=())
@generator_factory(item_tags=("iname",), cache_key_generator=lambda e, r, c: (e, c)) @generator_factory(item_tags=("iname",), cache_key_generator=lambda e, r, c: (e, c), context_tags=("kernel",))
def _lfs_iname(element, restriction, context): def _lfs_iname(element, restriction, context):
lfs = name_leaf_lfs(element, restriction) lfs = name_leaf_lfs(element, restriction)
bound = name_lfs_bound(lfs) bound = name_lfs_bound(lfs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment