diff --git a/force_bdss/io/tests/test_workflow_reader.py b/force_bdss/io/tests/test_workflow_reader.py index 7170e1e12968e2b384587d06eecf33832f53f978..32173b576be69f72007d1aec5f120ad1448056ec 100644 --- a/force_bdss/io/tests/test_workflow_reader.py +++ b/force_bdss/io/tests/test_workflow_reader.py @@ -4,23 +4,69 @@ from six import StringIO import testfixtures +from force_bdss.core.workflow import Workflow from force_bdss.io.workflow_reader import ( WorkflowReader, - InvalidVersionException, InvalidFileException) + InvalidVersionException, InvalidFileException, MissingPluginException, + ModelInstantiationFailedException) from force_bdss.tests.dummy_classes.factory_registry_plugin import \ DummyFactoryRegistryPlugin +from force_bdss.tests.probe_classes.factory_registry_plugin import \ + ProbeFactoryRegistryPlugin class TestWorkflowReader(unittest.TestCase): def setUp(self): self.registry = DummyFactoryRegistryPlugin() - self.wfreader = WorkflowReader(self.registry) + self.working_data = { + "version": "1", + "workflow": { + "mco": { + "id": "force.bdss.enthought.plugin.test.v0" + ".factory.dummy_mco", + "model_data": { + "parameters": [ + { + "id": "force.bdss.enthought.plugin.test.v0" + ".factory.dummy_mco.parameter" + ".dummy_mco_parameter", + "model_data": {} + } + ] + }, + }, + "execution_layers": [ + [{ + "id": "force.bdss.enthought.plugin.test.v0" + ".factory.dummy_data_source", + "model_data": { + "input_slot_info": [], + "output_slot_info": [], + } + }], + ], + "notification_listeners": [ + { + "id": "force.bdss.enthought.plugin.test.v0" + ".factory.dummy_notification_listener", + "model_data": {} + }, + ] + } + } + def test_initialization(self): self.assertEqual(self.wfreader.factory_registry, self.registry) + workflow = self.wfreader.read( + _as_json_stringio(self.working_data) + ) + + self.assertIsInstance(workflow, Workflow) + def test_invalid_version(self): data = { "version": "2", @@ -29,7 +75,7 @@ class TestWorkflowReader(unittest.TestCase): with testfixtures.LogCapture(): with self.assertRaises(InvalidVersionException): - self.wfreader.read(self._as_json_stringio(data)) + self.wfreader.read(_as_json_stringio(data)) def test_absent_version(self): data = { @@ -37,7 +83,7 @@ class TestWorkflowReader(unittest.TestCase): with testfixtures.LogCapture(): with self.assertRaises(InvalidFileException): - self.wfreader.read(self._as_json_stringio(data)) + self.wfreader.read(_as_json_stringio(data)) def test_missing_key(self): data = { @@ -47,11 +93,117 @@ class TestWorkflowReader(unittest.TestCase): with testfixtures.LogCapture(): with self.assertRaises(InvalidFileException): - self.wfreader.read(self._as_json_stringio(data)) + self.wfreader.read(_as_json_stringio(data)) + + def test_missing_plugin_mco(self): + data = self.working_data + data["workflow"]["mco"]["id"] = "missing_mco" + + with self.assertRaises(MissingPluginException): + self.wfreader.read(_as_json_stringio(data)) + + def test_missing_plugin_mco_parameter(self): + data = self.working_data + data["workflow"]["mco"]["model_data"]["parameters"][0]["id"] = \ + "missing_parameter" + + with self.assertRaises(MissingPluginException): + self.wfreader.read(_as_json_stringio(data)) + + def test_missing_plugin_notification_listener(self): + data = self.working_data + data["workflow"]["notification_listeners"][0]["id"] = \ + "missing_nl" + + with self.assertRaises(MissingPluginException): + self.wfreader.read(_as_json_stringio(data)) + + def test_missing_plugin_data_source(self): + data = self.working_data + data["workflow"]["execution_layers"][0][0]["id"] = \ + "missing_ds" + + with self.assertRaises(MissingPluginException): + self.wfreader.read(_as_json_stringio(data)) + + +class TestModelCreationFailure(unittest.TestCase): + def setUp(self): + self.registry = ProbeFactoryRegistryPlugin() + self.wfreader = WorkflowReader(self.registry) + + self.working_data = { + "version": "1", + "workflow": { + "mco": { + "id": "force.bdss.enthought.plugin.test.v0" + ".factory.probe_mco", + "model_data": { + "parameters": [ + { + "id": "force.bdss.enthought.plugin.test.v0" + ".factory.probe_mco.parameter" + ".probe_mco_parameter", + "model_data": {} + } + ] + }, + }, + "execution_layers": [ + [{ + "id": "force.bdss.enthought.plugin.test.v0" + ".factory.probe_data_source", + "model_data": { + "input_slot_info": [], + "output_slot_info": [], + } + }], + ], + "notification_listeners": [ + { + "id": "force.bdss.enthought.plugin.test.v0" + ".factory.probe_notification_listener", + "model_data": {} + }, + ] + } + } + + def test_basic_probe_loading(self): + self.wfreader.read( + _as_json_stringio(self.working_data) + ) + + def test_data_source_model_throws(self): + self.registry.data_source_factories[0].raises_on_create_model = True + with testfixtures.LogCapture(): + with self.assertRaises(ModelInstantiationFailedException): + self.wfreader.read( + _as_json_stringio(self.working_data) + ) + + def test_mco_model_throws(self): + self.registry.mco_factories[0].raises_on_create_model = True + with testfixtures.LogCapture(): + with self.assertRaises(ModelInstantiationFailedException): + self.wfreader.read( + _as_json_stringio(self.working_data) + ) + + def test_notification_listener_throws(self): + factory = self.registry.notification_listener_factories[0] + factory.raises_on_create_model = True + + with testfixtures.LogCapture(): + with self.assertRaises(ModelInstantiationFailedException): + self.wfreader.read( + _as_json_stringio(self.working_data) + ) + +def _as_json_stringio(data): + fp = StringIO() + json.dump(data, fp) + fp.seek(0) - def _as_json_stringio(self, data): - fp = StringIO() - json.dump(data, fp) - fp.seek(0) + return fp - return fp diff --git a/force_bdss/tests/probe_classes/data_source.py b/force_bdss/tests/probe_classes/data_source.py index 5b6119ead303e72991805dc302d84c920333ee27..65340f1df98a4923d1ff3fe2d8d4937423bad2a3 100644 --- a/force_bdss/tests/probe_classes/data_source.py +++ b/force_bdss/tests/probe_classes/data_source.py @@ -54,8 +54,11 @@ class ProbeDataSourceFactory(BaseDataSourceFactory): input_slots_size = Int(0) output_slots_size = Int(0) + raises_on_create_model = Bool(False) + raises_on_create_data_source = Bool(False) + def get_identifier(self): - return "test_ds" + return "probe_data_source" def get_name(self): return "test_data_source" @@ -67,6 +70,9 @@ class ProbeDataSourceFactory(BaseDataSourceFactory): return ProbeDataSource def create_model(self, model_data=None): + if self.raises_on_create_model: + raise Exception("ProbeDataSourceFactory.create_model") + if model_data is None: model_data = {} return self.model_class( @@ -79,6 +85,9 @@ class ProbeDataSourceFactory(BaseDataSourceFactory): ) def create_data_source(self): + if self.raises_on_create_data_source: + raise Exception("ProbeDataSourceFactory.create_data_source") + return self.data_source_class( factory=self, run_function=self.run_function, diff --git a/force_bdss/tests/probe_classes/mco.py b/force_bdss/tests/probe_classes/mco.py index 3b7bca244dbad8694d2f76a0bec3674387e60495..0c76d20491b0c67e207473853a87dd52e038fd1c 100644 --- a/force_bdss/tests/probe_classes/mco.py +++ b/force_bdss/tests/probe_classes/mco.py @@ -39,7 +39,7 @@ class ProbeParameterFactory(BaseMCOParameterFactory): return "Probe parameter" def get_identifier(self): - return "probe_parameter" + return "probe_mco_parameter" def get_description(self): return "Probe parameter" @@ -67,8 +67,12 @@ class ProbeMCOCommunicator(BaseMCOCommunicator): class ProbeMCOFactory(BaseMCOFactory): nb_output_data_values = Int(0) + raises_on_create_model = Bool(False) + raises_on_create_optimizer = Bool(False) + raises_on_create_communicator = Bool(False) + def get_identifier(self): - return "test_mco" + return "probe_mco" def get_model_class(self): return ProbeMCOModel @@ -83,9 +87,27 @@ class ProbeMCOFactory(BaseMCOFactory): return "testmco" def create_communicator(self): + if self.raises_on_create_communicator: + raise Exception("ProbeMCOFactory.create_communicator") + return self.communicator_class( self, nb_output_data_values=self.nb_output_data_values) + def create_model(self, model_data): + if self.raises_on_create_model: + raise Exception("ProbeMCOFactory.create_model") + + if model_data is None: + model_data = {} + + return self.model_class(self, **model_data) + + def create_optimizer(self): + if self.raises_on_create_optimizer: + raise Exception("ProbeMCOFactory.create_optimizer") + + return self.optimizer_class(self) + def parameter_factories(self): return [ProbeParameterFactory(mco_factory=self)] diff --git a/force_bdss/tests/probe_classes/notification_listener.py b/force_bdss/tests/probe_classes/notification_listener.py index 5de05f50c9a4a7d16879db9cc4271f6efa345d44..3e3c987b26d5730671a4e46897cb649cabf897ce 100644 --- a/force_bdss/tests/probe_classes/notification_listener.py +++ b/force_bdss/tests/probe_classes/notification_listener.py @@ -46,11 +46,14 @@ class ProbeNotificationListenerFactory(BaseNotificationListenerFactory): deliver_function = Function(default_value=pass_function) finalize_function = Function(default_value=pass_function) + raises_on_create_model = Bool(False) + raises_on_create_listener = Bool(False) + def get_name(self): return "test_notification_listener" def get_identifier(self): - return "test_nl" + return "probe_notification_listener" def get_listener_class(self): return ProbeNotificationListener @@ -58,7 +61,19 @@ class ProbeNotificationListenerFactory(BaseNotificationListenerFactory): def get_model_class(self): return ProbeNotificationListenerModel + def create_model(self, model_data=None): + if self.raises_on_create_model: + raise Exception("ProbeNotificationListenerFactory.create_model") + + if model_data is None: + model_data = {} + + return self.model_class(self, **model_data) + def create_listener(self): + if self.raises_on_create_listener: + raise Exception("ProbeNotificationListenerFactory.create_listener") + return self.listener_class( self, initialize_function=self.initialize_function,