From 8cf53e004ee0fd9d66c219613c9177d26e9189b8 Mon Sep 17 00:00:00 2001
From: Stefano Borini <sborini@enthought.com>
Date: Tue, 8 Aug 2017 15:52:08 +0100
Subject: [PATCH] Extracted methods as functions and tested their behavior with
 the new change

---
 force_bdss/core_evaluation_driver.py          | 303 +++++++++---------
 .../tests/test_core_evaluation_driver.py      | 100 +++++-
 2 files changed, 255 insertions(+), 148 deletions(-)

diff --git a/force_bdss/core_evaluation_driver.py b/force_bdss/core_evaluation_driver.py
index a9779ea..f0347cf 100644
--- a/force_bdss/core_evaluation_driver.py
+++ b/force_bdss/core_evaluation_driver.py
@@ -33,16 +33,16 @@ class CoreEvaluationDriver(BaseCoreDriver):
         mco_factory = mco_model.factory
         mco_communicator = mco_factory.create_communicator()
 
-        mco_data_values = self._get_data_values_from_mco(mco_model,
-                                                         mco_communicator)
+        mco_data_values = _get_data_values_from_mco(
+            mco_model, mco_communicator)
 
-        ds_results = self._compute_layer_results(
+        ds_results = _compute_layer_results(
             mco_data_values,
             workflow.data_sources,
             "create_data_source"
         )
 
-        kpi_results = self._compute_layer_results(
+        kpi_results = _compute_layer_results(
             ds_results + mco_data_values,
             workflow.kpi_calculators,
             "create_kpi_calculator"
@@ -50,153 +50,162 @@ class CoreEvaluationDriver(BaseCoreDriver):
 
         mco_communicator.send_to_mco(mco_model, kpi_results)
 
-    def _compute_layer_results(self,
-                               environment_data_values,
-                               evaluator_models,
-                               creator_method_name
-                               ):
-        """Helper routine.
-        Performs the evaluation of a single layer.
-        At the moment we have a single layer of DataSources followed
-        by a single layer of KPI calculators.
-
-        Parameters
-        ----------
-        environment_data_values: list
-            A list of data values to submit to the evaluators.
-
-        evaluator_models: list
-            A list of the models for all the evaluators (data source
-            or kpi calculator)
-
-        creator_method_name: str
-            A string of the creator method for the evaluator on the
-            factory (e.g. create_kpi_calculator)
-
-        NOTE: The above parameter is going to go away as soon as we move
-        to unlimited layers and remove the distinction between data sources
-        and KPI calculators.
-        """
-        results = []
-
-        for model in evaluator_models:
-            factory = model.factory
-            evaluator = getattr(factory, creator_method_name)()
-
-            # Get the slots for this data source. These must be matched to
-            # the appropriate values in the environment data values.
-            # Matching is by position.
-            in_slots, out_slots = evaluator.slots(model)
-
-            # Binding performs the extraction of the specified data values
-            # satisfying the above input slots from the environment data values
-            # considering what the user specified in terms of names (which is
-            # in the model input slot maps.
-            # The resulting data are the ones picked by name from the
-            # environment data values, and in the appropriate ordering as
-            # needed by the input slots.
-            passed_data_values = self._bind_data_values(
-                environment_data_values,
-                model.input_slot_maps,
-                in_slots)
-
-            # execute data source, passing only relevant data values.
-            logging.info("Evaluating for Data Source {}".format(
-                factory.name))
-            res = evaluator.run(model, passed_data_values)
-
-            if len(res) != len(out_slots):
-                error_txt = (
-                    "The number of data values ({} values) returned"
-                    " by '{}' does not match the number"
-                    " of output slots it specifies ({} values)."
-                    " This is likely a plugin error.").format(
-                    len(res), factory.name, len(out_slots)
-                )
-
-                logging.error(error_txt)
-                raise RuntimeError(error_txt)
-
-            if len(res) != len(model.output_slot_names):
-                error_txt = (
-                    "The number of data values ({} values) returned"
-                    " by '{}' does not match the number"
-                    " of user-defined names specified ({} values)."
-                    " This is either a plugin error or a file"
-                    " error.").format(
-                    len(res),
-                    factory.name,
-                    len(model.output_slot_names)
-                )
-
-                logging.error(error_txt)
-                raise RuntimeError(error_txt)
-
-            # At this point, the returned data values are unnamed.
-            # Add the names as specified by the user.
-            for dv, output_slot_name in zip(res, model.output_slot_names):
-                dv.name = output_slot_name
-
-            # If the name was not specified, simply discard the value,
-            # because apparently the user is not interested in it.
-            results.extend([r for r in res if r.name != ""])
-
-        # Finally, return all the computed data values from all evaluators,
-        # properly named.
-        return results
-
-    def _get_data_values_from_mco(self, model, communicator):
-        """Helper method.
-        Receives the data (in order) from the MCO, and bind them to the
-        specified names as from the model.
-
-        Parameters
-        ----------
-        model: BaseMCOModel
-            the MCO model (where the user-defined variable names are specified)
-        communicator: BaseMCOCommunicator
-            The communicator that produces the (temporarily unnamed) datavalues
-            from the MCO.
-        """
-        mco_data_values = communicator.receive_from_mco(model)
-
-        if len(mco_data_values) != len(model.parameters):
-            error_txt = ("The number of data values returned by"
-                         " the MCO ({} values) does not match the"
-                         " number of parameters specified ({} values)."
-                         " This is either a MCO plugin error or the workflow"
-                         " file is corrupted.").format(
-                len(mco_data_values), len(model.parameters)
+
+def _compute_layer_results(environment_data_values,
+                           evaluator_models,
+                           creator_method_name
+                           ):
+    """Helper routine.
+    Performs the evaluation of a single layer.
+    At the moment we have a single layer of DataSources followed
+    by a single layer of KPI calculators.
+
+    Parameters
+    ----------
+    environment_data_values: list
+        A list of data values to submit to the evaluators.
+
+    evaluator_models: list
+        A list of the models for all the evaluators (data source
+        or kpi calculator)
+
+    creator_method_name: str
+        A string of the creator method for the evaluator on the
+        factory (e.g. create_kpi_calculator)
+
+    NOTE: The above parameter is going to go away as soon as we move
+    to unlimited layers and remove the distinction between data sources
+    and KPI calculators.
+    """
+    results = []
+
+    for model in evaluator_models:
+        factory = model.factory
+        evaluator = getattr(factory, creator_method_name)()
+
+        # Get the slots for this data source. These must be matched to
+        # the appropriate values in the environment data values.
+        # Matching is by position.
+        in_slots, out_slots = evaluator.slots(model)
+
+        # Binding performs the extraction of the specified data values
+        # satisfying the above input slots from the environment data values
+        # considering what the user specified in terms of names (which is
+        # in the model input slot maps.
+        # The resulting data are the ones picked by name from the
+        # environment data values, and in the appropriate ordering as
+        # needed by the input slots.
+        passed_data_values = _bind_data_values(
+            environment_data_values,
+            model.input_slot_maps,
+            in_slots)
+
+        # execute data source, passing only relevant data values.
+        logging.info("Evaluating for Data Source {}".format(
+            factory.name))
+        res = evaluator.run(model, passed_data_values)
+
+        if len(res) != len(out_slots):
+            error_txt = (
+                "The number of data values ({} values) returned"
+                " by '{}' does not match the number"
+                " of output slots it specifies ({} values)."
+                " This is likely a plugin error.").format(
+                len(res), factory.name, len(out_slots)
+            )
+
+            logging.error(error_txt)
+            raise RuntimeError(error_txt)
+
+        if len(res) != len(model.output_slot_names):
+            error_txt = (
+                "The number of data values ({} values) returned"
+                " by '{}' does not match the number"
+                " of user-defined names specified ({} values)."
+                " This is either a plugin error or a file"
+                " error.").format(
+                len(res),
+                factory.name,
+                len(model.output_slot_names)
             )
+
             logging.error(error_txt)
             raise RuntimeError(error_txt)
 
-        # The data values obtained by the communicator are unnamed.
-        # Assign the name to each datavalue as specified by the user.
-        for dv, param in zip(mco_data_values, model.parameters):
-            dv.name = param.name
-
-        return mco_data_values
-
-    def _bind_data_values(self,
-                          available_data_values,
-                          model_slot_map,
-                          slots):
-        """
-        Given the named data values in the environment, the slots a given
-        data source expects, and the user-specified names for each of these
-        slots, returns those data values with the requested names, ordered
-        in the correct order as specified by the slot map.
-        """
-        passed_data_values = []
-        lookup_map = {dv.name: dv for dv in available_data_values}
-
-        if len(slots) != len(model_slot_map):
-            raise RuntimeError("The length of the slots is not equal to"
-                               " the length of the slot map. This may"
-                               " indicate a file error.")
+        # At this point, the returned data values are unnamed.
+        # Add the names as specified by the user.
+        for dv, output_slot_name in zip(res, model.output_slot_names):
+            dv.name = output_slot_name
+
+        # If the name was not specified, simply discard the value,
+        # because apparently the user is not interested in it.
+        results.extend([r for r in res if r.name != ""])
+
+    # Finally, return all the computed data values from all evaluators,
+    # properly named.
+    return results
 
+
+def _get_data_values_from_mco(model, communicator):
+    """Helper method.
+    Receives the data (in order) from the MCO, and bind them to the
+    specified names as from the model.
+
+    Parameters
+    ----------
+    model: BaseMCOModel
+        the MCO model (where the user-defined variable names are specified)
+    communicator: BaseMCOCommunicator
+        The communicator that produces the (temporarily unnamed) datavalues
+        from the MCO.
+    """
+    mco_data_values = communicator.receive_from_mco(model)
+
+    if len(mco_data_values) != len(model.parameters):
+        error_txt = ("The number of data values returned by"
+                     " the MCO ({} values) does not match the"
+                     " number of parameters specified ({} values)."
+                     " This is either a MCO plugin error or the workflow"
+                     " file is corrupted.").format(
+            len(mco_data_values), len(model.parameters)
+        )
+        logging.error(error_txt)
+        raise RuntimeError(error_txt)
+
+    # The data values obtained by the communicator are unnamed.
+    # Assign the name to each datavalue as specified by the user.
+    for dv, param in zip(mco_data_values, model.parameters):
+        dv.name = param.name
+
+    # Exclude those who have no name set.
+    return [dv for dv in mco_data_values if dv.name != ""]
+
+
+def _bind_data_values(available_data_values,
+                      model_slot_map,
+                      slots):
+    """
+    Given the named data values in the environment, the slots a given
+    data source expects, and the user-specified names for each of these
+    slots, returns those data values with the requested names, ordered
+    in the correct order as specified by the slot map.
+    """
+    passed_data_values = []
+    lookup_map = {dv.name: dv for dv in available_data_values}
+
+    if len(slots) != len(model_slot_map):
+        raise RuntimeError("The length of the slots is not equal to"
+                           " the length of the slot map. This may"
+                           " indicate a file error.")
+
+    try:
         for slot, slot_map in zip(slots, model_slot_map):
             passed_data_values.append(lookup_map[slot_map.name])
-
-        return passed_data_values
+    except KeyError:
+        raise RuntimeError(
+            "Unable to find requested name '{}' in available "
+            "data values. Current data value names: {}".format(
+                slot_map.name,
+                list(lookup_map.keys())))
+
+    return passed_data_values
diff --git a/force_bdss/tests/test_core_evaluation_driver.py b/force_bdss/tests/test_core_evaluation_driver.py
index 5bd4a14..07ac470 100644
--- a/force_bdss/tests/test_core_evaluation_driver.py
+++ b/force_bdss/tests/test_core_evaluation_driver.py
@@ -1,5 +1,8 @@
 import unittest
+
 from traits.api import Float, List
+
+from force_bdss.core.input_slot_map import InputSlotMap
 from force_bdss.factory_registry_plugin import FactoryRegistryPlugin
 from force_bdss.core.data_value import DataValue
 from force_bdss.core.slot import Slot
@@ -27,7 +30,8 @@ except ImportError:
 
 from envisage.api import Application
 
-from force_bdss.core_evaluation_driver import CoreEvaluationDriver
+from force_bdss.core_evaluation_driver import CoreEvaluationDriver, \
+    _bind_data_values, _compute_layer_results
 
 
 class NullMCOModel(BaseMCOModel):
@@ -159,6 +163,19 @@ class OneValueDataSource(BaseDataSource):
         )
 
 
+class TwoInputsThreeOutputsDataSource(BaseDataSource):
+    """Incorrect data source implementation whose run returns a data value
+    but no slot was specified for it."""
+    def run(self, model, parameters):
+        return [DataValue(value=1), DataValue(value=2), DataValue(value=3)]
+
+    def slots(self, model):
+        return (
+            (Slot(), Slot()),
+            (Slot(), Slot(), Slot())
+        )
+
+
 class NullDataSourceFactory(BaseDataSourceFactory):
     id = factory_id("enthought", "null_ds")
     name = "null_ds"
@@ -276,3 +293,84 @@ class TestCoreEvaluationDriver(unittest.TestCase):
                     " returned by 'null_kpic' does not match"
                     " the number of user-defined names"):
                 driver.application_started()
+
+    def test_bind_data_values(self):
+        data_values = [
+            DataValue(name="foo"),
+            DataValue(name="bar"),
+            DataValue(name="baz")
+        ]
+
+        slot_map = (
+            InputSlotMap(name="baz"),
+            InputSlotMap(name="bar")
+        )
+
+        slots = (
+            Slot(),
+            Slot()
+        )
+
+        result = _bind_data_values(data_values, slot_map, slots)
+        self.assertEqual(result[0], data_values[2])
+        self.assertEqual(result[1], data_values[1])
+
+        # Check the errors. Only one slot map for two slots.
+        slot_map = (
+            InputSlotMap(name="baz"),
+        )
+
+        with self.assertRaisesRegexp(
+                RuntimeError,
+                "The length of the slots is not equal to the length of"
+                " the slot map"):
+            _bind_data_values(data_values, slot_map, slots)
+
+        # missing value in the given data values.
+        slot_map = (
+            InputSlotMap(name="blap"),
+            InputSlotMap(name="bar")
+        )
+
+        with self.assertRaisesRegexp(
+                RuntimeError,
+                "Unable to find requested name 'blap' in available"
+                " data values."):
+            _bind_data_values(data_values, slot_map, slots)
+
+    def test_compute_layer_results(self):
+
+        data_values = [
+            DataValue(name="foo"),
+            DataValue(name="bar"),
+            DataValue(name="baz"),
+            DataValue(name="quux")
+        ]
+
+        mock_ds_factory = mock.Mock(spec=BaseDataSourceFactory)
+        mock_ds_factory.name = "mock factory"
+        mock_ds_factory.create_data_source.return_value = \
+            TwoInputsThreeOutputsDataSource(mock_ds_factory)
+        evaluator_model = NullDataSourceModel(factory=mock_ds_factory)
+
+        evaluator_model.input_slot_maps = [
+            InputSlotMap(name="foo"),
+            InputSlotMap(name="quux")
+        ]
+        evaluator_model.output_slot_names = ["one", "", "three"]
+
+        res = _compute_layer_results(
+            data_values,
+            [evaluator_model],
+            "create_data_source"
+        )
+        self.assertEqual(len(res), 2)
+        self.assertEqual(res[0].name, "one")
+        self.assertEqual(res[0].value, 1)
+        self.assertEqual(res[1].name, "three")
+        self.assertEqual(res[1].value, 3)
+
+    def test_empty_slot_name_skips_data_value(self):
+        """Checks if leaving a slot name empty will skip the data value
+        in the final output
+        """
-- 
GitLab