Skip to content
Snippets Groups Projects
Commit 2b043378 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Move function_name generation onto symbolic representation

parent 61adf2f4
No related branches found
No related tags found
No related merge requests found
...@@ -29,6 +29,7 @@ from dune.perftool.sumfact.permutation import (sumfact_permutation_strategy, ...@@ -29,6 +29,7 @@ from dune.perftool.sumfact.permutation import (sumfact_permutation_strategy,
permute_backward, permute_backward,
permute_forward, permute_forward,
) )
from dune.perftool.sumfact.quadrature import quadrature_points_per_direction
from dune.perftool.sumfact.symbolic import (get_input_output_tuple, from dune.perftool.sumfact.symbolic import (get_input_output_tuple,
SumfactKernel, SumfactKernel,
VectorizedSumfactKernel, VectorizedSumfactKernel,
...@@ -47,39 +48,9 @@ import numpy as np ...@@ -47,39 +48,9 @@ import numpy as np
import pymbolic.primitives as prim import pymbolic.primitives as prim
necessary_kernel_implementations = generator_factory(item_tags=("kernelimpl",), no_deco=True) # Have a generator function store the necessary sum factorization kernel implementations
# This way then can easily be extracted at the end of the form visiting process
necessary_kernel_implementations = generator_factory(item_tags=("kernelimpl",), cache_key_generator=lambda a: a[0].function_name, no_deco=True)
@generator_factory(cache_key_generator=lambda s, qp: (s.function_key, qp))
def _name_kernel_implementation_function(sf, qp):
name = "sfimpl_{}".format("_".join(str(m) for m in sf.matrix_sequence))
if get_form_option("fastdg"):
if sf.stage == 1:
if isinstance(sf, SumfactKernel):
fastdg = "{}comp{}".format(FEM_name_mangling(sf.input.element), sf.input.element_index)
if isinstance(sf, VectorizedSumfactKernel):
fastdg = "_".join("{}comp{}".format(FEM_name_mangling(i.element), i.element_index) for i in remove_duplicates(sf.input.inputs))
if sf.stage == 3:
if isinstance(sf, SumfactKernel):
fastdg = "{}comp{}".format(FEM_name_mangling(sf.output.test_element), sf.output.test_element_index)
if sf.within_inames:
fastdg = "{}x{}comp{}".format(fastdg, FEM_name_mangling(sf.output.trial_element), sf.output.trial_element_index)
if isinstance(sf, VectorizedSumfactKernel):
fastdg = "_".join("{}comp{}".format(FEM_name_mangling(i.test_element), i.test_element_index) for i in remove_duplicates(sf.output.outputs))
if sf.within_inames:
fastdg = "{}x{}".format(fastdg,
"_".join("{}comp{}".format(FEM_name_mangling(i.trial_element), i.trial_element_index) for i in remove_duplicates(sf.output.outputs))
)
name = "{}_fastdg{}_{}".format(name, sf.stage, fastdg)
necessary_kernel_implementations((sf, qp))
return name
def name_kernel_implementation_function(sf):
from dune.perftool.sumfact.quadrature import quadrature_points_per_direction
qp = quadrature_points_per_direction()
return _name_kernel_implementation_function(sf, qp)
def realize_sum_factorization_kernel(sf, **kwargs): def realize_sum_factorization_kernel(sf, **kwargs):
...@@ -125,7 +96,6 @@ def _realize_sum_factorization_kernel(sf): ...@@ -125,7 +96,6 @@ def _realize_sum_factorization_kernel(sf):
insn_dep = insn_dep.union(timer_dep) insn_dep = insn_dep.union(timer_dep)
# Get all the necessary pieces for a function call # Get all the necessary pieces for a function call
funcname = name_kernel_implementation_function(sf)
buffers = tuple(name_buffer_storage(sf.buffer, i) for i in range(2)) buffers = tuple(name_buffer_storage(sf.buffer, i) for i in range(2))
# Make sure that the storage is allocated and has a certain minimum size # Make sure that the storage is allocated and has a certain minimum size
...@@ -153,8 +123,12 @@ def _realize_sum_factorization_kernel(sf): ...@@ -153,8 +123,12 @@ def _realize_sum_factorization_kernel(sf):
if sf.stage == 3: if sf.stage == 3:
fastdg_args = sf.output.fastdg_args fastdg_args = sf.output.fastdg_args
# Trigger generation of the sum factorization kernel function
qp = quadrature_points_per_direction()
necessary_kernel_implementations((sf, qp))
# Call the function # Call the function
code = "{}({});".format(funcname, ", ".join(buffers + fastdg_args)) code = "{}({});".format(sf.function_name, ", ".join(buffers + fastdg_args))
tag = "sumfact_stage{}".format(sf.stage) tag = "sumfact_stage{}".format(sf.stage)
insn_dep = frozenset({instruction(code=code, insn_dep = frozenset({instruction(code=code,
depends_on=insn_dep, depends_on=insn_dep,
...@@ -334,7 +308,6 @@ def realize_sumfact_kernel_function(sf): ...@@ -334,7 +308,6 @@ def realize_sumfact_kernel_function(sf):
}) })
# Construct a loopy kernel object # Construct a loopy kernel object
name = name_kernel_implementation_function(sf)
from dune.perftool.pdelab.localoperator import extract_kernel_from_cache from dune.perftool.pdelab.localoperator import extract_kernel_from_cache
args = ["const char* buffer0", "const char* buffer1"] args = ["const char* buffer0", "const char* buffer1"]
if get_form_option('fastdg'): if get_form_option('fastdg'):
...@@ -344,7 +317,7 @@ def realize_sumfact_kernel_function(sf): ...@@ -344,7 +317,7 @@ def realize_sumfact_kernel_function(sf):
if sf.within_inames: if sf.within_inames:
args.append("unsigned int jacobian_offset{}".format(i)) args.append("unsigned int jacobian_offset{}".format(i))
signature = "void {}({}) const".format(name, ", ".join(args)) signature = "void {}({}) const".format(sf.function_name, ", ".join(args))
kernel = extract_kernel_from_cache("kernel_default", name, [signature], add_timings=False) kernel = extract_kernel_from_cache("kernel_default", sf.function_name, [signature], add_timings=False)
delete_cache_items("kernel_default") delete_cache_items("kernel_default")
return kernel return kernel
...@@ -5,6 +5,7 @@ from dune.perftool.generation import (get_counted_variable, ...@@ -5,6 +5,7 @@ from dune.perftool.generation import (get_counted_variable,
subst_rule, subst_rule,
transform, transform,
) )
from dune.perftool.pdelab.driver import FEM_name_mangling
from dune.perftool.pdelab.geometry import local_dimension, world_dimension from dune.perftool.pdelab.geometry import local_dimension, world_dimension
from dune.perftool.sumfact.quadrature import quadrature_inames from dune.perftool.sumfact.quadrature import quadrature_inames
from dune.perftool.sumfact.tabulation import BasisTabulationMatrixBase, BasisTabulationMatrixArray from dune.perftool.sumfact.tabulation import BasisTabulationMatrixBase, BasisTabulationMatrixArray
...@@ -288,15 +289,18 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): ...@@ -288,15 +289,18 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
# Watch out for the documentation to see which key is used unter what circumstances # Watch out for the documentation to see which key is used unter what circumstances
# #
@property @property
def function_key(self): def function_name(self):
""" Kernels sharing this key may use the same kernel implementation function """ """ The name of the function that implements this kernel """
fastdg = () name = "sfimpl_{}".format("_".join(str(m) for m in self.matrix_sequence))
if get_form_option("fastdg"): if get_form_option("fastdg"):
if self.stage == 1: if self.stage == 1:
fastdg = (self.input.element, self.input.element_index) fastdg = "{}comp{}".format(FEM_name_mangling(self.input.element), self.input.element_index)
if self.stage == 3: if self.stage == 3:
fastdg = (self.output.test_element, self.output.test_element_index, self.output.trial_element, self.output.trial_element_index) fastdg = "{}comp{}".format(FEM_name_mangling(self.output.test_element), self.output.test_element_index)
return tuple(str(m) for m in self.matrix_sequence) + fastdg if self.within_inames:
fastdg = "{}x{}comp{}".format(fastdg, FEM_name_mangling(self.output.trial_element), self.output.trial_element_index)
name = "{}_fastdg{}_{}".format(name, self.stage, fastdg)
return name
@property @property
def parallel_key(self): def parallel_key(self):
...@@ -556,14 +560,19 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) ...@@ -556,14 +560,19 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
# Watch out for the documentation to see which key is used unter what circumstances # Watch out for the documentation to see which key is used unter what circumstances
# #
@property @property
def function_key(self): def function_name(self):
fastdg = () name = "sfimpl_{}".format("_".join(str(m) for m in self.matrix_sequence))
if get_form_option("fastdg"): if get_form_option("fastdg"):
if self.stage == 1: if self.stage == 1:
fastdg = sum(((i.element, i.element_index) for i in remove_duplicates(self.input.inputs)), ()) fastdg = "_".join("{}comp{}".format(FEM_name_mangling(i.element), i.element_index) for i in remove_duplicates(self.input.inputs))
if self.stage == 3: if self.stage == 3:
fastdg = sum(((o.test_element, o.test_element_index, o.trial_element, o.trial_element_index) for o in remove_duplicates(self.output.outputs)), ()) fastdg = "_".join("{}comp{}".format(FEM_name_mangling(i.test_element), i.test_element_index) for i in remove_duplicates(self.output.outputs))
return tuple(str(m) for m in self.matrix_sequence) + fastdg if self.within_inames:
fastdg = "{}x{}".format(fastdg,
"_".join("{}comp{}".format(FEM_name_mangling(i.trial_element), i.trial_element_index) for i in remove_duplicates(self.output.outputs))
)
name = "{}_fastdg{}_{}".format(name, self.stage, fastdg)
return name
@property @property
def cache_key(self): def cache_key(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment