Skip to content
Snippets Groups Projects
Unverified Commit dd052776 authored by jjenthought's avatar jjenthought Committed by GitHub
Browse files

Merge pull request #152 from force-h2020/163-save-load-behaviour-bdss

Merged #152 
parents dc2dbce4 e9053ef7
No related branches found
No related tags found
No related merge requests found
...@@ -13,8 +13,10 @@ from force_bdss.io.workflow_reader import WorkflowReader ...@@ -13,8 +13,10 @@ from force_bdss.io.workflow_reader import WorkflowReader
from force_bdss.tests.dummy_classes.factory_registry_plugin import \ from force_bdss.tests.dummy_classes.factory_registry_plugin import \
DummyFactoryRegistryPlugin DummyFactoryRegistryPlugin
from force_bdss.io.workflow_writer import WorkflowWriter, traits_to_dict from force_bdss.io.workflow_writer import WorkflowWriter, traits_to_dict,\
pop_recursive
from force_bdss.core.workflow import Workflow from force_bdss.core.workflow import Workflow
from force_bdss.core.input_slot_info import InputSlotInfo
class TestWorkflowWriter(unittest.TestCase): class TestWorkflowWriter(unittest.TestCase):
...@@ -87,3 +89,29 @@ class TestWorkflowWriter(unittest.TestCase): ...@@ -87,3 +89,29 @@ class TestWorkflowWriter(unittest.TestCase):
mock_traits.__getstate__ = mock.Mock(return_value={"foo": "bar"}) mock_traits.__getstate__ = mock.Mock(return_value={"foo": "bar"})
self.assertEqual(traits_to_dict(mock_traits), {"foo": "bar"}) self.assertEqual(traits_to_dict(mock_traits), {"foo": "bar"})
def test_traits_to_dict(self):
wfwriter = WorkflowWriter()
wf = self._create_workflow()
exec_layer = wf.execution_layers[0]
exec_layer.data_sources[0].input_slot_info = [InputSlotInfo()]
slotdata = exec_layer.data_sources[0].input_slot_info[0].__getstate__()
self.assertTrue("__traits_version__" in slotdata)
# Calls traits_to_dict for each data source
datastore_list = wfwriter._execution_layer_data(exec_layer)
new_slotdata = datastore_list[0]['model_data']['input_slot_info']
self.assertTrue("__traits_version__" not in new_slotdata)
test_dictionary = {'K1': {'K1': 'V1', 'K2': 'V2', 'K3': 'V3'},
'K2': ['V1', 'V2', {'K1': 'V1', 'K2': 'V2',
'K3': 'V3'}],
'K3': 'V3',
'K4': ('V1', {'K3': 'V3'},)}
result_dictionary = {'K1': {'K1': 'V1', 'K2': 'V2', },
'K2': ['V1', 'V2', {'K1': 'V1', 'K2': 'V2', }],
'K4': ('V1', {},)}
test_result_dictionary = pop_recursive(test_dictionary, 'K3')
self.assertEqual(test_result_dictionary, result_dictionary)
...@@ -91,10 +91,31 @@ class WorkflowWriter(HasStrictTraits): ...@@ -91,10 +91,31 @@ class WorkflowWriter(HasStrictTraits):
def traits_to_dict(traits_obj): def traits_to_dict(traits_obj):
"""Converts a traits class into a dict, removing the pesky """Converts a traits class into a dict, removing the pesky
traits version.""" traits version."""
state = traits_obj.__getstate__() state = traits_obj.__getstate__()
state = pop_recursive(state, '__traits_version__')
return state
def pop_recursive(dictionary, remove_key):
"""Recursively remove a named key from dictionary and any contained
dictionaries."""
try: try:
state.pop("__traits_version__") dictionary.pop(remove_key)
except KeyError: except KeyError:
pass pass
return state for key, value in dictionary.items():
# If remove_key is in the dict, remove it
if isinstance(value, dict):
pop_recursive(value, remove_key)
# If we have a non-dict iterable which contains a dict,
# call pop.(remove_key) from that as well
elif isinstance(value, (tuple, list)):
for element in value:
if isinstance(element, dict):
pop_recursive(element, remove_key)
return dictionary
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment