From 968f623f16cb48ed395249de7cb5b0c4fbf0c27b Mon Sep 17 00:00:00 2001
From: James Johnson <jjohnson@enthought.com>
Date: Fri, 15 Jun 2018 15:31:34 +0100
Subject: [PATCH] Made recursive key remover more generic

---
 force_bdss/io/tests/test_workflow_writer.py | 24 ++++++++++++---------
 force_bdss/io/workflow_writer.py            | 20 ++++++++---------
 2 files changed, 24 insertions(+), 20 deletions(-)

diff --git a/force_bdss/io/tests/test_workflow_writer.py b/force_bdss/io/tests/test_workflow_writer.py
index b3641a6..6dac2cb 100644
--- a/force_bdss/io/tests/test_workflow_writer.py
+++ b/force_bdss/io/tests/test_workflow_writer.py
@@ -14,7 +14,7 @@ from force_bdss.tests.dummy_classes.factory_registry_plugin import \
     DummyFactoryRegistryPlugin
 
 from force_bdss.io.workflow_writer import WorkflowWriter, traits_to_dict,\
-    pop_traits_version
+    pop_recursive
 from force_bdss.core.workflow import Workflow
 
 
@@ -89,13 +89,17 @@ class TestWorkflowWriter(unittest.TestCase):
 
         self.assertEqual(traits_to_dict(mock_traits), {"foo": "bar"})
 
-    def test_pop_traits_version(self):
+    def test_pop_recursive(self):
 
-        test_dictionary = {'Entry1': {'Entry1-1': 4, '__traits_version__': 67},
-                           'Entry2': [3, 'a', {'Entry2-1': 5,
-                                               '__traits_version__': 9001}],
-                           '__traits_version__': 13}
-        result_dictionary = {'Entry1': {'Entry1-1': 4, },
-                             'Entry2': [3, 'a', {'Entry2-1': 5, }], }
-        traitless_dictionary = pop_traits_version(test_dictionary)
-        self.assertEqual(traitless_dictionary, result_dictionary)
+        test_dictionary = {'K1': {'K1': 'V1', 'K2': 'V2', 'K3': 'V3'},
+                           'K2': ['V1', 'V2', {'K1': 'V1', 'K2': 'V2',
+                                               'K3': 'V3'}],
+                           'K3': 'V3',
+                           'K4': ('V1', {'K3': 'V3'},)}
+
+        result_dictionary = {'K1': {'K1': 'V1', 'K2': 'V2', },
+                             'K2': ['V1', 'V2', {'K1': 'V1', 'K2': 'V2', }],
+                             'K4': ('V1', {},)}
+
+        test_result_dictionary = pop_recursive(test_dictionary, )
+        self.assertEqual(test_result_dictionary, result_dictionary)
diff --git a/force_bdss/io/workflow_writer.py b/force_bdss/io/workflow_writer.py
index 404fcb2..98c5c04 100644
--- a/force_bdss/io/workflow_writer.py
+++ b/force_bdss/io/workflow_writer.py
@@ -95,28 +95,28 @@ def traits_to_dict(traits_obj):
 
     state = traits_obj.__getstate__()
 
-    state = pop_traits_version(state)
+    state = pop_recursive(state,'__traits_version__')
 
     return state
 
 
-def pop_traits_version(dictionary):
-    """Recursively remove the __traits_version__ attribute
-    from dictionary."""
+def pop_recursive(dictionary,remove_key):
+    """Recursively remove a named key from dictionary and any contained
+    dictionaries."""
     try:
-        dictionary.pop("__traits_version__")
+        dictionary.pop(remove_key)
     except KeyError:
         pass
 
     for key in dictionary:
-        # If we have a dict, remove the traits version
+        # If remove_key is in the dict, remove it
         if isinstance(dictionary[key], dict):
-            pop_traits_version(dictionary[key])
-        # If we have a non-dict which contains a dict, remove traits from
-        # that as well
+            pop_recursive(dictionary[key], remove_key)
+        # If we have a non-dict iterable which contains a dict,
+        # call pop.(remove_key) from that as well
         elif isinstance(dictionary[key], Iterable):
             for element in dictionary[key]:
                 if isinstance(element, dict):
-                    pop_traits_version(element)
+                    pop_recursive(element, remove_key)
 
     return dictionary
-- 
GitLab