From 08d79cb291efe8fad477f0d0b9c614e98663442c Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Tue, 27 Mar 2018 10:25:19 +0200 Subject: [PATCH] Unify treatment of input (stage 1) and output (stage 3) I was tired of the amount of ifs that changed behaviour depending on stage 1 or 3. This is a cleaner approach. --- python/dune/perftool/sumfact/accumulation.py | 41 +++- python/dune/perftool/sumfact/basis.py | 32 ++- python/dune/perftool/sumfact/geometry.py | 5 +- python/dune/perftool/sumfact/realization.py | 32 +-- python/dune/perftool/sumfact/symbolic.py | 229 ++++++++++--------- 5 files changed, 187 insertions(+), 152 deletions(-) diff --git a/python/dune/perftool/sumfact/accumulation.py b/python/dune/perftool/sumfact/accumulation.py index d576562e..255d0909 100644 --- a/python/dune/perftool/sumfact/accumulation.py +++ b/python/dune/perftool/sumfact/accumulation.py @@ -24,6 +24,7 @@ from dune.perftool.options import (get_form_option, ) from dune.perftool.loopy.flatten import flatten_index from dune.perftool.sumfact.quadrature import nest_quadrature_loops +from dune.perftool.pdelab.driver import FEM_name_mangling from dune.perftool.pdelab.localoperator import determine_accumulation_space from dune.perftool.pdelab.restriction import restricted_name from dune.perftool.pdelab.signatures import assembler_routine_name @@ -35,7 +36,7 @@ from dune.perftool.sumfact.tabulation import (basis_functions_per_direction, from dune.perftool.sumfact.switch import (get_facedir, get_facemod, ) -from dune.perftool.sumfact.symbolic import SumfactKernel, SumfactKernelOutputBase +from dune.perftool.sumfact.symbolic import SumfactKernel, SumfactKernelInterfaceBase from dune.perftool.ufl.modified_terminals import extract_modified_arguments from dune.perftool.tools import get_pymbolic_basename, get_leaf from dune.perftool.error import PerftoolError @@ -84,7 +85,7 @@ def accum_iname(element, bound, i): return sumfact_iname(bound, "accum{}".format(suffix)) -class AccumulationOutput(SumfactKernelOutputBase, ImmutableRecord): +class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord): def __init__(self, accumvar=None, restriction=None, @@ -106,6 +107,14 @@ class AccumulationOutput(SumfactKernelOutputBase, ImmutableRecord): def __repr__(self): return ImmutableRecord.__repr__(self) + @property + def stage(self): + return 3 + + @property + def direct_is_possible(self): + return get_form_option("fastdg") + @property def within_inames(self): if self.trial_element is None: @@ -202,7 +211,17 @@ class AccumulationOutput(SumfactKernelOutputBase, ImmutableRecord): **args)}) @property - def fastdg_args(self): + def function_name_suffix(self): + if get_form_option("fastdg"): + suffix = "_fastdg1_{}comp{}".format(FEM_name_mangling(self.test_element), self.test_element_index) + if self.within_inames: + suffix = "{}x{}comp{}".format(suffix, FEM_name_mangling(self.trial_element), self.trial_element_index) + return suffix + else: + return "" + + @property + def function_args(self): if get_form_option("fastdg"): ret = ("{}.data()".format(self.accumvar),) if get_form_option("fastdg") and self.within_inames: @@ -214,6 +233,17 @@ class AccumulationOutput(SumfactKernelOutputBase, ImmutableRecord): else: return () + @property + def signature_args(self): + if get_form_option('fastdg'): + ret = ("double* fastdg0",) + if self.within_inames: + ret = ret + ("unsigned int jacobian_offset0",) + return ret + else: + return () + + def _local_sizes(element): from ufl import FiniteElement, MixedElement @@ -411,9 +441,8 @@ def generate_accumulation_instruction(expr, visitor): ) sf = SumfactKernel(matrix_sequence=matrix_sequence, - stage=3, position_priority=priority, - output=output, + interface=output, predicates=predicates, ) @@ -497,7 +526,7 @@ def generate_accumulation_instruction(expr, visitor): result, insn_dep = realize_sum_factorization_kernel(vsf.copy(insn_dep=vsf.insn_dep.union(insn_dep))) if not get_form_option("fastdg"): - insn_dep = vsf.output.realize(vsf, result, insn_dep) + insn_dep = vsf.interface.realize(vsf, result, insn_dep) if get_option("instrumentation_level") >= 4: assert vsf.stage == 3 diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py index ffb2eef9..3c48f419 100644 --- a/python/dune/perftool/sumfact/basis.py +++ b/python/dune/perftool/sumfact/basis.py @@ -32,7 +32,7 @@ from dune.perftool.pdelab.argument import name_coefficientcontainer from dune.perftool.pdelab.geometry import (local_dimension, world_dimension, ) -from dune.perftool.sumfact.symbolic import SumfactKernel, SumfactKernelInputBase +from dune.perftool.sumfact.symbolic import SumfactKernel, SumfactKernelInterfaceBase from dune.perftool.options import get_form_option from dune.perftool.pdelab.driver import FEM_name_mangling from dune.perftool.pdelab.restriction import restricted_name @@ -50,7 +50,7 @@ from loopy.match import Writes import pymbolic.primitives as prim -class LFSSumfactKernelInput(SumfactKernelInputBase, ImmutableRecord): +class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord): def __init__(self, coeff_func=None, element=None, @@ -71,7 +71,11 @@ class LFSSumfactKernelInput(SumfactKernelInputBase, ImmutableRecord): return repr(self) @property - def direct_input_is_possible(self): + def stage(self): + return 1 + + @property + def direct_is_possible(self): return get_form_option("fastdg") def realize(self, sf, insn_dep, index=0): @@ -113,13 +117,27 @@ class LFSSumfactKernelInput(SumfactKernelInputBase, ImmutableRecord): return prim.Subscript(prim.Variable(arg), inames) @property - def fastdg_args(self): - if self.direct_input_is_possible: + def function_name_suffix(self): + if get_form_option("fastdg"): + return "_fastdg1_{}comp{}".format(FEM_name_mangling(self.element), self.element_index) + else: + return "" + + @property + def function_args(self): + if get_form_option("fastdg"): func = self.coeff_func(self.restriction) return ("{}.data()".format(func),) else: return () + @property + def signature_args(self): + if get_form_option("fastdg"): + return ("const double* fastdg0",) + else: + return () + def _basis_functions_per_direction(element): """Number of basis functions per direction """ @@ -166,7 +184,7 @@ def pymbolic_coefficient_gradient(element, restriction, index, coeff_func, visit # The sum factorization kernel object gathering all relevant information sf = SumfactKernel(matrix_sequence=matrix_sequence, position_priority=grad_index, - input=inp, + interface=inp, ) from dune.perftool.sumfact.vectorization import attach_vectorization_info @@ -207,7 +225,7 @@ def pymbolic_coefficient(element, restriction, index, coeff_func, visitor): ) sf = SumfactKernel(matrix_sequence=matrix_sequence, - input=inp, + interface=inp, position_priority=3, ) diff --git a/python/dune/perftool/sumfact/geometry.py b/python/dune/perftool/sumfact/geometry.py index d0e96c00..7b78de41 100644 --- a/python/dune/perftool/sumfact/geometry.py +++ b/python/dune/perftool/sumfact/geometry.py @@ -17,7 +17,7 @@ from dune.perftool.pdelab.geometry import (local_dimension, name_geometry, ) from dune.perftool.sumfact.switch import get_facedir -from dune.perftool.sumfact.symbolic import SumfactKernelInputBase +from dune.perftool.sumfact.symbolic import SumfactKernelInterfaceBase from dune.perftool.sumfact.vectorization import attach_vectorization_info from dune.perftool.options import get_form_option, option_switch from dune.perftool.ufl.modified_terminals import Restriction @@ -35,7 +35,7 @@ def corner_iname(): return name -class GeoCornersInput(SumfactKernelInputBase, ImmutableRecord): +class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableRecord): def __init__(self, dir): ImmutableRecord.__init__(self, dir=dir) @@ -45,7 +45,6 @@ class GeoCornersInput(SumfactKernelInputBase, ImmutableRecord): temporary_variable(name, shape=(2 ** local_dimension(), sf.vector_width), custom_base_storage=name_buffer_storage(sf.buffer, 0), - decl_method=buffer_decl(storage, get_sumfact_dtype(sf)), managed=True, ) diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py index 703f9a6b..55e14108 100644 --- a/python/dune/perftool/sumfact/realization.py +++ b/python/dune/perftool/sumfact/realization.py @@ -19,7 +19,6 @@ from dune.perftool.generation import (barrier, from dune.perftool.loopy.flatten import flatten_index from dune.perftool.pdelab.argument import pymbolic_coefficient from dune.perftool.pdelab.basis import shape_as_pymbolic -from dune.perftool.pdelab.driver import FEM_name_mangling from dune.perftool.pdelab.geometry import world_dimension from dune.perftool.options import (get_form_option, get_option, @@ -30,8 +29,7 @@ from dune.perftool.sumfact.permutation import (sumfact_permutation_strategy, permute_forward, ) from dune.perftool.sumfact.quadrature import quadrature_points_per_direction -from dune.perftool.sumfact.symbolic import (get_input_output_tuple, - SumfactKernel, +from dune.perftool.sumfact.symbolic import (SumfactKernel, VectorizedSumfactKernel, ) from dune.perftool.sumfact.vectorization import attach_vectorization_info @@ -113,22 +111,15 @@ def _realize_sum_factorization_kernel(sf): ) # Realize the input if it is not direct - if not sf.input.direct_input_is_possible: - insn_dep = insn_dep.union(sf.input.realize(sf, insn_dep)) - - # Collect function call arguments - fastdg_args = () - if sf.stage == 1: - fastdg_args = sf.input.fastdg_args - if sf.stage == 3: - fastdg_args = sf.output.fastdg_args + if sf.stage == 1 and not sf.interface.direct_is_possible: + insn_dep = insn_dep.union(sf.interface.realize(sf, insn_dep)) # Trigger generation of the sum factorization kernel function qp = quadrature_points_per_direction() necessary_kernel_implementations((sf, qp)) # Call the function - code = "{}({});".format(sf.function_name, ", ".join(buffers + fastdg_args)) + code = "{}({});".format(sf.function_name, ", ".join(buffers + sf.interface.function_args)) tag = "sumfact_stage{}".format(sf.stage) insn_dep = frozenset({instruction(code=code, depends_on=insn_dep, @@ -227,12 +218,12 @@ def realize_sumfact_kernel_function(sf): # * a global data structure (if FastDGGridOperator is in use) # * a value from a global data structure, broadcasted to a vector type (vectorized + FastDGGridOperator) input_inames = (k_expr,) + tuple(prim.Variable(j) for j in out_inames[1:]) - if l == 0 and sf.input.direct_input_is_possible: + if l == 0 and sf.stage == 1 and sf.interface.direct_is_possible: # See comment below input_inames = permute_backward(input_inames, perm) inp_shape = permute_backward(inp_shape, perm) - input_summand = sf.input.realize_direct(inp_shape, input_inames) + input_summand = sf.interface.realize_direct(inp_shape, input_inames) else: # If we did permute the order of a matrices above we also # permuted the order of out_inames. Unfortunately the @@ -297,7 +288,7 @@ def realize_sumfact_kernel_function(sf): if l == len(matrix_sequence) - 1 and get_form_option('fastdg') and sf.stage == 3: if sf.vectorized: insn_args["forced_iname_deps"] = frozenset({vec_iname[0].name}) - insn_dep = sf.output.realize_direct(matprod, output_inames, out_shape, **insn_args) + insn_dep = sf.interface.realize_direct(matprod, output_inames, out_shape, **insn_args) else: # Issue the reduction instruction that implements the multiplication # at the same time store the instruction ID for the next instruction to depend on @@ -309,14 +300,7 @@ def realize_sumfact_kernel_function(sf): # Construct a loopy kernel object from dune.perftool.pdelab.localoperator import extract_kernel_from_cache - args = ["const char* buffer0", "const char* buffer1"] - if get_form_option('fastdg'): - const = "const " if sf.stage == 1 else "" - for i in range(len(get_input_output_tuple(sf))): - args.append("{}double* fastdg{}".format(const, i)) - if sf.within_inames: - args.append("unsigned int jacobian_offset{}".format(i)) - + args = ("const char* buffer0", "const char* buffer1") + sf.interface.signature_args signature = "void {}({}) const".format(sf.function_name, ", ".join(args)) kernel = extract_kernel_from_cache("kernel_default", sf.function_name, [signature], add_timings=False) delete_cache_items("kernel_default") diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py index 5ef99849..5f979cd7 100644 --- a/python/dune/perftool/sumfact/symbolic.py +++ b/python/dune/perftool/sumfact/symbolic.py @@ -5,7 +5,6 @@ from dune.perftool.generation import (get_counted_variable, subst_rule, transform, ) -from dune.perftool.pdelab.driver import FEM_name_mangling from dune.perftool.pdelab.geometry import local_dimension, world_dimension from dune.perftool.sumfact.quadrature import quadrature_inames from dune.perftool.sumfact.tabulation import BasisTabulationMatrixBase, BasisTabulationMatrixArray @@ -23,58 +22,85 @@ import frozendict import inspect -class SumfactKernelInputBase(object): +class SumfactKernelInterfaceBase(object): + """ A base class for the input/output of a sum factorization kernel + In stage 1, this represents the input object, in stage 3 the output object. + """ + def realize(self, *a, **kw): + raise NotImplementedError + + def realize_direct(self, *a, **kw): + raise NotImplementedError + @property - def direct_input_is_possible(self): - return False + def within_inames(self): + return () - def realize(self, sf, dep, index=0): - return dep + @property + def direct_is_possible(self): + return False - def realize_direct(self, inames): + @property + def stage(self): raise NotImplementedError + @property + def function_args(self): + return () + + @property + def signature_args(self): + return () + + @property + def function_name_suffix(self): + return "" + def __repr__(self): - return "SumfactKernelInputBase()" + return "SumfactKernelInterfaceBase()" -class VectorSumfactKernelInput(SumfactKernelInputBase): - def __init__(self, inputs): - assert(isinstance(inputs, tuple)) - self.inputs = inputs +class VectorSumfactKernelInput(SumfactKernelInterfaceBase): + def __init__(self, interfaces): + assert(isinstance(interfaces, tuple)) + self.interfaces = interfaces def __repr__(self): - return "_".join(repr(i) for i in self.inputs) + return "_".join(repr(i) for i in self.interfaces) + + @property + def stage(self): + return 1 @property - def direct_input_is_possible(self): - return all(i.direct_input_is_possible for i in self.inputs) + def direct_is_possible(self): + return all(i.direct_is_possible for i in self.interfaces) def realize(self, sf, dep): - for i, inp in enumerate(self.inputs): + for i, inp in enumerate(self.interfaces): dep = dep.union(inp.realize(sf, dep, index=i)) return dep def realize_direct(self, shape, inames): # Check whether the input exhibits a favorable structure # (whether we can broadcast scalar values into SIMD registers) - total = set(self.inputs) - lower = set(self.inputs[:len(self.inputs) // 2]) - upper = set(self.inputs[len(self.inputs) // 2:]) + total = set(self.interfaces) + lower = set(self.interfaces[:len(self.interfaces) // 2]) + upper = set(self.interfaces[len(self.interfaces) // 2:]) if len(total) == 1: # All input coefficients use the exact same input coefficient. # We implement this by broadcasting it into a SIMD register return prim.Call(ExplicitVCLCast(dtype_floatingpoint()), - (self.inputs[0].realize_direct(shape, inames),) + (self.interfaces[0].realize_direct(shape, inames),) ) elif len(total) == 2 and len(lower) == 1 and len(upper) == 1: # The lower and the upper part of the SIMD register use # the same input coefficient, we combine the SIMD register # from two shorter SIMD types return prim.Call(VCLLowerUpperLoad(dtype_floatingpoint()), - (self.inputs[0].realize_direct(shape, inames), - self.inputs[len(self.inputs) // 2].realize_direct(shape, inames, which=1), + (self.interfaces[0].realize_direct(shape, inames), + self.interfaces[len(self.interfaces) // 2].realize_direct(shape, inames, which=1), ) ) else: @@ -83,36 +109,37 @@ class VectorSumfactKernelInput(SumfactKernelInputBase): raise NotImplementedError("SIMD loads from scalars not implemented!") @property - def fastdg_args(self): - return sum((i.fastdg_args for i in remove_duplicates(self.inputs)), ()) + def function_args(self): + return sum((i.function_args for i in remove_duplicates(self.interfaces)), ()) + @property + def signature_args(self): + return tuple("const double* fastdg{}".format(i)for i, _ in enumerate(remove_duplicates(self.interfaces))) -class SumfactKernelOutputBase(object): @property - def within_inames(self): - return () + def function_name_suffix(self): + return "".join(i.function_name_suffix for i in remove_duplicates(self.interfaces)) - def realize(self, sf, result, insn_dep): - return dep - def realize_direct(self, result, inames, shape, args): - raise NotImplementedError +class VectorSumfactKernelOutput(SumfactKernelInterfaceBase): + def __init__(self, interfaces): + self.interfaces = interfaces def __repr__(self): - return "SumfactKernelOutputBase()" - + return "_".join(repr(o) for o in self.interfaces) -class VectorSumfactKernelOutput(SumfactKernelOutputBase): - def __init__(self, outputs): - self.outputs = outputs + @property + def stage(self): + return 3 - def __repr__(self): - return "_".join(repr(o) for o in self.outputs) + @property + def within_inames(self): + return self.interfaces[0].within_inames def _add_hadd(self, o, result): hadd_function = "horizontal_add" - if len(set(self.outputs)) > 1: - pos = self.outputs.index(o) + if len(set(self.interfaces)) > 1: + pos = self.interfaces.index(o) if pos == 0: hadd_function = "horizontal_add_lower" else: @@ -121,10 +148,10 @@ class VectorSumfactKernelOutput(SumfactKernelOutputBase): return prim.Call(prim.Variable(hadd_function), (result,)) def realize(self, sf, result, insn_dep): - outputs = set(self.outputs) + outputs = set(self.interfaces) - trial_element, = set(o.trial_element for o in self.outputs) - trial_element_index, = set(o.trial_element_index for o in self.outputs) + trial_element, = set(o.trial_element for o in self.interfaces) + trial_element_index, = set(o.trial_element_index for o in self.interfaces) from dune.perftool.sumfact.accumulation import accum_iname element = get_leaf(trial_element, trial_element_index) if trial_element is not None else None inames = tuple(accum_iname(element, mat.rows, i) @@ -140,7 +167,7 @@ class VectorSumfactKernelOutput(SumfactKernelOutputBase): return deps def realize_direct(self, result, inames, shape, **args): - outputs = set(self.outputs) + outputs = set(self.interfaces) # If multiple horizontal_add's are to be performed with 'result' # we need to precompute the result! @@ -153,14 +180,33 @@ class VectorSumfactKernelOutput(SumfactKernelOutputBase): deps = frozenset() for o in outputs: hadd_result = self._add_hadd(o, result) - which = tuple(remove_duplicates(self.outputs)).index(o) + which = tuple(remove_duplicates(self.interfaces)).index(o) deps = deps.union(o.realize_direct(hadd_result, inames, shape, which=which, **args)) return deps @property - def fastdg_args(self): - return sum((i.fastdg_args for i in remove_duplicates(self.outputs)), ()) + def function_args(self): + if get_form_option("fastdg"): + return sum((i.function_args for i in remove_duplicates(self.interfaces)), ()) + else: + return() + + @property + def signature_args(self): + if get_form_option("fastdg"): + def _get_pair(i): + ret = ("double* fastdg{}".format(i),) + if self.within_inames: + ret = ret + ("unsigned int jacobian_offset{}".format(i),) + return ret + return sum((_get_pair(i) for i, _ in enumerate(remove_duplicates(self.interfaces))), ()) + else: + return () + + @property + def function_name_suffix(self): + return "".join(i.function_name_suffix for i in remove_duplicates(self.interfaces)) class SumfactKernelBase(object): @@ -171,11 +217,9 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): def __init__(self, matrix_sequence=None, buffer=None, - stage=1, position_priority=None, insn_dep=frozenset(), - input=SumfactKernelInputBase(), - output=SumfactKernelOutputBase(), + interface=SumfactKernelInterfaceBase(), predicates=frozenset(), ): """Create a sum factorization kernel @@ -229,31 +273,18 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): for intermediate results. The memory is expected to be pre-initialized with the input or you have to provide direct_input (FastDGGridOperator). - stage: 1 or 3 position_priority: Will be used in the dry run to order kernels when doing vectorization e.g. (dx u,dy u,dz u, u). - restriction: Restriction for faces values. insn_dep: An instruction ID that the first issued instruction should depend upon. All following ones will depend on each other. - input: An SumfactKernelInputBase instance describing the input of the kernel - accumvar: The accumulation variable to accumulate into - trial_element: The leaf element of the trial function space. - Used to correctly nest stage 3 in the jacobian case. - test_element: The leaf element of the test function space - Used to compute offsets in the fastdg case. - test_element_index: the component of the test_element - trial_element_index: the component of the trial_element + interface: An SumfactKernelInterfaceBase instance describing the input + (stage 1) or output (stage 3) of the kernel """ # Assert the inputs! assert isinstance(matrix_sequence, tuple) assert all(isinstance(m, BasisTabulationMatrixBase) for m in matrix_sequence) - - assert stage in (1, 3) - - assert isinstance(input, SumfactKernelInputBase) - assert isinstance(output, SumfactKernelOutputBase) - + assert isinstance(interface, SumfactKernelInterfaceBase) assert isinstance(insn_dep, frozenset) # The following construction is a bit weird: Dict comprehensions do not have @@ -279,7 +310,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): def __str__(self): # Above stringifier just calls back into this return "SF{}:[{}]->[{}]".format(self.stage, - str(self.input), + str(self.interface), ", ".join(str(m) for m in self.matrix_sequence)) mapper_method = "map_sumfact_kernel" @@ -291,16 +322,8 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): @property def function_name(self): """ The name of the function that implements this kernel """ - name = "sfimpl_{}".format("_".join(str(m) for m in self.matrix_sequence)) - if get_form_option("fastdg"): - if self.stage == 1: - fastdg = "{}comp{}".format(FEM_name_mangling(self.input.element), self.input.element_index) - if self.stage == 3: - fastdg = "{}comp{}".format(FEM_name_mangling(self.output.test_element), self.output.test_element_index) - if self.within_inames: - fastdg = "{}x{}comp{}".format(fastdg, FEM_name_mangling(self.output.trial_element), self.output.trial_element_index) - name = "{}_fastdg{}_{}".format(name, self.stage, fastdg) - return name + return "sfimpl_{}{}".format("_".join(str(m) for m in self.matrix_sequence), + self.interface.function_name_suffix) @property def parallel_key(self): @@ -327,7 +350,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): work on the same input coefficient (stage 1) or accumulate into the same thing (stage 3) """ - return (repr(self.input), repr(self.output)) + return repr(self.interface) # # Some convenience methods to extract information about the sum factorization kernel @@ -337,7 +360,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): if self.parallel_key != other.parallel_key: return self.parallel_key < other.parallel_key if self.inout_key != other.inout_key: - return self.input_key < other.input_key + return self.inout_key < other.inout_key if self.position_priority == other.position_priority: return repr(self) < repr(other) if self.position_priority is None: @@ -361,7 +384,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): @property def within_inames(self): - return self.output.within_inames + return self.interface.within_inames def vec_index(self, sf): """ Map an unvectorized sumfact kernel object to its position @@ -447,6 +470,10 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): def tag(self): return "sumfac" + @property + def stage(self): + return self.interface.stage + # # Define properties for conformity with the interface of VectorizedSumfactKernel # @@ -548,7 +575,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) def __str__(self): # Above stringifier just calls back into this return "VSF{}:[{}]->[{}]".format(self.stage, - ", ".join(str(k.input) for k in self.kernels), + ", ".join(str(k.interface) for k in self.kernels), ", ".join(str(mat) for mat in self.matrix_sequence)) mapper_method = "map_vectorized_sumfact_kernel" @@ -561,18 +588,8 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) # @property def function_name(self): - name = "sfimpl_{}".format("_".join(str(m) for m in self.matrix_sequence)) - if get_form_option("fastdg"): - if self.stage == 1: - fastdg = "_".join("{}comp{}".format(FEM_name_mangling(i.element), i.element_index) for i in remove_duplicates(self.input.inputs)) - if self.stage == 3: - fastdg = "_".join("{}comp{}".format(FEM_name_mangling(i.test_element), i.test_element_index) for i in remove_duplicates(self.output.outputs)) - if self.within_inames: - fastdg = "{}x{}".format(fastdg, - "_".join("{}comp{}".format(FEM_name_mangling(i.trial_element), i.trial_element_index) for i in remove_duplicates(self.output.outputs)) - ) - name = "{}_fastdg{}_{}".format(name, self.stage, fastdg) - return name + return "sfimpl_{}{}".format("_".join(str(m) for m in self.matrix_sequence), + self.interface.function_name_suffix) @property def cache_key(self): @@ -627,14 +644,16 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) # # Define the same properties the normal SumfactKernel defines # - @property - def input(self): - return VectorSumfactKernelInput(tuple(k.input for k in self.kernels)) + def stage(self): + return self.kernels[0].stage @property - def output(self): - return VectorSumfactKernelOutput(tuple(k.output for k in self.kernels)) + def interface(self): + if self.stage == 1: + return VectorSumfactKernelInput(tuple(k.interface for k in self.kernels)) + else: + return VectorSumfactKernelOutput(tuple(k.interface for k in self.kernels)) @property def cache_key(self): @@ -791,17 +810,3 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) to be carried out """ from dune.perftool.sumfact.permutation import flop_cost return flop_cost(self.matrix_sequence) - - -def get_input_output_tuple(sf): - if sf.stage == 1: - if isinstance(sf, SumfactKernel): - return (sf.input,) - if isinstance(sf, VectorizedSumfactKernel): - return tuple(remove_duplicates(sf.input.inputs)) - if sf.stage == 3: - if isinstance(sf, SumfactKernel): - return (sf.output,) - if isinstance(sf, VectorizedSumfactKernel): - return tuple(remove_duplicates(sf.output.outputs)) - assert(False) -- GitLab