From 9b39feaafa4e6daad3ce25145b9132748d40205f Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Mon, 12 Feb 2018 11:33:32 +0100 Subject: [PATCH] Refactor direct input mechanism for further extension --- dune/perftool/sumfact/transposereg.hh | 3 ++ python/dune/perftool/loopy/vcl.py | 4 +- python/dune/perftool/sumfact/basis.py | 43 ++++++++++----- python/dune/perftool/sumfact/realization.py | 34 ++---------- python/dune/perftool/sumfact/symbolic.py | 60 +++++++++++++++++---- 5 files changed, 91 insertions(+), 53 deletions(-) diff --git a/dune/perftool/sumfact/transposereg.hh b/dune/perftool/sumfact/transposereg.hh index 9ed19740..4df76df4 100644 --- a/dune/perftool/sumfact/transposereg.hh +++ b/dune/perftool/sumfact/transposereg.hh @@ -66,6 +66,9 @@ void transpose_reg(Vec8d& a0, Vec8d& a1, Vec8d& a2, Vec8d& a3) a3 = blend8d<4,5,6,7,12,13,14,15>(b1, b3); } +/** TODO: Is this transpose using blend8d superior to the swap_halves + * version below using get_low/get_high? + */ void transpose_reg (Vec8d& a0, Vec8d& a1) { Vec8d b0, b1; diff --git a/python/dune/perftool/loopy/vcl.py b/python/dune/perftool/loopy/vcl.py index 143d9b6c..ead55725 100644 --- a/python/dune/perftool/loopy/vcl.py +++ b/python/dune/perftool/loopy/vcl.py @@ -62,8 +62,10 @@ def get_vcl_typename(nptype, register_size=None, vector_width=None): class ExplicitVCLCast(lp.symbolic.FunctionIdentifier): - def __init__(self, nptype, vector_width): + def __init__(self, nptype, vector_width=None): self.nptype = nptype + if vector_width is None: + vector_width = get_vcl_type_size(nptype) self.vector_width = vector_width def __getinitargs__(self): diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py index a941b8e0..425266f4 100644 --- a/python/dune/perftool/sumfact/basis.py +++ b/python/dune/perftool/sumfact/basis.py @@ -11,6 +11,7 @@ from dune.perftool.generation import (backend, get_counted_variable, get_counter, get_global_context_value, + globalarg, iname, instruction, kernel_cached, @@ -70,7 +71,11 @@ class LFSSumfactKernelInput(SumfactKernelInputBase, ImmutableRecord): def __str__(self): return repr(self) - def realize(self, sf, index, insn_dep): + @property + def direct_input_is_possible(self): + return get_form_option("fastdg") + + def realize(self, sf, insn_dep, index=0): lfs = name_lfs(self.element, self.restriction, self.element_index) basisiname = sumfact_iname(name_lfs_bound(lfs), "basis") container = self.coeff_func(self.restriction) @@ -85,18 +90,30 @@ class LFSSumfactKernelInput(SumfactKernelInputBase, ImmutableRecord): assignee = prim.Subscript(prim.Variable(name), (prim.Variable(basisiname),) + (index,)) - instruction(assignee=assignee, - expression=coeff, - depends_on=sf.insn_dep.union(insn_dep), - tags=frozenset({"sumfact_stage{}".format(sf.stage)}), - ) - - @property - def direct_input(self): - if get_form_option("fastdg"): - return self.coeff_func(self.restriction) - else: - return None + insn = instruction(assignee=assignee, + expression=coeff, + depends_on=sf.insn_dep.union(insn_dep), + tags=frozenset({"sumfact_stage{}".format(sf.stage)}), + ) + + return insn_dep.union(frozenset({insn})) + + def realize_direct(self, shape, inames): + arg = "{}_access{}".format(self.coeff_func(self.restriction), + "_comp{}".format(self.element_index) if self.element_index else "" + ) + + from dune.perftool.sumfact.realization import _dof_offset, alias_data_array + globalarg(arg, + shape=shape, + dim_tags=",".join("f" * len(shape)), + offset=_dof_offset(self.element, self.element_index), + ) + + func = self.coeff_func(self.restriction) + alias_data_array(arg, func) + + return prim.Subscript(prim.Variable(arg), inames) def _basis_functions_per_direction(element): diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py index b0e85d43..cc6704af 100644 --- a/python/dune/perftool/sumfact/realization.py +++ b/python/dune/perftool/sumfact/realization.py @@ -75,22 +75,8 @@ def _realize_sum_factorization_kernel(sf): ), })) - direct_input = sf.input.direct_input - - # Set up the input for stage 1 - if direct_input is None: - if sf.vectorized: - for i, inputsf in enumerate(sf.kernels): - inputsf.input.realize(sf, i, inputsf.insn_dep.union(insn_dep)) - else: - sf.input.realize(sf, 0, insn_dep) - - insn_dep = insn_dep.union(frozenset({lp.match.Writes("input_{}".format(sf.buffer))})) - else: - if sf.input.element_index is None: - direct_input_arg = "{}_access".format(direct_input) - else: - direct_input_arg = "{}_access_comp{}".format(direct_input, sf.input.element_index) + if not sf.input.direct_input_is_possible: + insn_dep = insn_dep.union(sf.input.realize(sf, insn_dep)) # Prepare some dim_tags/shapes for later use ftags = ",".join(["f"] * sf.length) @@ -143,24 +129,12 @@ def _realize_sum_factorization_kernel(sf): # * a global data structure (if FastDGGridOperator is in use) # * a value from a global data structure, broadcasted to a vector type (vectorized + FastDGGridOperator) input_inames = (k_expr,) + tuple(prim.Variable(j) for j in out_inames[1:]) - if l == 0 and direct_input is not None: + if l == 0 and sf.input.direct_input_is_possible: # See comment below input_inames = permute_backward(input_inames, perm) inp_shape = permute_backward(inp_shape, perm) - globalarg(direct_input_arg, - shape=inp_shape, - dim_tags=novec_ftags, - offset=_dof_offset(sf.input.element, sf.input.element_index), - ) - alias_data_array(direct_input_arg, direct_input) - if matrix.vectorized: - input_summand = prim.Call(ExplicitVCLCast(dtype_floatingpoint(), vector_width=sf.vector_width), - (prim.Subscript(prim.Variable(direct_input_arg), - input_inames),)) - else: - input_summand = prim.Subscript(prim.Variable(direct_input_arg), - input_inames + vec_iname) + input_summand = sf.input.realize_direct(inp_shape, input_inames) else: # If we did permute the order of a matrices above we also # permuted the order of out_inames. Unfortunately the diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py index babb432b..627c8fc0 100644 --- a/python/dune/perftool/sumfact/symbolic.py +++ b/python/dune/perftool/sumfact/symbolic.py @@ -5,6 +5,8 @@ from dune.perftool.generation import get_counted_variable from dune.perftool.pdelab.geometry import local_dimension, world_dimension from dune.perftool.sumfact.quadrature import quadrature_inames from dune.perftool.sumfact.tabulation import BasisTabulationMatrixBase, BasisTabulationMatrixArray +from dune.perftool.loopy.target import dtype_floatingpoint +from dune.perftool.loopy.vcl import ExplicitVCLCast from pytools import ImmutableRecord, product @@ -18,11 +20,52 @@ import inspect class SumfactKernelInputBase(object): @property - def direct_input(self): - return None + def direct_input_is_possible(self): + return False - def realize(self, sf, i, dep): - pass + def realize(self, sf, dep, index=0): + return frozenset() + + def realize_direct(self, inames): + raise NotImplementedError + + +class VectorSumfactKernelInput(SumfactKernelInputBase): + def __init__(self, inputs): + assert(isinstance(inputs, tuple)) + self.inputs = inputs + + @property + def direct_input_is_possible(self): + return all(i.direct_input_is_possible for i in self.inputs) + + def realize(self, sf, dep): + for i, inp in enumerate(self.inputs): + dep = dep.union(inp.realize(sf, dep, index=i)) + return dep + + def realize_direct(self, shape, inames): + # Check whether the input exhibits a favorable structure + # (whether we can broadcast scalar values into SIMD registers) + total = set(self.inputs) + lower = set(self.inputs[:len(self.inputs) // 2]) + upper = set(self.inputs[len(self.inputs) // 2:]) + + if len(total) == 1: + # All input coefficients use the exact same input coefficient. + # We implement this by broadcasting it into a SIMD register + return prim.Call(ExplicitVCLCast(dtype_floatingpoint()), + (self.inputs[0].realize_direct(shape, inames),) + ) + elif len(total) == 2 and len(lower) == 1 and len(upper) == 1: + # The lower and the upper part of the SIMD register use + # the same input coefficient, we combine the SIMD register + # from two shorter SIMD types + raise NotImplementedError("Lower/Upper half SIMD loads not implemented!") + else: + # The input does not exhibit a broadcastable structure, we + # need to load scalars into the SIMD vector. + raise NotImplementedError("SIMD loads from scalars not implemented!") class SumfactKernelBase(object): @@ -457,11 +500,6 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) def predicates(self): return self.kernels[0].predicates - @property - def input(self): - assert len(set(k.input for k in self.kernels)) == 1 - return self.kernels[0].input - @property def accumvar(self): assert len(set(k.accumvar for k in self.kernels)) == 1 @@ -486,6 +524,10 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) # Define the same properties the normal SumfactKernel defines # + @property + def input(self): + return VectorSumfactKernelInput(tuple(k.input for k in self.kernels)) + @property def cache_key(self): return (tuple(k.cache_key for k in self.kernels), self.buffer) -- GitLab