diff --git a/force_bdss/core/slot.py b/force_bdss/core/slot.py index 22182759c6be5ab0a05833916b7e03fbfdcbb09b..8db25aabb09670d51c2e997bd4e4185db324de6c 100644 --- a/force_bdss/core/slot.py +++ b/force_bdss/core/slot.py @@ -1,11 +1,16 @@ from traits.api import HasStrictTraits, String +from ..local_traits import CUBAType class Slot(HasStrictTraits): - """Describes an input or output slot in the DataSource or - KPICalculator""" + """ + Describes an input or output slot in the DataSource or + KPICalculator. If the DataSource and KPICalculator are functions, slots + define their argument number and types they need as input and what + they return as output. + """ #: A textual description of the slot description = String("No description") #: The CUBA key of the slot - type = String() + type = CUBAType() diff --git a/force_bdss/core_evaluation_driver.py b/force_bdss/core_evaluation_driver.py index f71d30c46ff3e81f5f0788ebf28ef9d68feaa88b..2568ddb66a3eba21f07b7098c3b99a76164ca045 100644 --- a/force_bdss/core_evaluation_driver.py +++ b/force_bdss/core_evaluation_driver.py @@ -1,6 +1,8 @@ from __future__ import print_function import sys +import logging + from traits.api import on_trait_change from .ids import plugin_id @@ -31,87 +33,185 @@ class CoreEvaluationDriver(BaseCoreDriver): mco_bundle = mco_model.bundle mco_communicator = mco_bundle.create_communicator() - # Receives the data from the MCO. These are technically unnamed. - # The names are then assigned. Order is important - mco_data_values = mco_communicator.receive_from_mco(mco_model) + mco_data_values = self._get_data_values_from_mco(mco_model, + mco_communicator) - if len(mco_data_values) != len(mco_model.parameters): - raise RuntimeError("The number of data values returned by" - " the MCO does not match the number of" - " parameters specified. This is likely a" - " MCO plugin error.") + ds_results = self._compute_ds_results( + mco_data_values, + workflow) - # Assign the name to the data value that was emitted. - for dv, param in zip(mco_data_values, mco_model.parameters): - dv.name = param.name + kpi_results = self._compute_kpi_results( + ds_results + mco_data_values, + workflow) + mco_communicator.send_to_mco(mco_model, kpi_results) + + def _compute_ds_results(self, environment_data_values, workflow): + """Helper routine. + Performs the evaluation of the DataSources, passing the current + environment data values (the MCO data) + """ ds_results = [] + for ds_model in workflow.data_sources: ds_bundle = ds_model.bundle data_source = ds_bundle.create_data_source() + # Get the slots for this data source. These must be matched to + # the appropriate values in the environment data values. + # Matching is by position. in_slots, out_slots = data_source.slots(ds_model) + + # Binding performs the extraction of the specified data values + # satisfying the above input slots from the environment data values + # considering what the user specified in terms of names (which is + # in the model input slot maps. + # The resulting data are the ones picked by name from the + # environment data values, and in the appropriate ordering as + # needed by the input slots. passed_data_values = self._bind_data_values( - mco_data_values, + environment_data_values, ds_model.input_slot_maps, in_slots) + # execute data source, passing only relevant data values. + logging.info("Evaluating for Data Source {}".format( + ds_bundle.name)) res = data_source.run(ds_model, passed_data_values) + if len(res) != len(out_slots): - raise RuntimeError("The number of data values returned by" - " the DataSource does not match the number" - " of parameters specified. This is likely a" - " DataSource plugin error.") + error_txt = ( + "The number of data values ({} values) returned" + " by the DataSource '{}' does not match the number" + " of output slots it specifies ({} values)." + " This is likely a DataSource plugin error.").format( + len(res), ds_bundle.name, len(out_slots) + ) - if len(res) != len(ds_model.output_slot_names): - raise RuntimeError("The number of data values returned by" - " the DataSource does not match the number" - " of names specified. This is either an" - " input file error or a plugin error.") + logging.error(error_txt) + raise RuntimeError(error_txt) + if len(res) != len(ds_model.output_slot_names): + error_txt = ( + "The number of data values ({} values) returned" + " by the DataSource '{}' does not match the number" + " of user-defined names specified ({} values)." + " This is likely a DataSource plugin error.").format( + len(res), + ds_bundle.name, + len(ds_model.output_slot_names) + ) + + logging.error(error_txt) + raise RuntimeError(error_txt) + + # At this point, the returned data values are unnamed. + # Add the names as specified by the user. for dv, output_slot_name in zip(res, ds_model.output_slot_names): dv.name = output_slot_name ds_results.extend(res) + # Finally, return all the computed data values from all data sources, + # properly named. + return ds_results + + def _compute_kpi_results(self, environment_data_values, workflow): + """Perform evaluation of all KPI calculators. + environment_data_values contains all data values provided from + the MCO and data sources. + """ kpi_results = [] + for kpic_model in workflow.kpi_calculators: kpic_bundle = kpic_model.bundle kpi_calculator = kpic_bundle.create_kpi_calculator() + in_slots, out_slots = kpi_calculator.slots(kpic_model) passed_data_values = self._bind_data_values( - mco_data_values+ds_results, + environment_data_values, kpic_model.input_slot_maps, in_slots) + logging.info("Evaluating for KPICalculator {}".format( + kpic_bundle.name)) + res = kpi_calculator.run(kpic_model, passed_data_values) + if len(res) != len(out_slots): - raise RuntimeError("The number of data values returned by" - " the KPICalculator does not match the" - " number of parameters specified. This is" - " likely a KPICalculator plugin error.") + error_txt = ( + "The number of data values ({} values) returned by" + " the KPICalculator '{}' does not match the" + " number of output slots ({} values). This is" + " likely a KPICalculator plugin error." + ).format(len(res), kpic_bundle.name, len(out_slots)) + logging.error(error_txt) + raise RuntimeError(error_txt) if len(res) != len(kpic_model.output_slot_names): - raise RuntimeError("The number of data values returned by" - " the KPICalculator does not match the" - " number of names specified. This is" - " either an input file error or a plugin" - " error.") - - for kpi, output_slot_name in zip(res, - kpic_model.output_slot_names): + error_txt = ( + "The number of data values ({} values) returned by" + " the KPICalculator '{}' does not match the" + " number of user-defined names specified ({} values)." + " This is either an input file error or a plugin" + " error." + ).format(len(res), kpic_bundle.name, + len(kpic_model.output_slot_names)) + logging.error(error_txt) + raise RuntimeError(error_txt) + + for kpi, output_slot_name in zip( + res, kpic_model.output_slot_names): kpi.name = output_slot_name kpi_results.extend(res) - mco_communicator.send_to_mco(mco_model, kpi_results) + return kpi_results + + def _get_data_values_from_mco(self, model, communicator): + """Helper method. + Receives the data (in order) from the MCO, and bind them to the + specified names as from the model. + + Parameters + ---------- + model: BaseMCOModel + the MCO model (where the user-defined variable names are specified) + communicator: BaseMCOCommunicator + The communicator that produces the (temporarily unnamed) datavalues + from the MCO. + """ + mco_data_values = communicator.receive_from_mco(model) + + if len(mco_data_values) != len(model.parameters): + error_txt = ("The number of data values returned by" + " the MCO ({} values) does not match the" + " number of parameters specified ({} values)." + " This is either a MCO plugin error or the workflow" + " file is corrupted.").format( + len(mco_data_values), len(model.parameters) + ) + logging.error(error_txt) + raise RuntimeError(error_txt) + + # The data values obtained by the communicator are unnamed. + # Assign the name to each datavalue as specified by the user. + for dv, param in zip(mco_data_values, model.parameters): + dv.name = param.name + + return mco_data_values def _bind_data_values(self, available_data_values, model_slot_map, slots): - + """ + Given the named data values in the environment, the slots a given + data source expects, and the user-specified names for each of these + slots, returns those data values with the requested names, ordered + in the correct order as specified by the slot map. + """ passed_data_values = [] lookup_map = {dv.name: dv for dv in available_data_values} diff --git a/force_bdss/local_traits.py b/force_bdss/local_traits.py index f68b5d909a9b2421fb35e204a604d38a80dbcd59..1c5e3f157e4409f2b7cbb61d60877cc6974ef9f4 100644 --- a/force_bdss/local_traits.py +++ b/force_bdss/local_traits.py @@ -1,6 +1,11 @@ -from traits.api import Regex +from traits.api import Regex, String #: Used for variable names, but allow also empty string as it's the default #: case and it will be present if the workflow is saved before actually #: specifying the value. Identifier = Regex(regex="(^[^\d\W]\w*\Z|^\Z)") + + +#: Identifies a CUBA type with its key. At the moment a String with +#: no validation, but will come later. +CUBAType = String() diff --git a/force_bdss/tests/fixtures/test_null.json b/force_bdss/tests/fixtures/test_null.json new file mode 100644 index 0000000000000000000000000000000000000000..b10ffe97ff8f6d850d9f21e25cbc3aa0df0e0cc7 --- /dev/null +++ b/force_bdss/tests/fixtures/test_null.json @@ -0,0 +1,34 @@ +{ + "version": "1", + "workflow": { + "mco": { + "id": "force.bdss.enthought.bundle.null_mco", + "model_data": { + "parameters" : [ + ] + } + }, + "data_sources": [ + { + "id": "force.bdss.enthought.bundle.null_ds", + "model_data": { + "input_slot_maps": [ + ], + "output_slot_names": [ + ] + } + } + ], + "kpi_calculators": [ + { + "id": "force.bdss.enthought.bundle.null_kpic", + "model_data": { + "input_slot_maps": [ + ], + "output_slot_names": [ + ] + } + } + ] + } +} diff --git a/force_bdss/tests/test_core_evaluation_driver.py b/force_bdss/tests/test_core_evaluation_driver.py index fc79b484871b7e065956eccfe2a532a008c6db48..9f680fa52ac94c16226889fc215a21337273c584 100644 --- a/force_bdss/tests/test_core_evaluation_driver.py +++ b/force_bdss/tests/test_core_evaluation_driver.py @@ -1,6 +1,8 @@ import unittest -from traits.api import Float +from traits.api import Float, List from force_bdss.bundle_registry_plugin import BundleRegistryPlugin +from force_bdss.core.data_value import DataValue +from force_bdss.core.slot import Slot from force_bdss.data_sources.base_data_source import BaseDataSource from force_bdss.data_sources.base_data_source_bundle import \ BaseDataSourceBundle @@ -37,15 +39,15 @@ class NullMCO(BaseMCO): pass -class NullParameter(BaseMCOParameter): +class RangedParameter(BaseMCOParameter): initial_value = Float() lower_bound = Float() upper_bound = Float() -class NullParameterFactory(BaseMCOParameterFactory): - id = mco_parameter_id("enthought", "dummy_dakota", "ranged") - model_class = NullParameter +class RangedParameterFactory(BaseMCOParameterFactory): + id = mco_parameter_id("enthought", "null_mco", "null") + model_class = RangedParameter class NullMCOCommunicator(BaseMCOCommunicator): @@ -56,11 +58,23 @@ class NullMCOCommunicator(BaseMCOCommunicator): return [] +class OneDataValueMCOCommunicator(BaseMCOCommunicator): + """A communicator that returns one single datavalue, for testing purposes. + """ + def send_to_mco(self, model, kpi_results): + pass + + def receive_from_mco(self, model): + return [ + DataValue() + ] + + class NullMCOBundle(BaseMCOBundle): - id = bundle_id("enthought", "dummy_dakota") + id = bundle_id("enthought", "null_mco") def create_model(self, model_data=None): - return NullMCOModel(self) + return NullMCOModel(self, **model_data) def create_communicator(self): return NullMCOCommunicator(self) @@ -69,7 +83,7 @@ class NullMCOBundle(BaseMCOBundle): return NullMCO(self) def parameter_factories(self): - return [NullParameterFactory(self)] + return [] class NullKPICalculatorModel(BaseKPICalculatorModel): @@ -84,7 +98,26 @@ class NullKPICalculator(BaseKPICalculator): return (), () +class BrokenOneValueKPICalculator(BaseKPICalculator): + def run(self, model, data_source_results): + return [DataValue()] + + def slots(self, model): + return (), () + + +class OneValueKPICalculator(BaseKPICalculator): + def run(self, model, data_source_results): + return [DataValue()] + + def slots(self, model): + return (), (Slot(), ) + + class NullKPICalculatorBundle(BaseKPICalculatorBundle): + id = bundle_id("enthought", "null_kpic") + name = "null_kpic" + def create_model(self, model_data=None): return NullKPICalculatorModel(self) @@ -104,7 +137,32 @@ class NullDataSource(BaseDataSource): return (), () +class BrokenOneValueDataSource(BaseDataSource): + """Incorrect data source implementation whose run returns a data value + but no slot was specified for it.""" + def run(self, model, parameters): + return [DataValue()] + + def slots(self, model): + return (), () + + +class OneValueDataSource(BaseDataSource): + """Incorrect data source implementation whose run returns a data value + but no slot was specified for it.""" + def run(self, model, parameters): + return [DataValue()] + + def slots(self, model): + return (), ( + Slot(), + ) + + class NullDataSourceBundle(BaseDataSourceBundle): + id = bundle_id("enthought", "null_ds") + name = "null_ds" + def create_model(self, model_data=None): return NullDataSourceModel(self) @@ -112,16 +170,20 @@ class NullDataSourceBundle(BaseDataSourceBundle): return NullDataSource(self) +class DummyBundleRegistryPlugin(BundleRegistryPlugin): + mco_bundles = List() + kpi_calculator_bundles = List() + data_source_bundles = List() + + def mock_bundle_registry_plugin(): - bundle_registry_plugin = mock.Mock(spec=BundleRegistryPlugin) + bundle_registry_plugin = DummyBundleRegistryPlugin() bundle_registry_plugin.mco_bundles = [ NullMCOBundle(bundle_registry_plugin)] - bundle_registry_plugin.mco_bundle_by_id = mock.Mock( - return_value=NullMCOBundle(bundle_registry_plugin)) - bundle_registry_plugin.kpi_calculator_bundle_by_id = mock.Mock( - return_value=NullKPICalculatorBundle(bundle_registry_plugin)) - bundle_registry_plugin.data_source_bundle_by_id = mock.Mock( - return_value=NullDataSourceBundle(bundle_registry_plugin)) + bundle_registry_plugin.kpi_calculator_bundles = [ + NullKPICalculatorBundle(bundle_registry_plugin)] + bundle_registry_plugin.data_source_bundles = [ + NullDataSourceBundle(bundle_registry_plugin)] return bundle_registry_plugin @@ -132,7 +194,7 @@ class TestCoreEvaluationDriver(unittest.TestCase): application.get_plugin = mock.Mock( return_value=self.mock_bundle_registry_plugin ) - application.workflow_filepath = fixtures.get("test_csv.json") + application.workflow_filepath = fixtures.get("test_null.json") self.mock_application = application def test_initialization(self): @@ -140,3 +202,77 @@ class TestCoreEvaluationDriver(unittest.TestCase): application=self.mock_application, ) driver.application_started() + + def test_error_for_non_matching_mco_parameters(self): + bundle = self.mock_bundle_registry_plugin.mco_bundles[0] + with mock.patch.object(bundle.__class__, + "create_communicator") as create_comm: + create_comm.return_value = OneDataValueMCOCommunicator( + bundle) + driver = CoreEvaluationDriver( + application=self.mock_application, + ) + with self.assertRaisesRegexp( + RuntimeError, + "The number of data values returned by the MCO"): + driver.application_started() + + def test_error_for_incorrect_output_slots(self): + bundle = self.mock_bundle_registry_plugin.data_source_bundles[0] + with mock.patch.object(bundle.__class__, + "create_data_source") as create_ds: + create_ds.return_value = BrokenOneValueDataSource(bundle) + driver = CoreEvaluationDriver( + application=self.mock_application, + ) + with self.assertRaisesRegexp( + RuntimeError, + "The number of data values \(1 values\)" + " returned by the DataSource 'null_ds' does not match" + " the number of output slots"): + driver.application_started() + + def test_error_for_missing_ds_output_names(self): + bundle = self.mock_bundle_registry_plugin.data_source_bundles[0] + with mock.patch.object(bundle.__class__, + "create_data_source") as create_ds: + create_ds.return_value = OneValueDataSource(bundle) + driver = CoreEvaluationDriver( + application=self.mock_application, + ) + with self.assertRaisesRegexp( + RuntimeError, + "The number of data values \(1 values\)" + " returned by the DataSource 'null_ds' does not match" + " the number of user-defined names"): + driver.application_started() + + def test_error_for_incorrect_kpic_output_slots(self): + bundle = self.mock_bundle_registry_plugin.kpi_calculator_bundles[0] + with mock.patch.object(bundle.__class__, + "create_kpi_calculator") as create_kpic: + create_kpic.return_value = BrokenOneValueKPICalculator(bundle) + driver = CoreEvaluationDriver( + application=self.mock_application, + ) + with self.assertRaisesRegexp( + RuntimeError, + "The number of data values \(1 values\)" + " returned by the KPICalculator 'null_kpic' does not match" + " the number of output slots"): + driver.application_started() + + def test_error_for_missing_kpic_output_names(self): + bundle = self.mock_bundle_registry_plugin.kpi_calculator_bundles[0] + with mock.patch.object(bundle.__class__, + "create_kpi_calculator") as create_kpic: + create_kpic.return_value = OneValueKPICalculator(bundle) + driver = CoreEvaluationDriver( + application=self.mock_application, + ) + with self.assertRaisesRegexp( + RuntimeError, + "The number of data values \(1 values\)" + " returned by the KPICalculator 'null_kpic' does not match" + " the number of user-defined names"): + driver.application_started() diff --git a/force_bdss/tests/test_local_traits.py b/force_bdss/tests/test_local_traits.py index ace5943e779d138349c9a595961be29eb4720b74..d61e6d2c89520653d65431a0030ad89b362f18be 100644 --- a/force_bdss/tests/test_local_traits.py +++ b/force_bdss/tests/test_local_traits.py @@ -1,11 +1,12 @@ import unittest from traits.api import HasStrictTraits, TraitError -from force_bdss.local_traits import Identifier +from force_bdss.local_traits import Identifier, CUBAType class Traited(HasStrictTraits): val = Identifier() + cuba = CUBAType() class TestLocalTraits(unittest.TestCase): @@ -19,3 +20,8 @@ class TestLocalTraits(unittest.TestCase): for broken in ["0", None, 123, "0hello", "hi$", "hi%"]: with self.assertRaises(TraitError): c.val = broken + + def test_cuba_type(self): + c = Traited() + c.cuba = "PRESSURE" + self.assertEqual(c.cuba, "PRESSURE")