diff --git a/python/dune/perftool/loopy/vcl.py b/python/dune/perftool/loopy/vcl.py index 0aef74a5aa8198b642355678a635631b10b1f634..f76f7881cf76e4360d2b6b8c0e691cfcad1488c1 100644 --- a/python/dune/perftool/loopy/vcl.py +++ b/python/dune/perftool/loopy/vcl.py @@ -2,7 +2,7 @@ Our extensions to the loopy type system """ from dune.perftool.options import get_option -from dune.perftool.generation import function_mangler +from dune.perftool.generation import function_mangler, include_file import loopy as lp import numpy as np @@ -103,7 +103,8 @@ def vcl_function_mangler(knl, func, arg_dtypes): vcl = lp.types.NumpyType(get_vcl_type(dtype)) return lp.CallMangleInfo("select", (vcl,), (vcl, vcl, vcl)) - if func == "horizontal_add": + if func in ("horizontal_add", "horizontal_add_lower", "horizontal_add_upper"): dtype = arg_dtypes[0] vcl = lp.types.NumpyType(get_vcl_type(dtype)) - return lp.CallMangleInfo("horizontal_add", (lp.types.NumpyType(dtype.dtype),), (vcl,)) + include_file("dune/perftool/sumfact/horizontaladd.hh", filetag="operatorfile") + return lp.CallMangleInfo(func, (lp.types.NumpyType(dtype.dtype),), (vcl,)) diff --git a/python/dune/perftool/sumfact/accumulation.py b/python/dune/perftool/sumfact/accumulation.py index 699c35fd4b9ac94ae231e8b365313c863be21c6b..bd1934ebfb15e8411ec53026cb065908c4102ef3 100644 --- a/python/dune/perftool/sumfact/accumulation.py +++ b/python/dune/perftool/sumfact/accumulation.py @@ -157,15 +157,15 @@ class AccumulationOutput(SumfactKernelOutputBase, ImmutableRecord): expr = prim.Call(PDELabAccumulationFunction(self.accumvar, rank), tuple(accum_args) ) - instruction(assignees=(), - expression=expr, - forced_iname_deps=frozenset(inames + additional_inames + self.within_inames), - forced_iname_deps_is_final=True, - depends_on=insn_dep, - predicates=sf.predicates - ) - - return frozenset() + dep = instruction(assignees=(), + expression=expr, + forced_iname_deps=frozenset(inames + additional_inames + self.within_inames), + forced_iname_deps_is_final=True, + depends_on=insn_dep, + predicates=sf.predicates + ) + + return frozenset({dep}) class SumfactAccumulationInfo(ImmutableRecord): diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py index 2b9c2e21ed647f2a436a61503fd82c803a6b429d..30698309aaf8a6eaf84b07facda0e17ddab9a840 100644 --- a/python/dune/perftool/sumfact/symbolic.py +++ b/python/dune/perftool/sumfact/symbolic.py @@ -97,24 +97,34 @@ class VectorSumfactKernelOutput(SumfactKernelOutputBase): def realize(self, sf, result, insn_dep): outputs = set(self.outputs) - assert(len(outputs) == 1) - - o, = outputs + trial_element, = set(o.trial_element for o in self.outputs) + trial_element_index, = set(o.trial_element_index for o in self.outputs) from dune.perftool.sumfact.accumulation import accum_iname - element = get_leaf(o.trial_element, o.trial_element_index) if o.trial_element is not None else None + element = get_leaf(trial_element, trial_element_index) if trial_element is not None else None inames = tuple(accum_iname(element, mat.rows, i) for i, mat in enumerate(sf.matrix_sequence)) - - veciname = accum_iname(element, sf.vector_width, "vec") + veciname = accum_iname(element, sf.vector_width // len(outputs), "vec") transform(lp.tag_inames, [(veciname, "vec")]) - from dune.perftool.tools import maybe_wrap_subscript - result = prim.Call(prim.Variable("horizontal_add"), - (maybe_wrap_subscript(result, tuple(prim.Variable(iname) for iname in inames + (veciname,))),), - ) + deps = frozenset() + for o in outputs: + hadd_function = "horizontal_add" + if len(outputs) > 1: + pos = self.outputs.index(o) + if pos == 0: + hadd_function = "horizontal_add_lower" + else: + hadd_function = "horizontal_add_upper" + + from dune.perftool.tools import maybe_wrap_subscript + hadd_result = prim.Call(prim.Variable(hadd_function), + (maybe_wrap_subscript(result, tuple(prim.Variable(iname) for iname in inames + (veciname,))),), + ) + + deps = deps.union(o.realize(sf, hadd_result, insn_dep, inames=inames, additional_inames=(veciname,))) - return o.realize(sf, result, insn_dep, inames=inames, additional_inames=(veciname,)) + return deps class SumfactKernelBase(object): diff --git a/python/dune/perftool/sumfact/vectorization.py b/python/dune/perftool/sumfact/vectorization.py index ad01abbc8360704e1842be9c503b0117063fdaee..2ea28eec3f786c4485d3c0913f7e3279d82e666f 100644 --- a/python/dune/perftool/sumfact/vectorization.py +++ b/python/dune/perftool/sumfact/vectorization.py @@ -291,9 +291,6 @@ def _level2_optimal_vectorization_strategy_generator(sumfacts, width, qp, alread if parallel > len(keys): continue - if parallel == 2 and next(iter(sumfacts)).stage == 3: - continue - horizontal = 1 while horizontal <= width // parallel: combo = sum((inoutkey_sumfacts[part][:horizontal] for part in range(parallel)), ())