From 31ccc154fa37772a05f3bf5ab734da6c4608fd58 Mon Sep 17 00:00:00 2001
From: martinRenou <martin.renou@gmail.com>
Date: Thu, 24 Aug 2017 15:40:12 +0100
Subject: [PATCH] Start to clean the tests

---
 force_bdss/mco/base_mco_communicator.py       |   4 +-
 .../tests/test_core_evaluation_driver.py      | 193 +++---------------
 2 files changed, 37 insertions(+), 160 deletions(-)

diff --git a/force_bdss/mco/base_mco_communicator.py b/force_bdss/mco/base_mco_communicator.py
index 22f0233..0f53bec 100644
--- a/force_bdss/mco/base_mco_communicator.py
+++ b/force_bdss/mco/base_mco_communicator.py
@@ -19,9 +19,11 @@ class BaseMCOCommunicator(ABCHasStrictTraits):
     #: A reference to the factory
     factory = Instance(IMCOFactory)
 
-    def __init__(self, factory):
+    def __init__(self, factory, *args, **kwargs):
         self.factory = factory
 
+        super(BaseMCOCommunicator, self).__init__(*args, **kwargs)
+
     @abc.abstractmethod
     def receive_from_mco(self, model):
         """
diff --git a/force_bdss/tests/test_core_evaluation_driver.py b/force_bdss/tests/test_core_evaluation_driver.py
index 78d044d..59b23f4 100644
--- a/force_bdss/tests/test_core_evaluation_driver.py
+++ b/force_bdss/tests/test_core_evaluation_driver.py
@@ -4,6 +4,11 @@ from traits.api import Float
 
 from force_bdss.tests.probe_classes.factory_registry_plugin import \
     ProbeFactoryRegistryPlugin
+from force_bdss.tests.probe_classes.mco import (
+    ProbeMCOCommunicator, ProbeMCOFactory)
+from force_bdss.tests.probe_classes.data_source import (
+    ProbeDataSourceFactory)
+
 
 from force_bdss.core.input_slot_map import InputSlotMap
 from force_bdss.core.data_value import DataValue
@@ -12,25 +17,11 @@ from force_bdss.data_sources.base_data_source import BaseDataSource
 from force_bdss.data_sources.base_data_source_factory import \
     BaseDataSourceFactory
 from force_bdss.data_sources.base_data_source_model import BaseDataSourceModel
-from force_bdss.ids import mco_parameter_id, factory_id
+from force_bdss.ids import mco_parameter_id
 from force_bdss.kpi.base_kpi_calculator import BaseKPICalculator
-from force_bdss.kpi.base_kpi_calculator_factory import BaseKPICalculatorFactory
-from force_bdss.kpi.base_kpi_calculator_model import BaseKPICalculatorModel
-from force_bdss.mco.base_mco import BaseMCO
-from force_bdss.mco.base_mco_factory import BaseMCOFactory
-from force_bdss.mco.base_mco_communicator import BaseMCOCommunicator
-from force_bdss.mco.base_mco_model import BaseMCOModel
 from force_bdss.mco.parameters.base_mco_parameter import BaseMCOParameter
 from force_bdss.mco.parameters.base_mco_parameter_factory import \
     BaseMCOParameterFactory
-from force_bdss.notification_listeners.base_notification_listener import \
-    BaseNotificationListener
-from force_bdss.notification_listeners.base_notification_listener_factory \
-    import \
-    BaseNotificationListenerFactory
-from force_bdss.notification_listeners.base_notification_listener_model \
-    import \
-    BaseNotificationListenerModel
 from force_bdss.tests import fixtures
 
 try:
@@ -44,15 +35,6 @@ from force_bdss.core_evaluation_driver import CoreEvaluationDriver, \
     _bind_data_values, _compute_layer_results
 
 
-class NullMCOModel(BaseMCOModel):
-    pass
-
-
-class NullMCO(BaseMCO):
-    def run(self, model):
-        pass
-
-
 class RangedParameter(BaseMCOParameter):
     initial_value = Float()
     lower_bound = Float()
@@ -64,52 +46,10 @@ class RangedParameterFactory(BaseMCOParameterFactory):
     model_class = RangedParameter
 
 
-class NullMCOCommunicator(BaseMCOCommunicator):
-    def send_to_mco(self, model, kpi_results):
-        pass
-
-    def receive_from_mco(self, model):
-        return []
-
-
-class OneDataValueMCOCommunicator(BaseMCOCommunicator):
+class OneValueMCOCommunicator(ProbeMCOCommunicator):
     """A communicator that returns one single datavalue, for testing purposes.
     """
-    def send_to_mco(self, model, kpi_results):
-        pass
-
-    def receive_from_mco(self, model):
-        return [
-            DataValue()
-        ]
-
-
-class NullMCOFactory(BaseMCOFactory):
-    id = factory_id("enthought", "test_mco")
-
-    def create_model(self, model_data=None):
-        return NullMCOModel(self, **model_data)
-
-    def create_communicator(self):
-        return NullMCOCommunicator(self)
-
-    def create_optimizer(self):
-        return NullMCO(self)
-
-    def parameter_factories(self):
-        return []
-
-
-class NullKPICalculatorModel(BaseKPICalculatorModel):
-    pass
-
-
-class NullKPICalculator(BaseKPICalculator):
-    def run(self, model, data_source_results):
-        return []
-
-    def slots(self, model):
-        return (), ()
+    nb_output_data_values = 1
 
 
 class BrokenOneValueKPICalculator(BaseKPICalculator):
@@ -128,39 +68,10 @@ class OneValueKPICalculator(BaseKPICalculator):
         return (), (Slot(), )
 
 
-class NullKPICalculatorFactory(BaseKPICalculatorFactory):
-    id = factory_id("enthought", "test_kpi_calculator")
-    name = "test_kpi_calculator"
-
-    def create_model(self, model_data=None):
-        return NullKPICalculatorModel(self)
-
-    def create_kpi_calculator(self):
-        return NullKPICalculator(self)
-
-
 class NullDataSourceModel(BaseDataSourceModel):
     pass
 
 
-class NullDataSource(BaseDataSource):
-    def run(self, model, parameters):
-        return []
-
-    def slots(self, model):
-        return (), ()
-
-
-class BrokenOneValueDataSource(BaseDataSource):
-    """Incorrect data source implementation whose run returns a data value
-    but no slot was specified for it."""
-    def run(self, model, parameters):
-        return [DataValue()]
-
-    def slots(self, model):
-        return (), ()
-
-
 class OneValueDataSource(BaseDataSource):
     """Incorrect data source implementation whose run returns a data value
     but no slot was specified for it."""
@@ -186,43 +97,6 @@ class TwoInputsThreeOutputsDataSource(BaseDataSource):
         )
 
 
-class NullDataSourceFactory(BaseDataSourceFactory):
-    id = factory_id("enthought", "test_data_source")
-    name = "test_data_source"
-
-    def create_model(self, model_data=None):
-        return NullDataSourceModel(self)
-
-    def create_data_source(self):
-        return NullDataSource(self)
-
-
-class NullNotificationListener(BaseNotificationListener):
-    def initialize(self, model):
-        pass
-
-    def deliver(self, event):
-        pass
-
-    def finalize(self):
-        pass
-
-
-class NullNotificationListenerModel(BaseNotificationListenerModel):
-    pass
-
-
-class NullNotificationListenerFactory(BaseNotificationListenerFactory):
-    id = factory_id("enthought", "null_nl")
-    name = "null_nl"
-
-    def create_listener(self):
-        return NullNotificationListener(self)
-
-    def create_model(self, model_data=None):
-        return NullNotificationListenerModel(self)
-
-
 class TestCoreEvaluationDriver(unittest.TestCase):
     def setUp(self):
         self.factory_registry_plugin = ProbeFactoryRegistryPlugin()
@@ -240,33 +114,34 @@ class TestCoreEvaluationDriver(unittest.TestCase):
         driver.application_started()
 
     def test_error_for_non_matching_mco_parameters(self):
-        factory = self.factory_registry_plugin.mco_factories[0]
-        with mock.patch.object(factory.__class__,
-                               "create_communicator") as create_comm:
-            create_comm.return_value = OneDataValueMCOCommunicator(
-                factory)
-            driver = CoreEvaluationDriver(
-                application=self.mock_application,
-            )
-            with self.assertRaisesRegexp(
-                    RuntimeError,
-                    "The number of data values returned by the MCO"):
-                driver.application_started()
+        mco_factories = self.factory_registry_plugin.mco_factories
+        mco_factories[0] = ProbeMCOFactory(
+            None,
+            communicator_class=OneValueMCOCommunicator)
+        driver = CoreEvaluationDriver(
+            application=self.mock_application)
+        with self.assertRaisesRegexp(
+                RuntimeError,
+                "The number of data values returned by the MCO"):
+            driver.application_started()
 
     def test_error_for_incorrect_output_slots(self):
-        factory = self.factory_registry_plugin.data_source_factories[0]
-        with mock.patch.object(factory.__class__,
-                               "create_data_source") as create_ds:
-            create_ds.return_value = BrokenOneValueDataSource(factory)
-            driver = CoreEvaluationDriver(
-                application=self.mock_application,
-            )
-            with self.assertRaisesRegexp(
-                    RuntimeError,
-                    "The number of data values \(1 values\)"
-                    " returned by 'test_data_source' does not match"
-                    " the number of output slots"):
-                driver.application_started()
+        data_source_factories = \
+            self.factory_registry_plugin.data_source_factories
+
+        def run(self, *args, **kwargs):
+            return [DataValue()]
+        data_source_factories[0] = ProbeDataSourceFactory(
+            None,
+            run_function=run)
+        driver = CoreEvaluationDriver(
+            application=self.mock_application)
+        with self.assertRaisesRegexp(
+                RuntimeError,
+                "The number of data values \(1 values\)"
+                " returned by 'test_data_source' does not match"
+                " the number of output slots"):
+            driver.application_started()
 
     def test_error_for_missing_ds_output_names(self):
         factory = self.factory_registry_plugin.data_source_factories[0]
-- 
GitLab