diff --git a/force_bdss/tests/test_core_evaluation_driver.py b/force_bdss/tests/test_core_evaluation_driver.py index 59b23f4fa2f2c936b987eef7f1ebe55b4a879008..1bbd5356b0b0f673aa32a0f9d944f361f5cd9e0b 100644 --- a/force_bdss/tests/test_core_evaluation_driver.py +++ b/force_bdss/tests/test_core_evaluation_driver.py @@ -1,27 +1,18 @@ import unittest -from traits.api import Float - from force_bdss.tests.probe_classes.factory_registry_plugin import \ ProbeFactoryRegistryPlugin from force_bdss.tests.probe_classes.mco import ( ProbeMCOCommunicator, ProbeMCOFactory) from force_bdss.tests.probe_classes.data_source import ( ProbeDataSourceFactory) +from force_bdss.tests.probe_classes.kpi_calculator import ( + ProbeKPICalculatorFactory) from force_bdss.core.input_slot_map import InputSlotMap from force_bdss.core.data_value import DataValue from force_bdss.core.slot import Slot -from force_bdss.data_sources.base_data_source import BaseDataSource -from force_bdss.data_sources.base_data_source_factory import \ - BaseDataSourceFactory -from force_bdss.data_sources.base_data_source_model import BaseDataSourceModel -from force_bdss.ids import mco_parameter_id -from force_bdss.kpi.base_kpi_calculator import BaseKPICalculator -from force_bdss.mco.parameters.base_mco_parameter import BaseMCOParameter -from force_bdss.mco.parameters.base_mco_parameter_factory import \ - BaseMCOParameterFactory from force_bdss.tests import fixtures try: @@ -35,68 +26,12 @@ from force_bdss.core_evaluation_driver import CoreEvaluationDriver, \ _bind_data_values, _compute_layer_results -class RangedParameter(BaseMCOParameter): - initial_value = Float() - lower_bound = Float() - upper_bound = Float() - - -class RangedParameterFactory(BaseMCOParameterFactory): - id = mco_parameter_id("enthought", "null_mco", "null") - model_class = RangedParameter - - class OneValueMCOCommunicator(ProbeMCOCommunicator): """A communicator that returns one single datavalue, for testing purposes. """ nb_output_data_values = 1 -class BrokenOneValueKPICalculator(BaseKPICalculator): - def run(self, model, data_source_results): - return [DataValue()] - - def slots(self, model): - return (), () - - -class OneValueKPICalculator(BaseKPICalculator): - def run(self, model, data_source_results): - return [DataValue()] - - def slots(self, model): - return (), (Slot(), ) - - -class NullDataSourceModel(BaseDataSourceModel): - pass - - -class OneValueDataSource(BaseDataSource): - """Incorrect data source implementation whose run returns a data value - but no slot was specified for it.""" - def run(self, model, parameters): - return [DataValue()] - - def slots(self, model): - return (), ( - Slot(), - ) - - -class TwoInputsThreeOutputsDataSource(BaseDataSource): - """Incorrect data source implementation whose run returns a data value - but no slot was specified for it.""" - def run(self, model, parameters): - return [DataValue(value=1), DataValue(value=2), DataValue(value=3)] - - def slots(self, model): - return ( - (Slot(), Slot()), - (Slot(), Slot(), Slot()) - ) - - class TestCoreEvaluationDriver(unittest.TestCase): def setUp(self): self.factory_registry_plugin = ProbeFactoryRegistryPlugin() @@ -144,49 +79,63 @@ class TestCoreEvaluationDriver(unittest.TestCase): driver.application_started() def test_error_for_missing_ds_output_names(self): - factory = self.factory_registry_plugin.data_source_factories[0] - with mock.patch.object(factory.__class__, - "create_data_source") as create_ds: - create_ds.return_value = OneValueDataSource(factory) - driver = CoreEvaluationDriver( - application=self.mock_application, - ) - with self.assertRaisesRegexp( - RuntimeError, - "The number of data values \(1 values\)" - " returned by 'test_data_source' does not match" - " the number of user-defined names"): - driver.application_started() + data_source_factories = \ + self.factory_registry_plugin.data_source_factories + + def run(self, *args, **kwargs): + return [DataValue()] + data_source_factories[0] = ProbeDataSourceFactory( + None, + run_function=run, + output_slots_size=1) + driver = CoreEvaluationDriver( + application=self.mock_application, + ) + with self.assertRaisesRegexp( + RuntimeError, + "The number of data values \(1 values\)" + " returned by 'test_data_source' does not match" + " the number of user-defined names"): + driver.application_started() def test_error_for_incorrect_kpic_output_slots(self): - factory = self.factory_registry_plugin.kpi_calculator_factories[0] - with mock.patch.object(factory.__class__, - "create_kpi_calculator") as create_kpic: - create_kpic.return_value = BrokenOneValueKPICalculator(factory) - driver = CoreEvaluationDriver( - application=self.mock_application, - ) - with self.assertRaisesRegexp( - RuntimeError, - "The number of data values \(1 values\)" - " returned by 'test_kpi_calculator' does not match" - " the number of output slots"): - driver.application_started() + kpi_calculator_factories = \ + self.factory_registry_plugin.kpi_calculator_factories + + def run(self, *args, **kwargs): + return [DataValue()] + kpi_calculator_factories[0] = ProbeKPICalculatorFactory( + None, + run_function=run) + driver = CoreEvaluationDriver( + application=self.mock_application, + ) + with self.assertRaisesRegexp( + RuntimeError, + "The number of data values \(1 values\)" + " returned by 'test_kpi_calculator' does not match" + " the number of output slots"): + driver.application_started() def test_error_for_missing_kpic_output_names(self): - factory = self.factory_registry_plugin.kpi_calculator_factories[0] - with mock.patch.object(factory.__class__, - "create_kpi_calculator") as create_kpic: - create_kpic.return_value = OneValueKPICalculator(factory) - driver = CoreEvaluationDriver( - application=self.mock_application, - ) - with self.assertRaisesRegexp( - RuntimeError, - "The number of data values \(1 values\)" - " returned by 'test_kpi_calculator' does not match" - " the number of user-defined names"): - driver.application_started() + kpi_calculator_factories = \ + self.factory_registry_plugin.kpi_calculator_factories + + def run(self, *args, **kwargs): + return [DataValue()] + kpi_calculator_factories[0] = ProbeKPICalculatorFactory( + None, + run_function=run, + output_slots_size=1) + driver = CoreEvaluationDriver( + application=self.mock_application, + ) + with self.assertRaisesRegexp( + RuntimeError, + "The number of data values \(1 values\)" + " returned by 'test_kpi_calculator' does not match" + " the number of user-defined names"): + driver.application_started() def test_bind_data_values(self): data_values = [ @@ -233,7 +182,6 @@ class TestCoreEvaluationDriver(unittest.TestCase): _bind_data_values(data_values, slot_map, slots) def test_compute_layer_results(self): - data_values = [ DataValue(name="foo"), DataValue(name="bar"), @@ -241,11 +189,14 @@ class TestCoreEvaluationDriver(unittest.TestCase): DataValue(name="quux") ] - mock_ds_factory = mock.Mock(spec=BaseDataSourceFactory) - mock_ds_factory.name = "mock factory" - mock_ds_factory.create_data_source.return_value = \ - TwoInputsThreeOutputsDataSource(mock_ds_factory) - evaluator_model = NullDataSourceModel(factory=mock_ds_factory) + def run(self, *args, **kwargs): + return [DataValue(value=1), DataValue(value=2), DataValue(value=3)] + ds_factory = ProbeDataSourceFactory( + None, + input_slots_size=2, + output_slots_size=3, + run_function=run) + evaluator_model = ds_factory.create_model() evaluator_model.input_slot_maps = [ InputSlotMap(name="foo"),