diff --git a/force_bdss/core_evaluation_driver.py b/force_bdss/core_evaluation_driver.py index a9779ea0be879b56b6d21f2d31e8696127160326..f0347cff034a10ea8fe54339a9ce41048507fff9 100644 --- a/force_bdss/core_evaluation_driver.py +++ b/force_bdss/core_evaluation_driver.py @@ -33,16 +33,16 @@ class CoreEvaluationDriver(BaseCoreDriver): mco_factory = mco_model.factory mco_communicator = mco_factory.create_communicator() - mco_data_values = self._get_data_values_from_mco(mco_model, - mco_communicator) + mco_data_values = _get_data_values_from_mco( + mco_model, mco_communicator) - ds_results = self._compute_layer_results( + ds_results = _compute_layer_results( mco_data_values, workflow.data_sources, "create_data_source" ) - kpi_results = self._compute_layer_results( + kpi_results = _compute_layer_results( ds_results + mco_data_values, workflow.kpi_calculators, "create_kpi_calculator" @@ -50,153 +50,162 @@ class CoreEvaluationDriver(BaseCoreDriver): mco_communicator.send_to_mco(mco_model, kpi_results) - def _compute_layer_results(self, - environment_data_values, - evaluator_models, - creator_method_name - ): - """Helper routine. - Performs the evaluation of a single layer. - At the moment we have a single layer of DataSources followed - by a single layer of KPI calculators. - - Parameters - ---------- - environment_data_values: list - A list of data values to submit to the evaluators. - - evaluator_models: list - A list of the models for all the evaluators (data source - or kpi calculator) - - creator_method_name: str - A string of the creator method for the evaluator on the - factory (e.g. create_kpi_calculator) - - NOTE: The above parameter is going to go away as soon as we move - to unlimited layers and remove the distinction between data sources - and KPI calculators. - """ - results = [] - - for model in evaluator_models: - factory = model.factory - evaluator = getattr(factory, creator_method_name)() - - # Get the slots for this data source. These must be matched to - # the appropriate values in the environment data values. - # Matching is by position. - in_slots, out_slots = evaluator.slots(model) - - # Binding performs the extraction of the specified data values - # satisfying the above input slots from the environment data values - # considering what the user specified in terms of names (which is - # in the model input slot maps. - # The resulting data are the ones picked by name from the - # environment data values, and in the appropriate ordering as - # needed by the input slots. - passed_data_values = self._bind_data_values( - environment_data_values, - model.input_slot_maps, - in_slots) - - # execute data source, passing only relevant data values. - logging.info("Evaluating for Data Source {}".format( - factory.name)) - res = evaluator.run(model, passed_data_values) - - if len(res) != len(out_slots): - error_txt = ( - "The number of data values ({} values) returned" - " by '{}' does not match the number" - " of output slots it specifies ({} values)." - " This is likely a plugin error.").format( - len(res), factory.name, len(out_slots) - ) - - logging.error(error_txt) - raise RuntimeError(error_txt) - - if len(res) != len(model.output_slot_names): - error_txt = ( - "The number of data values ({} values) returned" - " by '{}' does not match the number" - " of user-defined names specified ({} values)." - " This is either a plugin error or a file" - " error.").format( - len(res), - factory.name, - len(model.output_slot_names) - ) - - logging.error(error_txt) - raise RuntimeError(error_txt) - - # At this point, the returned data values are unnamed. - # Add the names as specified by the user. - for dv, output_slot_name in zip(res, model.output_slot_names): - dv.name = output_slot_name - - # If the name was not specified, simply discard the value, - # because apparently the user is not interested in it. - results.extend([r for r in res if r.name != ""]) - - # Finally, return all the computed data values from all evaluators, - # properly named. - return results - - def _get_data_values_from_mco(self, model, communicator): - """Helper method. - Receives the data (in order) from the MCO, and bind them to the - specified names as from the model. - - Parameters - ---------- - model: BaseMCOModel - the MCO model (where the user-defined variable names are specified) - communicator: BaseMCOCommunicator - The communicator that produces the (temporarily unnamed) datavalues - from the MCO. - """ - mco_data_values = communicator.receive_from_mco(model) - - if len(mco_data_values) != len(model.parameters): - error_txt = ("The number of data values returned by" - " the MCO ({} values) does not match the" - " number of parameters specified ({} values)." - " This is either a MCO plugin error or the workflow" - " file is corrupted.").format( - len(mco_data_values), len(model.parameters) + +def _compute_layer_results(environment_data_values, + evaluator_models, + creator_method_name + ): + """Helper routine. + Performs the evaluation of a single layer. + At the moment we have a single layer of DataSources followed + by a single layer of KPI calculators. + + Parameters + ---------- + environment_data_values: list + A list of data values to submit to the evaluators. + + evaluator_models: list + A list of the models for all the evaluators (data source + or kpi calculator) + + creator_method_name: str + A string of the creator method for the evaluator on the + factory (e.g. create_kpi_calculator) + + NOTE: The above parameter is going to go away as soon as we move + to unlimited layers and remove the distinction between data sources + and KPI calculators. + """ + results = [] + + for model in evaluator_models: + factory = model.factory + evaluator = getattr(factory, creator_method_name)() + + # Get the slots for this data source. These must be matched to + # the appropriate values in the environment data values. + # Matching is by position. + in_slots, out_slots = evaluator.slots(model) + + # Binding performs the extraction of the specified data values + # satisfying the above input slots from the environment data values + # considering what the user specified in terms of names (which is + # in the model input slot maps. + # The resulting data are the ones picked by name from the + # environment data values, and in the appropriate ordering as + # needed by the input slots. + passed_data_values = _bind_data_values( + environment_data_values, + model.input_slot_maps, + in_slots) + + # execute data source, passing only relevant data values. + logging.info("Evaluating for Data Source {}".format( + factory.name)) + res = evaluator.run(model, passed_data_values) + + if len(res) != len(out_slots): + error_txt = ( + "The number of data values ({} values) returned" + " by '{}' does not match the number" + " of output slots it specifies ({} values)." + " This is likely a plugin error.").format( + len(res), factory.name, len(out_slots) + ) + + logging.error(error_txt) + raise RuntimeError(error_txt) + + if len(res) != len(model.output_slot_names): + error_txt = ( + "The number of data values ({} values) returned" + " by '{}' does not match the number" + " of user-defined names specified ({} values)." + " This is either a plugin error or a file" + " error.").format( + len(res), + factory.name, + len(model.output_slot_names) ) + logging.error(error_txt) raise RuntimeError(error_txt) - # The data values obtained by the communicator are unnamed. - # Assign the name to each datavalue as specified by the user. - for dv, param in zip(mco_data_values, model.parameters): - dv.name = param.name - - return mco_data_values - - def _bind_data_values(self, - available_data_values, - model_slot_map, - slots): - """ - Given the named data values in the environment, the slots a given - data source expects, and the user-specified names for each of these - slots, returns those data values with the requested names, ordered - in the correct order as specified by the slot map. - """ - passed_data_values = [] - lookup_map = {dv.name: dv for dv in available_data_values} - - if len(slots) != len(model_slot_map): - raise RuntimeError("The length of the slots is not equal to" - " the length of the slot map. This may" - " indicate a file error.") + # At this point, the returned data values are unnamed. + # Add the names as specified by the user. + for dv, output_slot_name in zip(res, model.output_slot_names): + dv.name = output_slot_name + + # If the name was not specified, simply discard the value, + # because apparently the user is not interested in it. + results.extend([r for r in res if r.name != ""]) + + # Finally, return all the computed data values from all evaluators, + # properly named. + return results + +def _get_data_values_from_mco(model, communicator): + """Helper method. + Receives the data (in order) from the MCO, and bind them to the + specified names as from the model. + + Parameters + ---------- + model: BaseMCOModel + the MCO model (where the user-defined variable names are specified) + communicator: BaseMCOCommunicator + The communicator that produces the (temporarily unnamed) datavalues + from the MCO. + """ + mco_data_values = communicator.receive_from_mco(model) + + if len(mco_data_values) != len(model.parameters): + error_txt = ("The number of data values returned by" + " the MCO ({} values) does not match the" + " number of parameters specified ({} values)." + " This is either a MCO plugin error or the workflow" + " file is corrupted.").format( + len(mco_data_values), len(model.parameters) + ) + logging.error(error_txt) + raise RuntimeError(error_txt) + + # The data values obtained by the communicator are unnamed. + # Assign the name to each datavalue as specified by the user. + for dv, param in zip(mco_data_values, model.parameters): + dv.name = param.name + + # Exclude those who have no name set. + return [dv for dv in mco_data_values if dv.name != ""] + + +def _bind_data_values(available_data_values, + model_slot_map, + slots): + """ + Given the named data values in the environment, the slots a given + data source expects, and the user-specified names for each of these + slots, returns those data values with the requested names, ordered + in the correct order as specified by the slot map. + """ + passed_data_values = [] + lookup_map = {dv.name: dv for dv in available_data_values} + + if len(slots) != len(model_slot_map): + raise RuntimeError("The length of the slots is not equal to" + " the length of the slot map. This may" + " indicate a file error.") + + try: for slot, slot_map in zip(slots, model_slot_map): passed_data_values.append(lookup_map[slot_map.name]) - - return passed_data_values + except KeyError: + raise RuntimeError( + "Unable to find requested name '{}' in available " + "data values. Current data value names: {}".format( + slot_map.name, + list(lookup_map.keys()))) + + return passed_data_values diff --git a/force_bdss/tests/test_core_evaluation_driver.py b/force_bdss/tests/test_core_evaluation_driver.py index 5bd4a148f7ba97d45dc332071606c953c343600c..07ac470fb855681b3afe64339cd6c72cb4049e4a 100644 --- a/force_bdss/tests/test_core_evaluation_driver.py +++ b/force_bdss/tests/test_core_evaluation_driver.py @@ -1,5 +1,8 @@ import unittest + from traits.api import Float, List + +from force_bdss.core.input_slot_map import InputSlotMap from force_bdss.factory_registry_plugin import FactoryRegistryPlugin from force_bdss.core.data_value import DataValue from force_bdss.core.slot import Slot @@ -27,7 +30,8 @@ except ImportError: from envisage.api import Application -from force_bdss.core_evaluation_driver import CoreEvaluationDriver +from force_bdss.core_evaluation_driver import CoreEvaluationDriver, \ + _bind_data_values, _compute_layer_results class NullMCOModel(BaseMCOModel): @@ -159,6 +163,19 @@ class OneValueDataSource(BaseDataSource): ) +class TwoInputsThreeOutputsDataSource(BaseDataSource): + """Incorrect data source implementation whose run returns a data value + but no slot was specified for it.""" + def run(self, model, parameters): + return [DataValue(value=1), DataValue(value=2), DataValue(value=3)] + + def slots(self, model): + return ( + (Slot(), Slot()), + (Slot(), Slot(), Slot()) + ) + + class NullDataSourceFactory(BaseDataSourceFactory): id = factory_id("enthought", "null_ds") name = "null_ds" @@ -276,3 +293,84 @@ class TestCoreEvaluationDriver(unittest.TestCase): " returned by 'null_kpic' does not match" " the number of user-defined names"): driver.application_started() + + def test_bind_data_values(self): + data_values = [ + DataValue(name="foo"), + DataValue(name="bar"), + DataValue(name="baz") + ] + + slot_map = ( + InputSlotMap(name="baz"), + InputSlotMap(name="bar") + ) + + slots = ( + Slot(), + Slot() + ) + + result = _bind_data_values(data_values, slot_map, slots) + self.assertEqual(result[0], data_values[2]) + self.assertEqual(result[1], data_values[1]) + + # Check the errors. Only one slot map for two slots. + slot_map = ( + InputSlotMap(name="baz"), + ) + + with self.assertRaisesRegexp( + RuntimeError, + "The length of the slots is not equal to the length of" + " the slot map"): + _bind_data_values(data_values, slot_map, slots) + + # missing value in the given data values. + slot_map = ( + InputSlotMap(name="blap"), + InputSlotMap(name="bar") + ) + + with self.assertRaisesRegexp( + RuntimeError, + "Unable to find requested name 'blap' in available" + " data values."): + _bind_data_values(data_values, slot_map, slots) + + def test_compute_layer_results(self): + + data_values = [ + DataValue(name="foo"), + DataValue(name="bar"), + DataValue(name="baz"), + DataValue(name="quux") + ] + + mock_ds_factory = mock.Mock(spec=BaseDataSourceFactory) + mock_ds_factory.name = "mock factory" + mock_ds_factory.create_data_source.return_value = \ + TwoInputsThreeOutputsDataSource(mock_ds_factory) + evaluator_model = NullDataSourceModel(factory=mock_ds_factory) + + evaluator_model.input_slot_maps = [ + InputSlotMap(name="foo"), + InputSlotMap(name="quux") + ] + evaluator_model.output_slot_names = ["one", "", "three"] + + res = _compute_layer_results( + data_values, + [evaluator_model], + "create_data_source" + ) + self.assertEqual(len(res), 2) + self.assertEqual(res[0].name, "one") + self.assertEqual(res[0].value, 1) + self.assertEqual(res[1].name, "three") + self.assertEqual(res[1].value, 3) + + def test_empty_slot_name_skips_data_value(self): + """Checks if leaving a slot name empty will skip the data value + in the final output + """