From 31ccc154fa37772a05f3bf5ab734da6c4608fd58 Mon Sep 17 00:00:00 2001 From: martinRenou <martin.renou@gmail.com> Date: Thu, 24 Aug 2017 15:40:12 +0100 Subject: [PATCH] Start to clean the tests --- force_bdss/mco/base_mco_communicator.py | 4 +- .../tests/test_core_evaluation_driver.py | 193 +++--------------- 2 files changed, 37 insertions(+), 160 deletions(-) diff --git a/force_bdss/mco/base_mco_communicator.py b/force_bdss/mco/base_mco_communicator.py index 22f0233..0f53bec 100644 --- a/force_bdss/mco/base_mco_communicator.py +++ b/force_bdss/mco/base_mco_communicator.py @@ -19,9 +19,11 @@ class BaseMCOCommunicator(ABCHasStrictTraits): #: A reference to the factory factory = Instance(IMCOFactory) - def __init__(self, factory): + def __init__(self, factory, *args, **kwargs): self.factory = factory + super(BaseMCOCommunicator, self).__init__(*args, **kwargs) + @abc.abstractmethod def receive_from_mco(self, model): """ diff --git a/force_bdss/tests/test_core_evaluation_driver.py b/force_bdss/tests/test_core_evaluation_driver.py index 78d044d..59b23f4 100644 --- a/force_bdss/tests/test_core_evaluation_driver.py +++ b/force_bdss/tests/test_core_evaluation_driver.py @@ -4,6 +4,11 @@ from traits.api import Float from force_bdss.tests.probe_classes.factory_registry_plugin import \ ProbeFactoryRegistryPlugin +from force_bdss.tests.probe_classes.mco import ( + ProbeMCOCommunicator, ProbeMCOFactory) +from force_bdss.tests.probe_classes.data_source import ( + ProbeDataSourceFactory) + from force_bdss.core.input_slot_map import InputSlotMap from force_bdss.core.data_value import DataValue @@ -12,25 +17,11 @@ from force_bdss.data_sources.base_data_source import BaseDataSource from force_bdss.data_sources.base_data_source_factory import \ BaseDataSourceFactory from force_bdss.data_sources.base_data_source_model import BaseDataSourceModel -from force_bdss.ids import mco_parameter_id, factory_id +from force_bdss.ids import mco_parameter_id from force_bdss.kpi.base_kpi_calculator import BaseKPICalculator -from force_bdss.kpi.base_kpi_calculator_factory import BaseKPICalculatorFactory -from force_bdss.kpi.base_kpi_calculator_model import BaseKPICalculatorModel -from force_bdss.mco.base_mco import BaseMCO -from force_bdss.mco.base_mco_factory import BaseMCOFactory -from force_bdss.mco.base_mco_communicator import BaseMCOCommunicator -from force_bdss.mco.base_mco_model import BaseMCOModel from force_bdss.mco.parameters.base_mco_parameter import BaseMCOParameter from force_bdss.mco.parameters.base_mco_parameter_factory import \ BaseMCOParameterFactory -from force_bdss.notification_listeners.base_notification_listener import \ - BaseNotificationListener -from force_bdss.notification_listeners.base_notification_listener_factory \ - import \ - BaseNotificationListenerFactory -from force_bdss.notification_listeners.base_notification_listener_model \ - import \ - BaseNotificationListenerModel from force_bdss.tests import fixtures try: @@ -44,15 +35,6 @@ from force_bdss.core_evaluation_driver import CoreEvaluationDriver, \ _bind_data_values, _compute_layer_results -class NullMCOModel(BaseMCOModel): - pass - - -class NullMCO(BaseMCO): - def run(self, model): - pass - - class RangedParameter(BaseMCOParameter): initial_value = Float() lower_bound = Float() @@ -64,52 +46,10 @@ class RangedParameterFactory(BaseMCOParameterFactory): model_class = RangedParameter -class NullMCOCommunicator(BaseMCOCommunicator): - def send_to_mco(self, model, kpi_results): - pass - - def receive_from_mco(self, model): - return [] - - -class OneDataValueMCOCommunicator(BaseMCOCommunicator): +class OneValueMCOCommunicator(ProbeMCOCommunicator): """A communicator that returns one single datavalue, for testing purposes. """ - def send_to_mco(self, model, kpi_results): - pass - - def receive_from_mco(self, model): - return [ - DataValue() - ] - - -class NullMCOFactory(BaseMCOFactory): - id = factory_id("enthought", "test_mco") - - def create_model(self, model_data=None): - return NullMCOModel(self, **model_data) - - def create_communicator(self): - return NullMCOCommunicator(self) - - def create_optimizer(self): - return NullMCO(self) - - def parameter_factories(self): - return [] - - -class NullKPICalculatorModel(BaseKPICalculatorModel): - pass - - -class NullKPICalculator(BaseKPICalculator): - def run(self, model, data_source_results): - return [] - - def slots(self, model): - return (), () + nb_output_data_values = 1 class BrokenOneValueKPICalculator(BaseKPICalculator): @@ -128,39 +68,10 @@ class OneValueKPICalculator(BaseKPICalculator): return (), (Slot(), ) -class NullKPICalculatorFactory(BaseKPICalculatorFactory): - id = factory_id("enthought", "test_kpi_calculator") - name = "test_kpi_calculator" - - def create_model(self, model_data=None): - return NullKPICalculatorModel(self) - - def create_kpi_calculator(self): - return NullKPICalculator(self) - - class NullDataSourceModel(BaseDataSourceModel): pass -class NullDataSource(BaseDataSource): - def run(self, model, parameters): - return [] - - def slots(self, model): - return (), () - - -class BrokenOneValueDataSource(BaseDataSource): - """Incorrect data source implementation whose run returns a data value - but no slot was specified for it.""" - def run(self, model, parameters): - return [DataValue()] - - def slots(self, model): - return (), () - - class OneValueDataSource(BaseDataSource): """Incorrect data source implementation whose run returns a data value but no slot was specified for it.""" @@ -186,43 +97,6 @@ class TwoInputsThreeOutputsDataSource(BaseDataSource): ) -class NullDataSourceFactory(BaseDataSourceFactory): - id = factory_id("enthought", "test_data_source") - name = "test_data_source" - - def create_model(self, model_data=None): - return NullDataSourceModel(self) - - def create_data_source(self): - return NullDataSource(self) - - -class NullNotificationListener(BaseNotificationListener): - def initialize(self, model): - pass - - def deliver(self, event): - pass - - def finalize(self): - pass - - -class NullNotificationListenerModel(BaseNotificationListenerModel): - pass - - -class NullNotificationListenerFactory(BaseNotificationListenerFactory): - id = factory_id("enthought", "null_nl") - name = "null_nl" - - def create_listener(self): - return NullNotificationListener(self) - - def create_model(self, model_data=None): - return NullNotificationListenerModel(self) - - class TestCoreEvaluationDriver(unittest.TestCase): def setUp(self): self.factory_registry_plugin = ProbeFactoryRegistryPlugin() @@ -240,33 +114,34 @@ class TestCoreEvaluationDriver(unittest.TestCase): driver.application_started() def test_error_for_non_matching_mco_parameters(self): - factory = self.factory_registry_plugin.mco_factories[0] - with mock.patch.object(factory.__class__, - "create_communicator") as create_comm: - create_comm.return_value = OneDataValueMCOCommunicator( - factory) - driver = CoreEvaluationDriver( - application=self.mock_application, - ) - with self.assertRaisesRegexp( - RuntimeError, - "The number of data values returned by the MCO"): - driver.application_started() + mco_factories = self.factory_registry_plugin.mco_factories + mco_factories[0] = ProbeMCOFactory( + None, + communicator_class=OneValueMCOCommunicator) + driver = CoreEvaluationDriver( + application=self.mock_application) + with self.assertRaisesRegexp( + RuntimeError, + "The number of data values returned by the MCO"): + driver.application_started() def test_error_for_incorrect_output_slots(self): - factory = self.factory_registry_plugin.data_source_factories[0] - with mock.patch.object(factory.__class__, - "create_data_source") as create_ds: - create_ds.return_value = BrokenOneValueDataSource(factory) - driver = CoreEvaluationDriver( - application=self.mock_application, - ) - with self.assertRaisesRegexp( - RuntimeError, - "The number of data values \(1 values\)" - " returned by 'test_data_source' does not match" - " the number of output slots"): - driver.application_started() + data_source_factories = \ + self.factory_registry_plugin.data_source_factories + + def run(self, *args, **kwargs): + return [DataValue()] + data_source_factories[0] = ProbeDataSourceFactory( + None, + run_function=run) + driver = CoreEvaluationDriver( + application=self.mock_application) + with self.assertRaisesRegexp( + RuntimeError, + "The number of data values \(1 values\)" + " returned by 'test_data_source' does not match" + " the number of output slots"): + driver.application_started() def test_error_for_missing_ds_output_names(self): factory = self.factory_registry_plugin.data_source_factories[0] -- GitLab