Skip to content
Snippets Groups Projects
Commit 5a524300 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Update the flip flop buffer implementation

parent 656e022b
No related branches found
No related tags found
No related merge requests found
...@@ -81,9 +81,13 @@ def default_declaration(name, shape=(), shape_impl=()): ...@@ -81,9 +81,13 @@ def default_declaration(name, shape=(), shape_impl=()):
return '{} {}(0.0);'.format(t, name) return '{} {}(0.0);'.format(t, name)
def declaration_with_base_storage(base_storage): def declaration_with_base_storage(base_storage, storage_shape):
# We currently consider the base storage to be linear
assert(len(storage_shape) == 1)
@preamble @preamble
def base_storage_decl(name, shape=(), shape_impl=()): def base_storage_decl(name, shape=(), shape_impl=()):
default_declaration(base_storage, shape=storage_shape, shape_impl=('arr',))
return 'double* {} = {};'.format(name, base_storage) return 'double* {} = {};'.format(name, base_storage)
return base_storage_decl return base_storage_decl
...@@ -110,13 +114,14 @@ def temporary_variable(name, managed=False, **kwargs): ...@@ -110,13 +114,14 @@ def temporary_variable(name, managed=False, **kwargs):
# Check if a specified base_storage has already been initialized # Check if a specified base_storage has already been initialized
if kwargs.get('base_storage', None): if kwargs.get('base_storage', None):
assert(kwargs['base_storage'] in [tv.name for tv in temporary_variable._memoize_cache.values()])
assert(kwargs.get('decl_method', None) is None) assert(kwargs.get('decl_method', None) is None)
decl_method = declaration_with_base_storage(kwargs['base_storage']) assert('storage_shape' in kwargs)
decl_method = declaration_with_base_storage(kwargs['base_storage'], kwargs['storage_shape'])
else:
decl_method = kwargs.pop('decl_method', default_declaration)
shape = kwargs.get('shape', ()) shape = kwargs.get('shape', ())
shape_impl = kwargs.pop('shape_impl', ('arr',) * len(shape)) shape_impl = kwargs.pop('shape_impl', ('arr',) * len(shape))
decl_method = kwargs.pop('decl_method', default_declaration)
if decl_method is not None: if decl_method is not None:
decl_method(name, shape, shape_impl) decl_method(name, shape, shape_impl)
......
...@@ -6,6 +6,7 @@ from dune.perftool.generation import (generator_factory, ...@@ -6,6 +6,7 @@ from dune.perftool.generation import (generator_factory,
class FlipFlopBuffer(object): class FlipFlopBuffer(object):
def __init__(self, identifier, base_storage_size=None, num=2): def __init__(self, identifier, base_storage_size=None, num=2):
self.identifier = identifier self.identifier = identifier
self.base_storage_size = base_storage_size
self.num = num self.num = num
# Initialize the counter that switches between the base storages! # Initialize the counter that switches between the base storages!
...@@ -14,26 +15,31 @@ class FlipFlopBuffer(object): ...@@ -14,26 +15,31 @@ class FlipFlopBuffer(object):
# Initialize a total counter for the issued temporaries # Initialize a total counter for the issued temporaries
self._counter = 0 self._counter = 0
# Get the base storage temporaries that actually hold the data! # Generate the base storage names
# TODO: Use heap-allocated ones instead (easy with DuneTarget)
self.base_storage = tuple("{}_base_{}".format(self.identifier, i) for i in range(self.num)) self.base_storage = tuple("{}_base_{}".format(self.identifier, i) for i in range(self.num))
for bs in self.base_storage: def switch_base_storage(self):
temporary_variable(bs, shape=(base_storage_size,)) self._current = (self._current + 1) % self.num
def get_temporary(self, **kwargs): def get_temporary(self, **kwargs):
assert("base_storage" not in kwargs) assert("base_storage" not in kwargs)
assert("storage_shape" not in kwargs)
# Select the base storage and increase counter # Select the base storage and increase counter
base = self.base_storage[self._current] base = self.base_storage[self._current]
self._current = (self._current + 1) % self.num
# Construct a temporary name # Construct a temporary name
name = "{}_{}".format(self.identifier, self._counter) name = "{}_{}".format(self.identifier, self._counter)
self._counter = self._counter + 1 self._counter = self._counter + 1
# Construct the temporary and return it # Construct the temporary and return it
temporary_variable(name, base_storage=base, **kwargs) temporary_variable(name,
base_storage=base,
storage_shape=(self.base_storage_size,),
managed=True,
**kwargs
)
return name return name
...@@ -49,3 +55,7 @@ def initialize_buffer(identifier, base_storage_size=None, num=2): ...@@ -49,3 +55,7 @@ def initialize_buffer(identifier, base_storage_size=None, num=2):
def get_buffer_temporary(identifier, **kwargs): def get_buffer_temporary(identifier, **kwargs):
return initialize_buffer(identifier).get_temporary(**kwargs) return initialize_buffer(identifier).get_temporary(**kwargs)
def switch_base_storage(identifier):
initialize_buffer(identifier).switch_base_storage()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment