Skip to content
Snippets Groups Projects
Commit 22800cf2 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Fix symdiffs!

parent fc5d941e
No related branches found
No related tags found
No related merge requests found
...@@ -119,6 +119,9 @@ class _RegisteredFunction(object): ...@@ -119,6 +119,9 @@ class _RegisteredFunction(object):
# Return the result for immediate usage # Return the result for immediate usage
return self._get_content(cache_key) return self._get_content(cache_key)
def remove_by_value(self, val):
self._memoize_cache = {k:v for k, v in self._memoize_cache.items() if v != val}
def generator_factory(**factory_kwargs): def generator_factory(**factory_kwargs):
""" A function decorator factory """ A function decorator factory
......
...@@ -488,10 +488,8 @@ def generate_kernel(integrals): ...@@ -488,10 +488,8 @@ def generate_kernel(integrals):
def extract_kernel_from_cache(tag): def extract_kernel_from_cache(tag):
# Preprocess some instruction! # Preprocess some instruction!
from dune.perftool.sumfact.sumfact import expand_sumfact_kernels, filter_sumfact_instructions from dune.perftool.sumfact.sumfact import expand_sumfact_kernels
instructions = [i for i in retrieve_cache_items("{} and instruction".format(tag))] expand_sumfact_kernels(tag)
expand_sumfact_kernels(instructions)
filter_sumfact_instructions()
# Now extract regular loopy kernel components # Now extract regular loopy kernel components
from dune.perftool.loopy.target import DuneTarget from dune.perftool.loopy.target import DuneTarget
......
...@@ -16,6 +16,7 @@ from dune.perftool.generation import (backend, ...@@ -16,6 +16,7 @@ from dune.perftool.generation import (backend,
globalarg, globalarg,
iname, iname,
instruction, instruction,
retrieve_cache_items,
silenced_warning, silenced_warning,
temporary_variable, temporary_variable,
transform, transform,
...@@ -139,6 +140,12 @@ def default_resolution(insns): ...@@ -139,6 +140,12 @@ def default_resolution(insns):
depends_on=frozenset(*deps) depends_on=frozenset(*deps)
) )
) )
if isinstance(insn, lp.Assignment):
from dune.perftool.generation.loopy import expr_instruction_impl
expr_instruction_impl.remove_by_value(insn)
if isinstance(insn, lp.CallInstruction):
from dune.perftool.generation.loopy import call_instruction_impl
call_instruction_impl.remove_by_value(insn)
def apply_sumfact_grad_vectorization(insns, stage): def apply_sumfact_grad_vectorization(insns, stage):
...@@ -153,7 +160,7 @@ def apply_sumfact_grad_vectorization(insns, stage): ...@@ -153,7 +160,7 @@ def apply_sumfact_grad_vectorization(insns, stage):
# Now apply some heuristics when to vectorize... # Now apply some heuristics when to vectorize...
if len(set(sumfact_kernels)) < 3: if len(set(sumfact_kernels)) < 3:
default_resolution(insns) return
else: else:
# Vectorize!!! # Vectorize!!!
sumfact_kernels = sorted(sumfact_kernels, key=lambda s: s.preferred_interleaving_position) sumfact_kernels = sorted(sumfact_kernels, key=lambda s: s.preferred_interleaving_position)
...@@ -167,6 +174,7 @@ def apply_sumfact_grad_vectorization(insns, stage): ...@@ -167,6 +174,7 @@ def apply_sumfact_grad_vectorization(insns, stage):
# Maybe initialize the input buffer # Maybe initialize the input buffer
if sumfact_kernels[0].setup_method: if sumfact_kernels[0].setup_method:
assert stage == 1
shape = product(mat.cols for mat in sumfact_kernels[0].a_matrices) shape = product(mat.cols for mat in sumfact_kernels[0].a_matrices)
shape = (shape, 4) shape = (shape, 4)
initialize_buffer(buffer, base_storage_size=4 * shape[0]) initialize_buffer(buffer, base_storage_size=4 * shape[0])
...@@ -178,6 +186,7 @@ def apply_sumfact_grad_vectorization(insns, stage): ...@@ -178,6 +186,7 @@ def apply_sumfact_grad_vectorization(insns, stage):
func, args = sumf.setup_method func, args = sumf.setup_method
insn_dep = insn_dep.union({func(inp, *args, additional_indices=(i,))}) insn_dep = insn_dep.union({func(inp, *args, additional_indices=(i,))})
else: else:
assert stage == 3
# No setup method defined. We need to make sure the input is correctly setup # No setup method defined. We need to make sure the input is correctly setup
shape = tuple(mat.cols for mat in sumfact_kernels[0].a_matrices) + (4,) shape = tuple(mat.cols for mat in sumfact_kernels[0].a_matrices) + (4,)
initialize_buffer(buffer, base_storage_size=4 * shape[0]) initialize_buffer(buffer, base_storage_size=4 * shape[0])
...@@ -189,9 +198,11 @@ def apply_sumfact_grad_vectorization(insns, stage): ...@@ -189,9 +198,11 @@ def apply_sumfact_grad_vectorization(insns, stage):
for insn in insns: for insn in insns:
if isinstance(insn, lp.Assignment): if isinstance(insn, lp.Assignment):
if get_pymbolic_basename(insn.assignee) == sumf.input_temporary: if get_pymbolic_basename(insn.assignee) == sumf.input_temporary:
built_instruction(insn.copy(assignee=prim.Subscript(prim.Variable(inp), insn.assignee.index + (i,)))) built_instruction(insn.copy(assignee=prim.Subscript(prim.Variable(inp), insn.assignee.index + (i,)),
id=insn.id + "__{}".format(i)))
insn_dep = insn_dep.union(insn.depends_on)
from dune.perftool.generation.loopy import expr_instruction_impl from dune.perftool.generation.loopy import expr_instruction_impl
expr_instruction_impl._memoize_cache = {k: v for k, v in expr_instruction_impl._memoize_cache.items() if v.id != insn.id} expr_instruction_impl.remove_by_value(insn)
# Determine the joined AMatrix # Determine the joined AMatrix
large_a_matrices = [] large_a_matrices = []
...@@ -225,19 +236,19 @@ def apply_sumfact_grad_vectorization(insns, stage): ...@@ -225,19 +236,19 @@ def apply_sumfact_grad_vectorization(insns, stage):
depends_on=dep, depends_on=dep,
)) ))
def expand_sumfact_kernels(insns): if isinstance(insn, lp.Assignment):
if get_option("vectorize_grads"): from dune.perftool.generation.loopy import expr_instruction_impl
apply_sumfact_grad_vectorization(insns, 1) expr_instruction_impl.remove_by_value(insn)
apply_sumfact_grad_vectorization(insns, 3) if isinstance(insn, lp.CallInstruction):
else: from dune.perftool.generation.loopy import call_instruction_impl
default_resolution(insns) call_instruction_impl.remove_by_value(insn)
def filter_sumfact_instructions(): def expand_sumfact_kernels(tag):
""" Remove all instructions that contain a SumfactKernel node """ if get_option("vectorize_grads"):
from dune.perftool.generation.loopy import expr_instruction_impl, call_instruction_impl apply_sumfact_grad_vectorization([i for i in retrieve_cache_items("{} and instruction".format(tag))], 1)
expr_instruction_impl._memoize_cache = {k: v for k, v in expr_instruction_impl._memoize_cache.items() if not find_sumfact(v.expression)} apply_sumfact_grad_vectorization([i for i in retrieve_cache_items("{} and instruction".format(tag))], 3)
call_instruction_impl._memoize_cache = {k: v for k, v in call_instruction_impl._memoize_cache.items() if not find_sumfact(v.expression)} default_resolution(retrieve_cache_items("{} and instruction".format(tag)))
@iname @iname
...@@ -389,7 +400,6 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): ...@@ -389,7 +400,6 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
(Subscript(result, tuple(Variable(i) for i in inames)),) (Subscript(result, tuple(Variable(i) for i in inames)),)
) )
) )
instruction(assignees=(), instruction(assignees=(),
expression=expr, expression=expr,
forced_iname_deps=frozenset(inames + visitor.inames), forced_iname_deps=frozenset(inames + visitor.inames),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment