Skip to content
Snippets Groups Projects
Commit 31ccc154 authored by martinRenou's avatar martinRenou
Browse files

Start to clean the tests

parent cbaa5075
No related branches found
No related tags found
1 merge request!101Create probe classes for tests
......@@ -19,9 +19,11 @@ class BaseMCOCommunicator(ABCHasStrictTraits):
#: A reference to the factory
factory = Instance(IMCOFactory)
def __init__(self, factory):
def __init__(self, factory, *args, **kwargs):
self.factory = factory
super(BaseMCOCommunicator, self).__init__(*args, **kwargs)
@abc.abstractmethod
def receive_from_mco(self, model):
"""
......
......@@ -4,6 +4,11 @@ from traits.api import Float
from force_bdss.tests.probe_classes.factory_registry_plugin import \
ProbeFactoryRegistryPlugin
from force_bdss.tests.probe_classes.mco import (
ProbeMCOCommunicator, ProbeMCOFactory)
from force_bdss.tests.probe_classes.data_source import (
ProbeDataSourceFactory)
from force_bdss.core.input_slot_map import InputSlotMap
from force_bdss.core.data_value import DataValue
......@@ -12,25 +17,11 @@ from force_bdss.data_sources.base_data_source import BaseDataSource
from force_bdss.data_sources.base_data_source_factory import \
BaseDataSourceFactory
from force_bdss.data_sources.base_data_source_model import BaseDataSourceModel
from force_bdss.ids import mco_parameter_id, factory_id
from force_bdss.ids import mco_parameter_id
from force_bdss.kpi.base_kpi_calculator import BaseKPICalculator
from force_bdss.kpi.base_kpi_calculator_factory import BaseKPICalculatorFactory
from force_bdss.kpi.base_kpi_calculator_model import BaseKPICalculatorModel
from force_bdss.mco.base_mco import BaseMCO
from force_bdss.mco.base_mco_factory import BaseMCOFactory
from force_bdss.mco.base_mco_communicator import BaseMCOCommunicator
from force_bdss.mco.base_mco_model import BaseMCOModel
from force_bdss.mco.parameters.base_mco_parameter import BaseMCOParameter
from force_bdss.mco.parameters.base_mco_parameter_factory import \
BaseMCOParameterFactory
from force_bdss.notification_listeners.base_notification_listener import \
BaseNotificationListener
from force_bdss.notification_listeners.base_notification_listener_factory \
import \
BaseNotificationListenerFactory
from force_bdss.notification_listeners.base_notification_listener_model \
import \
BaseNotificationListenerModel
from force_bdss.tests import fixtures
try:
......@@ -44,15 +35,6 @@ from force_bdss.core_evaluation_driver import CoreEvaluationDriver, \
_bind_data_values, _compute_layer_results
class NullMCOModel(BaseMCOModel):
pass
class NullMCO(BaseMCO):
def run(self, model):
pass
class RangedParameter(BaseMCOParameter):
initial_value = Float()
lower_bound = Float()
......@@ -64,52 +46,10 @@ class RangedParameterFactory(BaseMCOParameterFactory):
model_class = RangedParameter
class NullMCOCommunicator(BaseMCOCommunicator):
def send_to_mco(self, model, kpi_results):
pass
def receive_from_mco(self, model):
return []
class OneDataValueMCOCommunicator(BaseMCOCommunicator):
class OneValueMCOCommunicator(ProbeMCOCommunicator):
"""A communicator that returns one single datavalue, for testing purposes.
"""
def send_to_mco(self, model, kpi_results):
pass
def receive_from_mco(self, model):
return [
DataValue()
]
class NullMCOFactory(BaseMCOFactory):
id = factory_id("enthought", "test_mco")
def create_model(self, model_data=None):
return NullMCOModel(self, **model_data)
def create_communicator(self):
return NullMCOCommunicator(self)
def create_optimizer(self):
return NullMCO(self)
def parameter_factories(self):
return []
class NullKPICalculatorModel(BaseKPICalculatorModel):
pass
class NullKPICalculator(BaseKPICalculator):
def run(self, model, data_source_results):
return []
def slots(self, model):
return (), ()
nb_output_data_values = 1
class BrokenOneValueKPICalculator(BaseKPICalculator):
......@@ -128,39 +68,10 @@ class OneValueKPICalculator(BaseKPICalculator):
return (), (Slot(), )
class NullKPICalculatorFactory(BaseKPICalculatorFactory):
id = factory_id("enthought", "test_kpi_calculator")
name = "test_kpi_calculator"
def create_model(self, model_data=None):
return NullKPICalculatorModel(self)
def create_kpi_calculator(self):
return NullKPICalculator(self)
class NullDataSourceModel(BaseDataSourceModel):
pass
class NullDataSource(BaseDataSource):
def run(self, model, parameters):
return []
def slots(self, model):
return (), ()
class BrokenOneValueDataSource(BaseDataSource):
"""Incorrect data source implementation whose run returns a data value
but no slot was specified for it."""
def run(self, model, parameters):
return [DataValue()]
def slots(self, model):
return (), ()
class OneValueDataSource(BaseDataSource):
"""Incorrect data source implementation whose run returns a data value
but no slot was specified for it."""
......@@ -186,43 +97,6 @@ class TwoInputsThreeOutputsDataSource(BaseDataSource):
)
class NullDataSourceFactory(BaseDataSourceFactory):
id = factory_id("enthought", "test_data_source")
name = "test_data_source"
def create_model(self, model_data=None):
return NullDataSourceModel(self)
def create_data_source(self):
return NullDataSource(self)
class NullNotificationListener(BaseNotificationListener):
def initialize(self, model):
pass
def deliver(self, event):
pass
def finalize(self):
pass
class NullNotificationListenerModel(BaseNotificationListenerModel):
pass
class NullNotificationListenerFactory(BaseNotificationListenerFactory):
id = factory_id("enthought", "null_nl")
name = "null_nl"
def create_listener(self):
return NullNotificationListener(self)
def create_model(self, model_data=None):
return NullNotificationListenerModel(self)
class TestCoreEvaluationDriver(unittest.TestCase):
def setUp(self):
self.factory_registry_plugin = ProbeFactoryRegistryPlugin()
......@@ -240,33 +114,34 @@ class TestCoreEvaluationDriver(unittest.TestCase):
driver.application_started()
def test_error_for_non_matching_mco_parameters(self):
factory = self.factory_registry_plugin.mco_factories[0]
with mock.patch.object(factory.__class__,
"create_communicator") as create_comm:
create_comm.return_value = OneDataValueMCOCommunicator(
factory)
driver = CoreEvaluationDriver(
application=self.mock_application,
)
with self.assertRaisesRegexp(
RuntimeError,
"The number of data values returned by the MCO"):
driver.application_started()
mco_factories = self.factory_registry_plugin.mco_factories
mco_factories[0] = ProbeMCOFactory(
None,
communicator_class=OneValueMCOCommunicator)
driver = CoreEvaluationDriver(
application=self.mock_application)
with self.assertRaisesRegexp(
RuntimeError,
"The number of data values returned by the MCO"):
driver.application_started()
def test_error_for_incorrect_output_slots(self):
factory = self.factory_registry_plugin.data_source_factories[0]
with mock.patch.object(factory.__class__,
"create_data_source") as create_ds:
create_ds.return_value = BrokenOneValueDataSource(factory)
driver = CoreEvaluationDriver(
application=self.mock_application,
)
with self.assertRaisesRegexp(
RuntimeError,
"The number of data values \(1 values\)"
" returned by 'test_data_source' does not match"
" the number of output slots"):
driver.application_started()
data_source_factories = \
self.factory_registry_plugin.data_source_factories
def run(self, *args, **kwargs):
return [DataValue()]
data_source_factories[0] = ProbeDataSourceFactory(
None,
run_function=run)
driver = CoreEvaluationDriver(
application=self.mock_application)
with self.assertRaisesRegexp(
RuntimeError,
"The number of data values \(1 values\)"
" returned by 'test_data_source' does not match"
" the number of output slots"):
driver.application_started()
def test_error_for_missing_ds_output_names(self):
factory = self.factory_registry_plugin.data_source_factories[0]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment