Skip to content
Snippets Groups Projects
Commit f3545072 authored by Stefano Borini's avatar Stefano Borini
Browse files

Added tests and modified probe classes for testing the exception in create_model

parent bfe8d5a3
No related branches found
No related tags found
1 merge request!137Communicative workflow reader
......@@ -4,23 +4,69 @@ from six import StringIO
import testfixtures
from force_bdss.core.workflow import Workflow
from force_bdss.io.workflow_reader import (
WorkflowReader,
InvalidVersionException, InvalidFileException)
InvalidVersionException, InvalidFileException, MissingPluginException,
ModelInstantiationFailedException)
from force_bdss.tests.dummy_classes.factory_registry_plugin import \
DummyFactoryRegistryPlugin
from force_bdss.tests.probe_classes.factory_registry_plugin import \
ProbeFactoryRegistryPlugin
class TestWorkflowReader(unittest.TestCase):
def setUp(self):
self.registry = DummyFactoryRegistryPlugin()
self.wfreader = WorkflowReader(self.registry)
self.working_data = {
"version": "1",
"workflow": {
"mco": {
"id": "force.bdss.enthought.plugin.test.v0"
".factory.dummy_mco",
"model_data": {
"parameters": [
{
"id": "force.bdss.enthought.plugin.test.v0"
".factory.dummy_mco.parameter"
".dummy_mco_parameter",
"model_data": {}
}
]
},
},
"execution_layers": [
[{
"id": "force.bdss.enthought.plugin.test.v0"
".factory.dummy_data_source",
"model_data": {
"input_slot_info": [],
"output_slot_info": [],
}
}],
],
"notification_listeners": [
{
"id": "force.bdss.enthought.plugin.test.v0"
".factory.dummy_notification_listener",
"model_data": {}
},
]
}
}
def test_initialization(self):
self.assertEqual(self.wfreader.factory_registry,
self.registry)
workflow = self.wfreader.read(
_as_json_stringio(self.working_data)
)
self.assertIsInstance(workflow, Workflow)
def test_invalid_version(self):
data = {
"version": "2",
......@@ -29,7 +75,7 @@ class TestWorkflowReader(unittest.TestCase):
with testfixtures.LogCapture():
with self.assertRaises(InvalidVersionException):
self.wfreader.read(self._as_json_stringio(data))
self.wfreader.read(_as_json_stringio(data))
def test_absent_version(self):
data = {
......@@ -37,7 +83,7 @@ class TestWorkflowReader(unittest.TestCase):
with testfixtures.LogCapture():
with self.assertRaises(InvalidFileException):
self.wfreader.read(self._as_json_stringio(data))
self.wfreader.read(_as_json_stringio(data))
def test_missing_key(self):
data = {
......@@ -47,11 +93,117 @@ class TestWorkflowReader(unittest.TestCase):
with testfixtures.LogCapture():
with self.assertRaises(InvalidFileException):
self.wfreader.read(self._as_json_stringio(data))
self.wfreader.read(_as_json_stringio(data))
def test_missing_plugin_mco(self):
data = self.working_data
data["workflow"]["mco"]["id"] = "missing_mco"
with self.assertRaises(MissingPluginException):
self.wfreader.read(_as_json_stringio(data))
def test_missing_plugin_mco_parameter(self):
data = self.working_data
data["workflow"]["mco"]["model_data"]["parameters"][0]["id"] = \
"missing_parameter"
with self.assertRaises(MissingPluginException):
self.wfreader.read(_as_json_stringio(data))
def test_missing_plugin_notification_listener(self):
data = self.working_data
data["workflow"]["notification_listeners"][0]["id"] = \
"missing_nl"
with self.assertRaises(MissingPluginException):
self.wfreader.read(_as_json_stringio(data))
def test_missing_plugin_data_source(self):
data = self.working_data
data["workflow"]["execution_layers"][0][0]["id"] = \
"missing_ds"
with self.assertRaises(MissingPluginException):
self.wfreader.read(_as_json_stringio(data))
class TestModelCreationFailure(unittest.TestCase):
def setUp(self):
self.registry = ProbeFactoryRegistryPlugin()
self.wfreader = WorkflowReader(self.registry)
self.working_data = {
"version": "1",
"workflow": {
"mco": {
"id": "force.bdss.enthought.plugin.test.v0"
".factory.probe_mco",
"model_data": {
"parameters": [
{
"id": "force.bdss.enthought.plugin.test.v0"
".factory.probe_mco.parameter"
".probe_mco_parameter",
"model_data": {}
}
]
},
},
"execution_layers": [
[{
"id": "force.bdss.enthought.plugin.test.v0"
".factory.probe_data_source",
"model_data": {
"input_slot_info": [],
"output_slot_info": [],
}
}],
],
"notification_listeners": [
{
"id": "force.bdss.enthought.plugin.test.v0"
".factory.probe_notification_listener",
"model_data": {}
},
]
}
}
def test_basic_probe_loading(self):
self.wfreader.read(
_as_json_stringio(self.working_data)
)
def test_data_source_model_throws(self):
self.registry.data_source_factories[0].raises_on_create_model = True
with testfixtures.LogCapture():
with self.assertRaises(ModelInstantiationFailedException):
self.wfreader.read(
_as_json_stringio(self.working_data)
)
def test_mco_model_throws(self):
self.registry.mco_factories[0].raises_on_create_model = True
with testfixtures.LogCapture():
with self.assertRaises(ModelInstantiationFailedException):
self.wfreader.read(
_as_json_stringio(self.working_data)
)
def test_notification_listener_throws(self):
factory = self.registry.notification_listener_factories[0]
factory.raises_on_create_model = True
with testfixtures.LogCapture():
with self.assertRaises(ModelInstantiationFailedException):
self.wfreader.read(
_as_json_stringio(self.working_data)
)
def _as_json_stringio(data):
fp = StringIO()
json.dump(data, fp)
fp.seek(0)
def _as_json_stringio(self, data):
fp = StringIO()
json.dump(data, fp)
fp.seek(0)
return fp
return fp
......@@ -54,8 +54,11 @@ class ProbeDataSourceFactory(BaseDataSourceFactory):
input_slots_size = Int(0)
output_slots_size = Int(0)
raises_on_create_model = Bool(False)
raises_on_create_data_source = Bool(False)
def get_identifier(self):
return "test_ds"
return "probe_data_source"
def get_name(self):
return "test_data_source"
......@@ -67,6 +70,9 @@ class ProbeDataSourceFactory(BaseDataSourceFactory):
return ProbeDataSource
def create_model(self, model_data=None):
if self.raises_on_create_model:
raise Exception("ProbeDataSourceFactory.create_model")
if model_data is None:
model_data = {}
return self.model_class(
......@@ -79,6 +85,9 @@ class ProbeDataSourceFactory(BaseDataSourceFactory):
)
def create_data_source(self):
if self.raises_on_create_data_source:
raise Exception("ProbeDataSourceFactory.create_data_source")
return self.data_source_class(
factory=self,
run_function=self.run_function,
......
......@@ -39,7 +39,7 @@ class ProbeParameterFactory(BaseMCOParameterFactory):
return "Probe parameter"
def get_identifier(self):
return "probe_parameter"
return "probe_mco_parameter"
def get_description(self):
return "Probe parameter"
......@@ -67,8 +67,12 @@ class ProbeMCOCommunicator(BaseMCOCommunicator):
class ProbeMCOFactory(BaseMCOFactory):
nb_output_data_values = Int(0)
raises_on_create_model = Bool(False)
raises_on_create_optimizer = Bool(False)
raises_on_create_communicator = Bool(False)
def get_identifier(self):
return "test_mco"
return "probe_mco"
def get_model_class(self):
return ProbeMCOModel
......@@ -83,9 +87,27 @@ class ProbeMCOFactory(BaseMCOFactory):
return "testmco"
def create_communicator(self):
if self.raises_on_create_communicator:
raise Exception("ProbeMCOFactory.create_communicator")
return self.communicator_class(
self,
nb_output_data_values=self.nb_output_data_values)
def create_model(self, model_data):
if self.raises_on_create_model:
raise Exception("ProbeMCOFactory.create_model")
if model_data is None:
model_data = {}
return self.model_class(self, **model_data)
def create_optimizer(self):
if self.raises_on_create_optimizer:
raise Exception("ProbeMCOFactory.create_optimizer")
return self.optimizer_class(self)
def parameter_factories(self):
return [ProbeParameterFactory(mco_factory=self)]
......@@ -46,11 +46,14 @@ class ProbeNotificationListenerFactory(BaseNotificationListenerFactory):
deliver_function = Function(default_value=pass_function)
finalize_function = Function(default_value=pass_function)
raises_on_create_model = Bool(False)
raises_on_create_listener = Bool(False)
def get_name(self):
return "test_notification_listener"
def get_identifier(self):
return "test_nl"
return "probe_notification_listener"
def get_listener_class(self):
return ProbeNotificationListener
......@@ -58,7 +61,19 @@ class ProbeNotificationListenerFactory(BaseNotificationListenerFactory):
def get_model_class(self):
return ProbeNotificationListenerModel
def create_model(self, model_data=None):
if self.raises_on_create_model:
raise Exception("ProbeNotificationListenerFactory.create_model")
if model_data is None:
model_data = {}
return self.model_class(self, **model_data)
def create_listener(self):
if self.raises_on_create_listener:
raise Exception("ProbeNotificationListenerFactory.create_listener")
return self.listener_class(
self,
initialize_function=self.initialize_function,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment