Skip to content
Snippets Groups Projects
Commit e2f7995f authored by Dominic Kempf's avatar Dominic Kempf
Browse files

WIP

parent 9b170c79
No related branches found
No related tags found
No related merge requests found
...@@ -43,6 +43,7 @@ from dune.perftool.generation.loopy import (barrier, ...@@ -43,6 +43,7 @@ from dune.perftool.generation.loopy import (barrier,
kernel_cached, kernel_cached,
noop_instruction, noop_instruction,
silenced_warning, silenced_warning,
subst_rule,
temporary_variable, temporary_variable,
transform, transform,
valuearg, valuearg,
......
...@@ -172,8 +172,8 @@ def noop_instruction(**kwargs): ...@@ -172,8 +172,8 @@ def noop_instruction(**kwargs):
context_tags="kernel", context_tags="kernel",
cache_key_generator=no_caching, cache_key_generator=no_caching,
) )
def transform(trafo, *args): def transform(trafo, *args, **kwargs):
return (trafo, args) return (trafo, args, kwargs)
@generator_factory(item_tags=("instruction", "barrier"), @generator_factory(item_tags=("instruction", "barrier"),
...@@ -216,3 +216,8 @@ def loopy_class_member(name, classtag=None, potentially_vectorized=False, **kwar ...@@ -216,3 +216,8 @@ def loopy_class_member(name, classtag=None, potentially_vectorized=False, **kwar
globalarg(name, **kwargs) globalarg(name, **kwargs)
return name return name
@generator_factory(item_tags=("substrule",), context_tags="kernel")
def subst_rule(name, args, expr):
return lp.SubstitutionRule(name, args, expr)
...@@ -539,19 +539,7 @@ def extract_kernel_from_cache(tag, wrap_in_cgen=True): ...@@ -539,19 +539,7 @@ def extract_kernel_from_cache(tag, wrap_in_cgen=True):
# Apply the transformations that were gathered during tree traversals # Apply the transformations that were gathered during tree traversals
for trafo in transformations: for trafo in transformations:
kernel = trafo[0](kernel, *trafo[1]) kernel = trafo[0](kernel, *trafo[1], **trafo[2])
# Precompute all the substrules
for sr in kernel.substitutions:
tmpname = "precompute_{}".format(sr)
kernel = lp.precompute(kernel,
sr,
temporary_name=tmpname,
)
# Vectorization strategies are actually very likely to eliminate the
# precomputation temporary. To avoid the temporary elimination warning
# we need to explicitly disable it.
kernel = kernel.copy(silenced_warnings=kernel.silenced_warnings + ["temp_to_write({})".format(tmpname)])
from dune.perftool.loopy import heuristic_duplication from dune.perftool.loopy import heuristic_duplication
kernel = heuristic_duplication(kernel) kernel = heuristic_duplication(kernel)
......
...@@ -194,7 +194,7 @@ def _realize_sum_factorization_kernel(sf): ...@@ -194,7 +194,7 @@ def _realize_sum_factorization_kernel(sf):
tag = "{}_{}".format(tag, "_".join(sf.within_inames)) tag = "{}_{}".format(tag, "_".join(sf.within_inames))
# Collect the key word arguments for the loopy instruction # Collect the key word arguments for the loopy instruction
insn_args = {"forced_iname_deps": frozenset([iname for iname in out_inames]).union(frozenset(sf.within_inames)), insn_args = {"forced_iname_deps": frozenset([i for i in out_inames]).union(frozenset(sf.within_inames)),
"forced_iname_deps_is_final": True, "forced_iname_deps_is_final": True,
"depends_on": insn_dep, "depends_on": insn_dep,
"tags": frozenset({tag}), "tags": frozenset({tag}),
...@@ -205,6 +205,7 @@ def _realize_sum_factorization_kernel(sf): ...@@ -205,6 +205,7 @@ def _realize_sum_factorization_kernel(sf):
# In case of direct output we directly accumulate the result # In case of direct output we directly accumulate the result
# of the Sumfactorization into some global data structure. # of the Sumfactorization into some global data structure.
if l == len(matrix_sequence) - 1 and get_form_option('fastdg') and sf.stage == 3: if l == len(matrix_sequence) - 1 and get_form_option('fastdg') and sf.stage == 3:
insn_args["forced_iname_deps"] = insn_args["forced_iname_deps"].union(frozenset({vec_iname[0].name}))
insn_dep = sf.output.realize_direct(matprod, output_inames, out_shape, insn_args) insn_dep = sf.output.realize_direct(matprod, output_inames, out_shape, insn_args)
else: else:
# Issue the reduction instruction that implements the multiplication # Issue the reduction instruction that implements the multiplication
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from dune.perftool.options import get_option from dune.perftool.options import get_option
from dune.perftool.generation import (get_counted_variable, from dune.perftool.generation import (get_counted_variable,
subst_rule,
transform, transform,
) )
from dune.perftool.pdelab.geometry import local_dimension, world_dimension from dune.perftool.pdelab.geometry import local_dimension, world_dimension
...@@ -128,6 +129,14 @@ class VectorSumfactKernelOutput(SumfactKernelOutputBase): ...@@ -128,6 +129,14 @@ class VectorSumfactKernelOutput(SumfactKernelOutputBase):
def realize_direct(self, result, inames, shape, args): def realize_direct(self, result, inames, shape, args):
outputs = set(self.outputs) outputs = set(self.outputs)
# If multiple horizontal_add's are to be performed with 'result'
# we need to precompute the result!
if len(outputs) > 1:
substname = "haddsubst_{}".format("_".join([i.name for i in inames]))
subst_rule(substname, (), result)
result = prim.Call(prim.Variable(substname), ())
transform(lp.precompute, substname, precompute_outer_inames=args["forced_iname_deps"])
deps = frozenset() deps = frozenset()
for o in outputs: for o in outputs:
hadd_result = self._add_hadd(o, result) hadd_result = self._add_hadd(o, result)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment