Skip to content
Snippets Groups Projects
Commit 6637071e authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Simplify flipflop buffer construction

parent f786d623
No related branches found
No related tags found
No related merge requests found
from dune.perftool.error import PerftoolLoopyError
from dune.perftool.generation import (generator_factory,
get_counted_variable,
get_global_context_value,
from dune.perftool.generation import (get_counted_variable,
kernel_cached,
temporary_variable,
)
class FlipFlopBuffer(object):
def __init__(self, identifier, base_storage_size=None, num=2):
def __init__(self, identifier):
self.identifier = identifier
self.base_storage_size = base_storage_size
self.num = num
# Initialize the counter that switches between the base storages!
self._current = 0
# Generate the base storage names
self.base_storage = tuple("{}_base_{}".format(self.identifier, i) for i in range(self.num))
self.base_storage = tuple("{}_base_{}".format(self.identifier, i) for i in (0, 1))
def switch_base_storage(self):
self._current = (self._current + 1) % self.num
self._current = (self._current + 1) % 2
def get_temporary(self, **kwargs):
assert("base_storage" not in kwargs)
......@@ -33,10 +30,6 @@ class FlipFlopBuffer(object):
if name is None:
name = get_counted_variable(self.identifier)
# Get geometric dimension
formdata = get_global_context_value('formdata')
dim = formdata.geometric_dimension
# Construct the temporary and return it
temporary_variable(name,
base_storage=base,
......@@ -47,14 +40,9 @@ class FlipFlopBuffer(object):
return name
@generator_factory(item_tags=("buffer"), cache_key_generator=lambda i, **kw: i, context_tags=("kernel",))
def initialize_buffer(identifier, base_storage_size=None, num=2):
if base_storage_size is None:
raise PerftoolLoopyError("The buffer for identifier {} has not been initialized.".format(identifier))
return FlipFlopBuffer(identifier,
base_storage_size=base_storage_size,
num=num,
)
@kernel_cached
def initialize_buffer(identifier):
return FlipFlopBuffer(identifier)
def get_buffer_temporary(identifier, **kwargs):
......
......@@ -100,12 +100,9 @@ def pymbolic_coefficient_gradient(element, restriction, component, coeff_func, v
shape = (product(mat.cols for mat in a_matrices),)
if index is not None:
shape = shape + (4,)
inp = initialize_buffer(buf,
base_storage_size=product(max(mat.rows, mat.cols) for mat in a_matrices),
num=2
).get_temporary(shape=shape,
name=inp,
)
inp = initialize_buffer(buf).get_temporary(shape=shape,
name=inp,
)
insn_dep = frozenset({Writes(inp)})
if get_option('fastdg'):
......@@ -197,10 +194,7 @@ def pymbolic_coefficient(element, restriction, component, coeff_func, visitor):
shape = (product(mat.cols for mat in a_matrices),)
if index is not None:
shape = shape + (4,)
initialize_buffer(buf,
base_storage_size=product(max(mat.rows, mat.cols) for mat in a_matrices),
num=2
).get_temporary(shape=shape,
initialize_buffer(buf).get_temporary(shape=shape,
name=inp,
)
......
......@@ -185,10 +185,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
vectag = frozenset()
base_storage_size = product(max(mat.rows, mat.cols) for mat in a_matrices)
temp = initialize_buffer(buf,
base_storage_size=base_storage_size,
num=2
).get_temporary(shape=shape,
temp = initialize_buffer(buf).get_temporary(shape=shape,
dim_tags=dim_tags,
name=inp,
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment