Skip to content
Snippets Groups Projects
Commit 0f0d69d0 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Correctly treat restriction in the jacobian_skeleton case

parent c13b3424
No related branches found
No related tags found
No related merge requests found
...@@ -22,6 +22,9 @@ class SumfactKernel(prim.Variable): ...@@ -22,6 +22,9 @@ class SumfactKernel(prim.Variable):
preferred_position, preferred_position,
restriction, restriction,
): ):
if not isinstance(restriction, tuple):
restriction = (restriction, 0)
self.a_matrices = a_matrices self.a_matrices = a_matrices
self.buffer = buffer self.buffer = buffer
self.stage = stage self.stage = stage
......
...@@ -128,7 +128,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): ...@@ -128,7 +128,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
# Get the vectorization info. If this happens during the dry run, we get dummies # Get the vectorization info. If this happens during the dry run, we get dummies
from dune.perftool.sumfact.vectorization import get_vectorization_info from dune.perftool.sumfact.vectorization import get_vectorization_info
a_matrices, buffer, input, index = get_vectorization_info(a_matrices, accterm.argument.restriction) a_matrices, buffer, input, index = get_vectorization_info(a_matrices, (accterm.argument.restriction, restriction))
# Initialize a base storage for this buffer and get a temporay pointing to it # Initialize a base storage for this buffer and get a temporay pointing to it
shape = tuple(mat.cols for mat in a_matrices if mat.cols != 1) shape = tuple(mat.cols for mat in a_matrices if mat.cols != 1)
...@@ -185,7 +185,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): ...@@ -185,7 +185,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
insn_dep=insn_dep, insn_dep=insn_dep,
additional_inames=frozenset(visitor.inames), additional_inames=frozenset(visitor.inames),
preferred_position=pref_pos, preferred_position=pref_pos,
restriction=accterm.argument.restriction, restriction=(accterm.argument.restriction, restriction),
) )
inames = tuple(sumfact_iname(mat.rows, 'accum') for mat in a_matrices) inames = tuple(sumfact_iname(mat.rows, 'accum') for mat in a_matrices)
......
...@@ -18,14 +18,18 @@ def vectorization_info(a_matrices, restriction, new_a_matrices, buffer, input, i ...@@ -18,14 +18,18 @@ def vectorization_info(a_matrices, restriction, new_a_matrices, buffer, input, i
def get_vectorization_info(a_matrices, restriction): def get_vectorization_info(a_matrices, restriction):
if not isinstance(restriction, tuple):
restriction = (restriction, 0)
from dune.perftool.generation import get_global_context_value from dune.perftool.generation import get_global_context_value
if get_global_context_value("dry_run"): if get_global_context_value("dry_run"):
# Return dummy data # Return dummy data
return (a_matrices, get_counted_variable("buffer"), get_counted_variable("input"), None) return (a_matrices, get_counted_variable("buffer"), get_counted_variable("input"), None)
try:
return vectorization_info(a_matrices, restriction, None, None, None, None) # Try getting the vectorization info collected after dry run
except TypeError: vec = vectorization_info(a_matrices, restriction, None, None, None, None)
if vec[0] is None:
raise PerftoolError("Sumfact Vectorization data should have been collected in dry run, but wasnt") raise PerftoolError("Sumfact Vectorization data should have been collected in dry run, but wasnt")
return vec
def no_vectorization(sumfacts): def no_vectorization(sumfacts):
...@@ -107,7 +111,9 @@ def decide_vectorization_strategy(): ...@@ -107,7 +111,9 @@ def decide_vectorization_strategy():
no_vectorization(sumfacts) no_vectorization(sumfacts)
else: else:
for stage in (1, 3): for stage in (1, 3):
for restriction in (Restriction.NONE, Restriction.POSITIVE, Restriction.NEGATIVE): res = (Restriction.NONE, Restriction.POSITIVE, Restriction.NEGATIVE)
import itertools as it
for restriction in it.product(res, res):
decide_stage_vectorization_strategy(sumfacts, stage, restriction) decide_stage_vectorization_strategy(sumfacts, stage, restriction)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment