From 9c12ecc40ee16e774e8caf9f0bda2057dc2492a1 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Tue, 26 Apr 2016 10:33:48 +0200
Subject: [PATCH] Remodel the generic_context

---
 python/dune/perftool/generation/__init__.py  |  4 +--
 python/dune/perftool/generation/context.py   | 33 +++++++++++++-------
 python/dune/perftool/loopy/transformer.py    | 11 ++++---
 python/dune/perftool/pdelab/localoperator.py | 14 +++------
 4 files changed, 33 insertions(+), 29 deletions(-)

diff --git a/python/dune/perftool/generation/__init__.py b/python/dune/perftool/generation/__init__.py
index a51143e6..5fecdc4c 100644
--- a/python/dune/perftool/generation/__init__.py
+++ b/python/dune/perftool/generation/__init__.py
@@ -27,6 +27,6 @@ from dune.perftool.generation.loopy import (domain,
                                             )
 
 from dune.perftool.generation.context import (cache_context,
-                                              generic_context,
-                                              get_generic_context_value,
+                                              global_context,
+                                              get_global_context_value,
                                               )
diff --git a/python/dune/perftool/generation/context.py b/python/dune/perftool/generation/context.py
index 13099e01..5e945d8a 100644
--- a/python/dune/perftool/generation/context.py
+++ b/python/dune/perftool/generation/context.py
@@ -31,25 +31,34 @@ def get_context_tags():
     return result
 
 
-_generic_context_cache = {}
+_global_context_cache = {}
 
 
-class _GenericContext(object):
-    def __init__(self, key, value):
-        self.key = key
-        self.value = value
-        assert key not in _generic_context_cache
+class _GlobalContext(object):
+    def __init__(self, **kwargs):
+        self.kw = kwargs
 
     def __enter__(self):
-        _generic_context_cache[self.key] = self.value
+        self.old_kw = {}
+        for k, v in self.kw.items():
+            # First store existing values of the same keys
+            if k in _global_context_cache:
+                self.old_kw[k] = v
+            # Now replace the value with the new one
+            _global_context_cache[k] = v
 
     def __exit__(self, exc_type, exc_value, traceback):
-        del _generic_context_cache[self.key]
+        # Delete all the entries from this context
+        for k in self.kw.keys():
+            del _global_context_cache[k]
+        # and restore previously overwritten values
+        for k, v in self.old_kw.items():
+            _global_context_cache[k] = v
 
 
-def generic_context(key, value):
-    return _GenericContext(key, value)
+def global_context(**kwargs):
+    return _GlobalContext(**kwargs)
 
 
-def get_generic_context_value(key):
-    return _generic_context_cache[key]
+def get_global_context_value(key):
+    return _global_context_cache[key]
diff --git a/python/dune/perftool/loopy/transformer.py b/python/dune/perftool/loopy/transformer.py
index 909ed752..6776df96 100644
--- a/python/dune/perftool/loopy/transformer.py
+++ b/python/dune/perftool/loopy/transformer.py
@@ -44,8 +44,8 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper):
         assert isinstance(o, Expression)
 
         # Determine the name of the parameter function
-        from dune.perftool.generation import get_generic_context_value
-        name = get_generic_context_value("namedata")[id(o)]
+        from dune.perftool.generation import get_global_context_value
+        name = get_global_context_value("namedata")[id(o)]
 
         # Trigger the generation of code for this thing in the parameter class
         from dune.perftool.pdelab.parameter import parameter_function
@@ -197,8 +197,8 @@ def transform_accumulation_term(term, measure, subdomain_id):
         # modified measure per integral type.
 
         # Get the original form and inspect the present measures
-        from dune.perftool.generation import get_generic_context_value
-        original_form = get_generic_context_value("formdata").original_form
+        from dune.perftool.generation import get_global_context_value
+        original_form = get_global_context_value("formdata").original_form
         sd = original_form.subdomain_data()
         assert len(sd) == 1
         subdomains, = list(sd.values())
@@ -208,10 +208,11 @@ def transform_accumulation_term(term, measure, subdomain_id):
                 del subdomains[k]
 
         # Finally extract the original subdomain_data (which needs to be present!)
+        assert measure in subdomains
         subdomain_data = subdomains[measure]
 
         # Determine the name of the parameter function
-        name = get_generic_context_value("namedata")[id(subdomain_data)]
+        name = get_global_context_value("namedata")[id(subdomain_data)]
 
         # Trigger the generation of code for this thing in the parameter class
         from dune.perftool.pdelab.parameter import parameter_function
diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py
index 126c4089..28019f58 100644
--- a/python/dune/perftool/pdelab/localoperator.py
+++ b/python/dune/perftool/pdelab/localoperator.py
@@ -243,16 +243,11 @@ def generate_localoperator_kernels(formdata, namedata):
     # Have a data structure collect the generated kernels
     operator_kernels = {}
 
-    import functools
-    from dune.perftool.generation import generic_context
-    namedata_context = functools.partial(generic_context, "namedata")
-    formdata_context = functools.partial(generic_context, "formdata")
-
-    with formdata_context(formdata):
+    from dune.perftool.generation import global_context
+    with global_context(formdata=formdata, namedata=namedata):
         # Generate the necessary residual methods
         for integral in form.integrals():
-            with namedata_context(namedata):
-                kernel = generate_kernel(integral)
+            kernel = generate_kernel(integral)
             operator_kernels[(integral.integral_type(), 'residual')] = kernel
 
         # Generate the necessary jacobian methods
@@ -265,8 +260,7 @@ def generate_localoperator_kernels(formdata, namedata):
             jacform = expand_derivatives(derivative(form, form.coefficients()[0]))
 
             for integral in jacform.integrals():
-                with namedata_context(namedata):
-                    kernel = generate_kernel(integral)
+                kernel = generate_kernel(integral)
                 operator_kernels[(integral.integral_type(), 'jacobian')] = kernel
 
         # TODO: JacobianApply for matrix-free computations.
-- 
GitLab