diff --git a/force_bdss/api.py b/force_bdss/api.py index ff4094d31bcc913cd3971369a9f76fd2db89502a..165dd9bdc1869e369d2f0e80a2669224947941d4 100644 --- a/force_bdss/api.py +++ b/force_bdss/api.py @@ -5,12 +5,25 @@ from .core.data_value import DataValue # noqa from .core.workflow import Workflow # noqa from .core.slot import Slot # noqa from .core.i_factory import IFactory # noqa +from .core.input_slot_info import InputSlotInfo # noqa +from .core.output_slot_info import OutputSlotInfo # noqa +from .core.kpi_specification import KPISpecification # noqa +from .core.execution_layer import ExecutionLayer # noqa +from .core.verifier import verify_workflow # noqa +from .core.verifier import VerifierError # noqa from .data_sources.base_data_source_model import BaseDataSourceModel # noqa from .data_sources.base_data_source import BaseDataSource # noqa from .data_sources.base_data_source_factory import BaseDataSourceFactory # noqa from .data_sources.i_data_source_factory import IDataSourceFactory # noqa +from .factory_registry_plugin import IFactoryRegistryPlugin # noqa +from .factory_registry_plugin import FactoryRegistryPlugin # noqa + +from .io.workflow_reader import WorkflowReader # noqa +from .io.workflow_reader import InvalidFileException # noqa +from .io.workflow_writer import WorkflowWriter # noqa + from .mco.base_mco_model import BaseMCOModel # noqa from .mco.base_mco_communicator import BaseMCOCommunicator # noqa from .mco.base_mco import BaseMCO # noqa diff --git a/force_bdss/io/tests/test_workflow_writer.py b/force_bdss/io/tests/test_workflow_writer.py index d2a9d297039661d38b4ca71311ed5a2d062add1f..379af5a686bc03f3988ac56bf43aa857bfa02193 100644 --- a/force_bdss/io/tests/test_workflow_writer.py +++ b/force_bdss/io/tests/test_workflow_writer.py @@ -13,8 +13,10 @@ from force_bdss.io.workflow_reader import WorkflowReader from force_bdss.tests.dummy_classes.factory_registry_plugin import \ DummyFactoryRegistryPlugin -from force_bdss.io.workflow_writer import WorkflowWriter, traits_to_dict +from force_bdss.io.workflow_writer import WorkflowWriter, traits_to_dict,\ + pop_recursive from force_bdss.core.workflow import Workflow +from force_bdss.core.input_slot_info import InputSlotInfo class TestWorkflowWriter(unittest.TestCase): @@ -87,3 +89,29 @@ class TestWorkflowWriter(unittest.TestCase): mock_traits.__getstate__ = mock.Mock(return_value={"foo": "bar"}) self.assertEqual(traits_to_dict(mock_traits), {"foo": "bar"}) + + def test_traits_to_dict(self): + + wfwriter = WorkflowWriter() + wf = self._create_workflow() + exec_layer = wf.execution_layers[0] + exec_layer.data_sources[0].input_slot_info = [InputSlotInfo()] + slotdata = exec_layer.data_sources[0].input_slot_info[0].__getstate__() + self.assertTrue("__traits_version__" in slotdata) + # Calls traits_to_dict for each data source + datastore_list = wfwriter._execution_layer_data(exec_layer) + new_slotdata = datastore_list[0]['model_data']['input_slot_info'] + self.assertTrue("__traits_version__" not in new_slotdata) + + test_dictionary = {'K1': {'K1': 'V1', 'K2': 'V2', 'K3': 'V3'}, + 'K2': ['V1', 'V2', {'K1': 'V1', 'K2': 'V2', + 'K3': 'V3'}], + 'K3': 'V3', + 'K4': ('V1', {'K3': 'V3'},)} + + result_dictionary = {'K1': {'K1': 'V1', 'K2': 'V2', }, + 'K2': ['V1', 'V2', {'K1': 'V1', 'K2': 'V2', }], + 'K4': ('V1', {},)} + + test_result_dictionary = pop_recursive(test_dictionary, 'K3') + self.assertEqual(test_result_dictionary, result_dictionary) diff --git a/force_bdss/io/workflow_writer.py b/force_bdss/io/workflow_writer.py index aadd19d2156dc98ca3ea2aff5db5fad7e5e40d0b..ddd90ef4ef5979267aa0c5033d838655cb65b2e9 100644 --- a/force_bdss/io/workflow_writer.py +++ b/force_bdss/io/workflow_writer.py @@ -91,10 +91,31 @@ class WorkflowWriter(HasStrictTraits): def traits_to_dict(traits_obj): """Converts a traits class into a dict, removing the pesky traits version.""" + state = traits_obj.__getstate__() + + state = pop_recursive(state, '__traits_version__') + + return state + + +def pop_recursive(dictionary, remove_key): + """Recursively remove a named key from dictionary and any contained + dictionaries.""" try: - state.pop("__traits_version__") + dictionary.pop(remove_key) except KeyError: pass - return state + for key, value in dictionary.items(): + # If remove_key is in the dict, remove it + if isinstance(value, dict): + pop_recursive(value, remove_key) + # If we have a non-dict iterable which contains a dict, + # call pop.(remove_key) from that as well + elif isinstance(value, (tuple, list)): + for element in value: + if isinstance(element, dict): + pop_recursive(element, remove_key) + + return dictionary