diff --git a/force_bdss/io/tests/test_workflow_writer.py b/force_bdss/io/tests/test_workflow_writer.py index b3641a61be3b2c6a12bd636175f472b27000b53b..6dac2cbb470bbd5ac8064764517f9876e22d5ff6 100644 --- a/force_bdss/io/tests/test_workflow_writer.py +++ b/force_bdss/io/tests/test_workflow_writer.py @@ -14,7 +14,7 @@ from force_bdss.tests.dummy_classes.factory_registry_plugin import \ DummyFactoryRegistryPlugin from force_bdss.io.workflow_writer import WorkflowWriter, traits_to_dict,\ - pop_traits_version + pop_recursive from force_bdss.core.workflow import Workflow @@ -89,13 +89,17 @@ class TestWorkflowWriter(unittest.TestCase): self.assertEqual(traits_to_dict(mock_traits), {"foo": "bar"}) - def test_pop_traits_version(self): + def test_pop_recursive(self): - test_dictionary = {'Entry1': {'Entry1-1': 4, '__traits_version__': 67}, - 'Entry2': [3, 'a', {'Entry2-1': 5, - '__traits_version__': 9001}], - '__traits_version__': 13} - result_dictionary = {'Entry1': {'Entry1-1': 4, }, - 'Entry2': [3, 'a', {'Entry2-1': 5, }], } - traitless_dictionary = pop_traits_version(test_dictionary) - self.assertEqual(traitless_dictionary, result_dictionary) + test_dictionary = {'K1': {'K1': 'V1', 'K2': 'V2', 'K3': 'V3'}, + 'K2': ['V1', 'V2', {'K1': 'V1', 'K2': 'V2', + 'K3': 'V3'}], + 'K3': 'V3', + 'K4': ('V1', {'K3': 'V3'},)} + + result_dictionary = {'K1': {'K1': 'V1', 'K2': 'V2', }, + 'K2': ['V1', 'V2', {'K1': 'V1', 'K2': 'V2', }], + 'K4': ('V1', {},)} + + test_result_dictionary = pop_recursive(test_dictionary, ) + self.assertEqual(test_result_dictionary, result_dictionary) diff --git a/force_bdss/io/workflow_writer.py b/force_bdss/io/workflow_writer.py index 404fcb29d7e6a8fef7d7a71895cf743b6ff5b29b..98c5c045f28bc265a9cee90051f52e90092e6469 100644 --- a/force_bdss/io/workflow_writer.py +++ b/force_bdss/io/workflow_writer.py @@ -95,28 +95,28 @@ def traits_to_dict(traits_obj): state = traits_obj.__getstate__() - state = pop_traits_version(state) + state = pop_recursive(state,'__traits_version__') return state -def pop_traits_version(dictionary): - """Recursively remove the __traits_version__ attribute - from dictionary.""" +def pop_recursive(dictionary,remove_key): + """Recursively remove a named key from dictionary and any contained + dictionaries.""" try: - dictionary.pop("__traits_version__") + dictionary.pop(remove_key) except KeyError: pass for key in dictionary: - # If we have a dict, remove the traits version + # If remove_key is in the dict, remove it if isinstance(dictionary[key], dict): - pop_traits_version(dictionary[key]) - # If we have a non-dict which contains a dict, remove traits from - # that as well + pop_recursive(dictionary[key], remove_key) + # If we have a non-dict iterable which contains a dict, + # call pop.(remove_key) from that as well elif isinstance(dictionary[key], Iterable): for element in dictionary[key]: if isinstance(element, dict): - pop_traits_version(element) + pop_recursive(element, remove_key) return dictionary