From 2b043378f7c675a66d5f6666608b7daf61bb811a Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Mon, 26 Mar 2018 14:26:12 +0200 Subject: [PATCH] Move function_name generation onto symbolic representation --- python/dune/perftool/sumfact/realization.py | 49 +++++---------------- python/dune/perftool/sumfact/symbolic.py | 31 ++++++++----- 2 files changed, 31 insertions(+), 49 deletions(-) diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py index 4ea2e8a5..703f9a6b 100644 --- a/python/dune/perftool/sumfact/realization.py +++ b/python/dune/perftool/sumfact/realization.py @@ -29,6 +29,7 @@ from dune.perftool.sumfact.permutation import (sumfact_permutation_strategy, permute_backward, permute_forward, ) +from dune.perftool.sumfact.quadrature import quadrature_points_per_direction from dune.perftool.sumfact.symbolic import (get_input_output_tuple, SumfactKernel, VectorizedSumfactKernel, @@ -47,39 +48,9 @@ import numpy as np import pymbolic.primitives as prim -necessary_kernel_implementations = generator_factory(item_tags=("kernelimpl",), no_deco=True) - - -@generator_factory(cache_key_generator=lambda s, qp: (s.function_key, qp)) -def _name_kernel_implementation_function(sf, qp): - name = "sfimpl_{}".format("_".join(str(m) for m in sf.matrix_sequence)) - if get_form_option("fastdg"): - if sf.stage == 1: - if isinstance(sf, SumfactKernel): - fastdg = "{}comp{}".format(FEM_name_mangling(sf.input.element), sf.input.element_index) - if isinstance(sf, VectorizedSumfactKernel): - fastdg = "_".join("{}comp{}".format(FEM_name_mangling(i.element), i.element_index) for i in remove_duplicates(sf.input.inputs)) - if sf.stage == 3: - if isinstance(sf, SumfactKernel): - fastdg = "{}comp{}".format(FEM_name_mangling(sf.output.test_element), sf.output.test_element_index) - if sf.within_inames: - fastdg = "{}x{}comp{}".format(fastdg, FEM_name_mangling(sf.output.trial_element), sf.output.trial_element_index) - if isinstance(sf, VectorizedSumfactKernel): - fastdg = "_".join("{}comp{}".format(FEM_name_mangling(i.test_element), i.test_element_index) for i in remove_duplicates(sf.output.outputs)) - if sf.within_inames: - fastdg = "{}x{}".format(fastdg, - "_".join("{}comp{}".format(FEM_name_mangling(i.trial_element), i.trial_element_index) for i in remove_duplicates(sf.output.outputs)) - ) - - name = "{}_fastdg{}_{}".format(name, sf.stage, fastdg) - necessary_kernel_implementations((sf, qp)) - return name - - -def name_kernel_implementation_function(sf): - from dune.perftool.sumfact.quadrature import quadrature_points_per_direction - qp = quadrature_points_per_direction() - return _name_kernel_implementation_function(sf, qp) +# Have a generator function store the necessary sum factorization kernel implementations +# This way then can easily be extracted at the end of the form visiting process +necessary_kernel_implementations = generator_factory(item_tags=("kernelimpl",), cache_key_generator=lambda a: a[0].function_name, no_deco=True) def realize_sum_factorization_kernel(sf, **kwargs): @@ -125,7 +96,6 @@ def _realize_sum_factorization_kernel(sf): insn_dep = insn_dep.union(timer_dep) # Get all the necessary pieces for a function call - funcname = name_kernel_implementation_function(sf) buffers = tuple(name_buffer_storage(sf.buffer, i) for i in range(2)) # Make sure that the storage is allocated and has a certain minimum size @@ -153,8 +123,12 @@ def _realize_sum_factorization_kernel(sf): if sf.stage == 3: fastdg_args = sf.output.fastdg_args + # Trigger generation of the sum factorization kernel function + qp = quadrature_points_per_direction() + necessary_kernel_implementations((sf, qp)) + # Call the function - code = "{}({});".format(funcname, ", ".join(buffers + fastdg_args)) + code = "{}({});".format(sf.function_name, ", ".join(buffers + fastdg_args)) tag = "sumfact_stage{}".format(sf.stage) insn_dep = frozenset({instruction(code=code, depends_on=insn_dep, @@ -334,7 +308,6 @@ def realize_sumfact_kernel_function(sf): }) # Construct a loopy kernel object - name = name_kernel_implementation_function(sf) from dune.perftool.pdelab.localoperator import extract_kernel_from_cache args = ["const char* buffer0", "const char* buffer1"] if get_form_option('fastdg'): @@ -344,7 +317,7 @@ def realize_sumfact_kernel_function(sf): if sf.within_inames: args.append("unsigned int jacobian_offset{}".format(i)) - signature = "void {}({}) const".format(name, ", ".join(args)) - kernel = extract_kernel_from_cache("kernel_default", name, [signature], add_timings=False) + signature = "void {}({}) const".format(sf.function_name, ", ".join(args)) + kernel = extract_kernel_from_cache("kernel_default", sf.function_name, [signature], add_timings=False) delete_cache_items("kernel_default") return kernel diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py index c3be20dd..5ef99849 100644 --- a/python/dune/perftool/sumfact/symbolic.py +++ b/python/dune/perftool/sumfact/symbolic.py @@ -5,6 +5,7 @@ from dune.perftool.generation import (get_counted_variable, subst_rule, transform, ) +from dune.perftool.pdelab.driver import FEM_name_mangling from dune.perftool.pdelab.geometry import local_dimension, world_dimension from dune.perftool.sumfact.quadrature import quadrature_inames from dune.perftool.sumfact.tabulation import BasisTabulationMatrixBase, BasisTabulationMatrixArray @@ -288,15 +289,18 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): # Watch out for the documentation to see which key is used unter what circumstances # @property - def function_key(self): - """ Kernels sharing this key may use the same kernel implementation function """ - fastdg = () + def function_name(self): + """ The name of the function that implements this kernel """ + name = "sfimpl_{}".format("_".join(str(m) for m in self.matrix_sequence)) if get_form_option("fastdg"): if self.stage == 1: - fastdg = (self.input.element, self.input.element_index) + fastdg = "{}comp{}".format(FEM_name_mangling(self.input.element), self.input.element_index) if self.stage == 3: - fastdg = (self.output.test_element, self.output.test_element_index, self.output.trial_element, self.output.trial_element_index) - return tuple(str(m) for m in self.matrix_sequence) + fastdg + fastdg = "{}comp{}".format(FEM_name_mangling(self.output.test_element), self.output.test_element_index) + if self.within_inames: + fastdg = "{}x{}comp{}".format(fastdg, FEM_name_mangling(self.output.trial_element), self.output.trial_element_index) + name = "{}_fastdg{}_{}".format(name, self.stage, fastdg) + return name @property def parallel_key(self): @@ -556,14 +560,19 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) # Watch out for the documentation to see which key is used unter what circumstances # @property - def function_key(self): - fastdg = () + def function_name(self): + name = "sfimpl_{}".format("_".join(str(m) for m in self.matrix_sequence)) if get_form_option("fastdg"): if self.stage == 1: - fastdg = sum(((i.element, i.element_index) for i in remove_duplicates(self.input.inputs)), ()) + fastdg = "_".join("{}comp{}".format(FEM_name_mangling(i.element), i.element_index) for i in remove_duplicates(self.input.inputs)) if self.stage == 3: - fastdg = sum(((o.test_element, o.test_element_index, o.trial_element, o.trial_element_index) for o in remove_duplicates(self.output.outputs)), ()) - return tuple(str(m) for m in self.matrix_sequence) + fastdg + fastdg = "_".join("{}comp{}".format(FEM_name_mangling(i.test_element), i.test_element_index) for i in remove_duplicates(self.output.outputs)) + if self.within_inames: + fastdg = "{}x{}".format(fastdg, + "_".join("{}comp{}".format(FEM_name_mangling(i.trial_element), i.trial_element_index) for i in remove_duplicates(self.output.outputs)) + ) + name = "{}_fastdg{}_{}".format(name, self.stage, fastdg) + return name @property def cache_key(self): -- GitLab