From 5a524300ec639de0d9e213a74aebf166acf8722b Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Fri, 21 Oct 2016 17:24:45 +0200
Subject: [PATCH] Update the flip flop buffer implementation

---
 python/dune/perftool/generation/loopy.py | 13 +++++++++----
 python/dune/perftool/loopy/buffer.py     | 22 ++++++++++++++++------
 2 files changed, 25 insertions(+), 10 deletions(-)

diff --git a/python/dune/perftool/generation/loopy.py b/python/dune/perftool/generation/loopy.py
index 37d4607d..2e0cd5fb 100644
--- a/python/dune/perftool/generation/loopy.py
+++ b/python/dune/perftool/generation/loopy.py
@@ -81,9 +81,13 @@ def default_declaration(name, shape=(), shape_impl=()):
         return '{} {}(0.0);'.format(t, name)
 
 
-def declaration_with_base_storage(base_storage):
+def declaration_with_base_storage(base_storage, storage_shape):
+    # We currently consider the base storage to be linear
+    assert(len(storage_shape) == 1)
+
     @preamble
     def base_storage_decl(name, shape=(), shape_impl=()):
+        default_declaration(base_storage, shape=storage_shape, shape_impl=('arr',))
         return 'double* {} = {};'.format(name, base_storage)
 
     return base_storage_decl
@@ -110,13 +114,14 @@ def temporary_variable(name, managed=False, **kwargs):
 
     # Check if a specified base_storage has already been initialized
     if kwargs.get('base_storage', None):
-        assert(kwargs['base_storage'] in [tv.name for tv in temporary_variable._memoize_cache.values()])
         assert(kwargs.get('decl_method', None) is None)
-        decl_method = declaration_with_base_storage(kwargs['base_storage'])
+        assert('storage_shape' in kwargs)
+        decl_method = declaration_with_base_storage(kwargs['base_storage'], kwargs['storage_shape'])
+    else:
+        decl_method = kwargs.pop('decl_method', default_declaration)
 
     shape = kwargs.get('shape', ())
     shape_impl = kwargs.pop('shape_impl', ('arr',) * len(shape))
-    decl_method = kwargs.pop('decl_method', default_declaration)
 
     if decl_method is not None:
         decl_method(name, shape, shape_impl)
diff --git a/python/dune/perftool/loopy/buffer.py b/python/dune/perftool/loopy/buffer.py
index 265cc051..4238c492 100644
--- a/python/dune/perftool/loopy/buffer.py
+++ b/python/dune/perftool/loopy/buffer.py
@@ -6,6 +6,7 @@ from dune.perftool.generation import (generator_factory,
 class FlipFlopBuffer(object):
     def __init__(self, identifier, base_storage_size=None, num=2):
         self.identifier = identifier
+        self.base_storage_size = base_storage_size
         self.num = num
 
         # Initialize the counter that switches between the base storages!
@@ -14,26 +15,31 @@ class FlipFlopBuffer(object):
         # Initialize a total counter for the issued temporaries
         self._counter = 0
 
-        # Get the base storage temporaries that actually hold the data!
-        # TODO: Use heap-allocated ones instead (easy with DuneTarget)
+        # Generate the base storage names
         self.base_storage = tuple("{}_base_{}".format(self.identifier, i) for i in range(self.num))
 
-        for bs in self.base_storage:
-            temporary_variable(bs, shape=(base_storage_size,))
+    def switch_base_storage(self):
+        self._current = (self._current + 1) % self.num
 
     def get_temporary(self, **kwargs):
         assert("base_storage" not in kwargs)
+        assert("storage_shape" not in kwargs)
 
         # Select the base storage and increase counter
         base = self.base_storage[self._current]
-        self._current = (self._current + 1) % self.num
 
         # Construct a temporary name
         name = "{}_{}".format(self.identifier, self._counter)
         self._counter = self._counter + 1
 
         # Construct the temporary and return it
-        temporary_variable(name, base_storage=base, **kwargs)
+        temporary_variable(name,
+                           base_storage=base,
+                           storage_shape=(self.base_storage_size,),
+                           managed=True,
+                           **kwargs
+                           )
+
         return name
 
 
@@ -49,3 +55,7 @@ def initialize_buffer(identifier, base_storage_size=None, num=2):
 
 def get_buffer_temporary(identifier, **kwargs):
     return initialize_buffer(identifier).get_temporary(**kwargs)
+
+
+def switch_base_storage(identifier):
+    initialize_buffer(identifier).switch_base_storage()
-- 
GitLab