From f3545072466b6610a30d05d9b6c679ea65a9f2b6 Mon Sep 17 00:00:00 2001 From: Stefano Borini <sborini@enthought.com> Date: Mon, 21 May 2018 15:05:56 +0100 Subject: [PATCH] Added tests and modified probe classes for testing the exception in create_model --- force_bdss/io/tests/test_workflow_reader.py | 172 +++++++++++++++++- force_bdss/tests/probe_classes/data_source.py | 11 +- force_bdss/tests/probe_classes/mco.py | 26 ++- .../probe_classes/notification_listener.py | 17 +- 4 files changed, 212 insertions(+), 14 deletions(-) diff --git a/force_bdss/io/tests/test_workflow_reader.py b/force_bdss/io/tests/test_workflow_reader.py index 7170e1e..32173b5 100644 --- a/force_bdss/io/tests/test_workflow_reader.py +++ b/force_bdss/io/tests/test_workflow_reader.py @@ -4,23 +4,69 @@ from six import StringIO import testfixtures +from force_bdss.core.workflow import Workflow from force_bdss.io.workflow_reader import ( WorkflowReader, - InvalidVersionException, InvalidFileException) + InvalidVersionException, InvalidFileException, MissingPluginException, + ModelInstantiationFailedException) from force_bdss.tests.dummy_classes.factory_registry_plugin import \ DummyFactoryRegistryPlugin +from force_bdss.tests.probe_classes.factory_registry_plugin import \ + ProbeFactoryRegistryPlugin class TestWorkflowReader(unittest.TestCase): def setUp(self): self.registry = DummyFactoryRegistryPlugin() - self.wfreader = WorkflowReader(self.registry) + self.working_data = { + "version": "1", + "workflow": { + "mco": { + "id": "force.bdss.enthought.plugin.test.v0" + ".factory.dummy_mco", + "model_data": { + "parameters": [ + { + "id": "force.bdss.enthought.plugin.test.v0" + ".factory.dummy_mco.parameter" + ".dummy_mco_parameter", + "model_data": {} + } + ] + }, + }, + "execution_layers": [ + [{ + "id": "force.bdss.enthought.plugin.test.v0" + ".factory.dummy_data_source", + "model_data": { + "input_slot_info": [], + "output_slot_info": [], + } + }], + ], + "notification_listeners": [ + { + "id": "force.bdss.enthought.plugin.test.v0" + ".factory.dummy_notification_listener", + "model_data": {} + }, + ] + } + } + def test_initialization(self): self.assertEqual(self.wfreader.factory_registry, self.registry) + workflow = self.wfreader.read( + _as_json_stringio(self.working_data) + ) + + self.assertIsInstance(workflow, Workflow) + def test_invalid_version(self): data = { "version": "2", @@ -29,7 +75,7 @@ class TestWorkflowReader(unittest.TestCase): with testfixtures.LogCapture(): with self.assertRaises(InvalidVersionException): - self.wfreader.read(self._as_json_stringio(data)) + self.wfreader.read(_as_json_stringio(data)) def test_absent_version(self): data = { @@ -37,7 +83,7 @@ class TestWorkflowReader(unittest.TestCase): with testfixtures.LogCapture(): with self.assertRaises(InvalidFileException): - self.wfreader.read(self._as_json_stringio(data)) + self.wfreader.read(_as_json_stringio(data)) def test_missing_key(self): data = { @@ -47,11 +93,117 @@ class TestWorkflowReader(unittest.TestCase): with testfixtures.LogCapture(): with self.assertRaises(InvalidFileException): - self.wfreader.read(self._as_json_stringio(data)) + self.wfreader.read(_as_json_stringio(data)) + + def test_missing_plugin_mco(self): + data = self.working_data + data["workflow"]["mco"]["id"] = "missing_mco" + + with self.assertRaises(MissingPluginException): + self.wfreader.read(_as_json_stringio(data)) + + def test_missing_plugin_mco_parameter(self): + data = self.working_data + data["workflow"]["mco"]["model_data"]["parameters"][0]["id"] = \ + "missing_parameter" + + with self.assertRaises(MissingPluginException): + self.wfreader.read(_as_json_stringio(data)) + + def test_missing_plugin_notification_listener(self): + data = self.working_data + data["workflow"]["notification_listeners"][0]["id"] = \ + "missing_nl" + + with self.assertRaises(MissingPluginException): + self.wfreader.read(_as_json_stringio(data)) + + def test_missing_plugin_data_source(self): + data = self.working_data + data["workflow"]["execution_layers"][0][0]["id"] = \ + "missing_ds" + + with self.assertRaises(MissingPluginException): + self.wfreader.read(_as_json_stringio(data)) + + +class TestModelCreationFailure(unittest.TestCase): + def setUp(self): + self.registry = ProbeFactoryRegistryPlugin() + self.wfreader = WorkflowReader(self.registry) + + self.working_data = { + "version": "1", + "workflow": { + "mco": { + "id": "force.bdss.enthought.plugin.test.v0" + ".factory.probe_mco", + "model_data": { + "parameters": [ + { + "id": "force.bdss.enthought.plugin.test.v0" + ".factory.probe_mco.parameter" + ".probe_mco_parameter", + "model_data": {} + } + ] + }, + }, + "execution_layers": [ + [{ + "id": "force.bdss.enthought.plugin.test.v0" + ".factory.probe_data_source", + "model_data": { + "input_slot_info": [], + "output_slot_info": [], + } + }], + ], + "notification_listeners": [ + { + "id": "force.bdss.enthought.plugin.test.v0" + ".factory.probe_notification_listener", + "model_data": {} + }, + ] + } + } + + def test_basic_probe_loading(self): + self.wfreader.read( + _as_json_stringio(self.working_data) + ) + + def test_data_source_model_throws(self): + self.registry.data_source_factories[0].raises_on_create_model = True + with testfixtures.LogCapture(): + with self.assertRaises(ModelInstantiationFailedException): + self.wfreader.read( + _as_json_stringio(self.working_data) + ) + + def test_mco_model_throws(self): + self.registry.mco_factories[0].raises_on_create_model = True + with testfixtures.LogCapture(): + with self.assertRaises(ModelInstantiationFailedException): + self.wfreader.read( + _as_json_stringio(self.working_data) + ) + + def test_notification_listener_throws(self): + factory = self.registry.notification_listener_factories[0] + factory.raises_on_create_model = True + + with testfixtures.LogCapture(): + with self.assertRaises(ModelInstantiationFailedException): + self.wfreader.read( + _as_json_stringio(self.working_data) + ) + +def _as_json_stringio(data): + fp = StringIO() + json.dump(data, fp) + fp.seek(0) - def _as_json_stringio(self, data): - fp = StringIO() - json.dump(data, fp) - fp.seek(0) + return fp - return fp diff --git a/force_bdss/tests/probe_classes/data_source.py b/force_bdss/tests/probe_classes/data_source.py index 5b6119e..65340f1 100644 --- a/force_bdss/tests/probe_classes/data_source.py +++ b/force_bdss/tests/probe_classes/data_source.py @@ -54,8 +54,11 @@ class ProbeDataSourceFactory(BaseDataSourceFactory): input_slots_size = Int(0) output_slots_size = Int(0) + raises_on_create_model = Bool(False) + raises_on_create_data_source = Bool(False) + def get_identifier(self): - return "test_ds" + return "probe_data_source" def get_name(self): return "test_data_source" @@ -67,6 +70,9 @@ class ProbeDataSourceFactory(BaseDataSourceFactory): return ProbeDataSource def create_model(self, model_data=None): + if self.raises_on_create_model: + raise Exception("ProbeDataSourceFactory.create_model") + if model_data is None: model_data = {} return self.model_class( @@ -79,6 +85,9 @@ class ProbeDataSourceFactory(BaseDataSourceFactory): ) def create_data_source(self): + if self.raises_on_create_data_source: + raise Exception("ProbeDataSourceFactory.create_data_source") + return self.data_source_class( factory=self, run_function=self.run_function, diff --git a/force_bdss/tests/probe_classes/mco.py b/force_bdss/tests/probe_classes/mco.py index 3b7bca2..0c76d20 100644 --- a/force_bdss/tests/probe_classes/mco.py +++ b/force_bdss/tests/probe_classes/mco.py @@ -39,7 +39,7 @@ class ProbeParameterFactory(BaseMCOParameterFactory): return "Probe parameter" def get_identifier(self): - return "probe_parameter" + return "probe_mco_parameter" def get_description(self): return "Probe parameter" @@ -67,8 +67,12 @@ class ProbeMCOCommunicator(BaseMCOCommunicator): class ProbeMCOFactory(BaseMCOFactory): nb_output_data_values = Int(0) + raises_on_create_model = Bool(False) + raises_on_create_optimizer = Bool(False) + raises_on_create_communicator = Bool(False) + def get_identifier(self): - return "test_mco" + return "probe_mco" def get_model_class(self): return ProbeMCOModel @@ -83,9 +87,27 @@ class ProbeMCOFactory(BaseMCOFactory): return "testmco" def create_communicator(self): + if self.raises_on_create_communicator: + raise Exception("ProbeMCOFactory.create_communicator") + return self.communicator_class( self, nb_output_data_values=self.nb_output_data_values) + def create_model(self, model_data): + if self.raises_on_create_model: + raise Exception("ProbeMCOFactory.create_model") + + if model_data is None: + model_data = {} + + return self.model_class(self, **model_data) + + def create_optimizer(self): + if self.raises_on_create_optimizer: + raise Exception("ProbeMCOFactory.create_optimizer") + + return self.optimizer_class(self) + def parameter_factories(self): return [ProbeParameterFactory(mco_factory=self)] diff --git a/force_bdss/tests/probe_classes/notification_listener.py b/force_bdss/tests/probe_classes/notification_listener.py index 5de05f5..3e3c987 100644 --- a/force_bdss/tests/probe_classes/notification_listener.py +++ b/force_bdss/tests/probe_classes/notification_listener.py @@ -46,11 +46,14 @@ class ProbeNotificationListenerFactory(BaseNotificationListenerFactory): deliver_function = Function(default_value=pass_function) finalize_function = Function(default_value=pass_function) + raises_on_create_model = Bool(False) + raises_on_create_listener = Bool(False) + def get_name(self): return "test_notification_listener" def get_identifier(self): - return "test_nl" + return "probe_notification_listener" def get_listener_class(self): return ProbeNotificationListener @@ -58,7 +61,19 @@ class ProbeNotificationListenerFactory(BaseNotificationListenerFactory): def get_model_class(self): return ProbeNotificationListenerModel + def create_model(self, model_data=None): + if self.raises_on_create_model: + raise Exception("ProbeNotificationListenerFactory.create_model") + + if model_data is None: + model_data = {} + + return self.model_class(self, **model_data) + def create_listener(self): + if self.raises_on_create_listener: + raise Exception("ProbeNotificationListenerFactory.create_listener") + return self.listener_class( self, initialize_function=self.initialize_function, -- GitLab