diff --git a/force_bdss/api.py b/force_bdss/api.py index 165dd9bdc1869e369d2f0e80a2669224947941d4..d5567d689ee41c7e0c88db1aa19c5048f6d3e4fc 100644 --- a/force_bdss/api.py +++ b/force_bdss/api.py @@ -1,6 +1,7 @@ from .base_extension_plugin import BaseExtensionPlugin # noqa from .ids import plugin_id, factory_id # noqa +from .core.base_factory import BaseFactory # noqa from .core.data_value import DataValue # noqa from .core.workflow import Workflow # noqa from .core.slot import Slot # noqa @@ -11,6 +12,7 @@ from .core.kpi_specification import KPISpecification # noqa from .core.execution_layer import ExecutionLayer # noqa from .core.verifier import verify_workflow # noqa from .core.verifier import VerifierError # noqa +from .core.execution import execute_layer, execute_workflow # noqa from .data_sources.base_data_source_model import BaseDataSourceModel # noqa from .data_sources.base_data_source import BaseDataSource # noqa @@ -44,4 +46,4 @@ from .ui_hooks.i_ui_hooks_factory import IUIHooksFactory # noqa from .ui_hooks.base_ui_hooks_factory import BaseUIHooksFactory # noqa from .ui_hooks.base_ui_hooks_manager import BaseUIHooksManager # noqa -from .local_traits import Identifier # noqa +from .local_traits import Identifier, PositiveInt # noqa diff --git a/force_bdss/bdss_application.py b/force_bdss/bdss_application.py index a822528477d75c8a36ca856b41a5c753d46106b2..4e5d49b188c33e50ee6e364d78fd95afc8b374ad 100644 --- a/force_bdss/bdss_application.py +++ b/force_bdss/bdss_application.py @@ -6,8 +6,9 @@ from stevedore.exception import NoMatches from envisage.api import Application from envisage.core_plugin import CorePlugin -from traits.api import Unicode, Bool +from traits.api import Unicode, Bool, Property +from force_bdss.ids import InternalPluginID from .factory_registry_plugin import FactoryRegistryPlugin from .core_evaluation_driver import CoreEvaluationDriver from .core_mco_driver import CoreMCODriver @@ -28,6 +29,9 @@ class BDSSApplication(Application): #: coordination of the MCO itself. See design notes for more details. evaluate = Bool() + #: Gives the currently opened workflow + workflow = Property() + def __init__(self, evaluate, workflow_filepath): self.evaluate = evaluate self.workflow_filepath = workflow_filepath @@ -53,6 +57,15 @@ class BDSSApplication(Application): super(BDSSApplication, self).__init__(plugins=plugins) + def _get_workflow(self): + if self.evaluate: + plugin = self.get_plugin( + InternalPluginID.CORE_EVALUATION_DRIVER_ID) + else: + plugin = self.get_plugin(InternalPluginID.CORE_MCO_DRIVER_ID) + + return plugin.workflow + def _import_extensions(plugins, ext): """Service routine extracted for testing. diff --git a/force_bdss/core/base_factory.py b/force_bdss/core/base_factory.py index e8daa730591b65ebe6cc89398cc95e59d1cd50af..f69db3f1d97c1da7848c1a40bc7895ae51ef4226 100644 --- a/force_bdss/core/base_factory.py +++ b/force_bdss/core/base_factory.py @@ -1,5 +1,5 @@ from envisage.plugin import Plugin -from traits.api import HasStrictTraits, Str, Instance +from traits.api import HasStrictTraits, Str, Unicode, Instance from force_bdss.ids import factory_id @@ -12,6 +12,9 @@ class BaseFactory(HasStrictTraits): #: A human readable name of the factory. Spaces allowed name = Str() + #: A long description of the factory. + description = Unicode() + #: Reference to the plugin that carries this factory #: This is automatically set by the system. you should not define it #: in your subclass. @@ -21,6 +24,7 @@ class BaseFactory(HasStrictTraits): super(BaseFactory, self).__init__(plugin=plugin, *args, **kwargs) self.name = self.get_name() + self.description = self.get_description() identifier = self.get_identifier() try: id = self._global_id(identifier) @@ -51,5 +55,8 @@ class BaseFactory(HasStrictTraits): "get_identifier was not implemented in factory {}".format( self.__class__)) + def get_description(self): + return u"No description available." + def _global_id(self, identifier): return factory_id(self.plugin.id, identifier) diff --git a/force_bdss/core/execution.py b/force_bdss/core/execution.py new file mode 100644 index 0000000000000000000000000000000000000000..5b13be218cd8b12f4e1613288d7746645803ec56 --- /dev/null +++ b/force_bdss/core/execution.py @@ -0,0 +1,209 @@ +import logging +from force_bdss.core.data_value import DataValue + +log = logging.getLogger(__name__) + + +def execute_workflow(workflow, data_values): + """Executes the given workflow using the list of data values. + Returns a list of data values for the KPI results + + Parameters + ---------- + workflow: Workflow + The instance of the workflow + + data_values: List + The data values that the MCO generally provides. + + Returns + ------- + list: A list of DataValues containing the KPI results. + """ + + available_data_values = data_values[:] + for index, layer in enumerate(workflow.execution_layers): + log.info("Computing data layer {}".format(index)) + ds_results = execute_layer(layer, available_data_values) + available_data_values += ds_results + + log.info("Aggregating KPI data") + + kpi_results = [] + kpi_names = [kpi.name for kpi in workflow.mco.kpis] + + kpi_results = [ + dv + for dv in available_data_values + if dv.name in kpi_names + ] + + return kpi_results + + +def execute_layer(layer, environment_data_values): + """Helper routine. + Performs the evaluation of a single layer. + At the moment we have a single layer of DataSources followed + by a single layer of KPI calculators. + + Parameters + ---------- + layer: ExecutionLayer + A list of the models for all the data sources + + environment_data_values: list + A list of data values to submit to the evaluators. + + NOTE: The above parameter is going to go away as soon as we move + to unlimited layers and remove the distinction between data sources + and KPI calculators. + """ + results = [] + + for model in layer.data_sources: + factory = model.factory + try: + data_source = factory.create_data_source() + except Exception: + log.exception( + "Unable to create data source from factory '{}' " + "in plugin '{}'. This may indicate a programming " + "error in the plugin".format( + factory.id, + factory.plugin.id)) + raise + + # Get the slots for this data source. These must be matched to + # the appropriate values in the environment data values. + # Matching is by position. + in_slots, out_slots = data_source.slots(model) + + # Binding performs the extraction of the specified data values + # satisfying the above input slots from the environment data values + # considering what the user specified in terms of names (which is + # in the model input slot info + # The resulting data are the ones picked by name from the + # environment data values, and in the appropriate ordering as + # needed by the input slots. + passed_data_values = _bind_data_values( + environment_data_values, + model.input_slot_info, + in_slots) + + # execute data source, passing only relevant data values. + log.info("Evaluating for Data Source {}".format( + factory.name)) + log.info("Passed values:") + for idx, dv in enumerate(passed_data_values): + log.info("{}: {}".format(idx, dv)) + + try: + res = data_source.run(model, passed_data_values) + except Exception: + log.exception( + "Evaluation could not be performed. " + "Run method raised exception.") + raise + + if not isinstance(res, list): + error_txt = ( + "The run method of data source {} must return a list." + " It returned instead {}. Fix the run() method to return" + " the appropriate entity.".format( + factory.name, + type(res) + )) + log.error(error_txt) + raise RuntimeError(error_txt) + + if len(res) != len(out_slots): + error_txt = ( + "The number of data values ({} values) returned" + " by '{}' does not match the number" + " of output slots it specifies ({} values)." + " This is likely a plugin error.").format( + len(res), factory.name, len(out_slots) + ) + + log.error(error_txt) + raise RuntimeError(error_txt) + + if len(res) != len(model.output_slot_info): + error_txt = ( + "The number of data values ({} values) returned" + " by '{}' does not match the number" + " of user-defined names specified ({} values)." + " This is either a plugin error or a file" + " error.").format( + len(res), + factory.name, + len(model.output_slot_info) + ) + + log.error(error_txt) + raise RuntimeError(error_txt) + + for idx, dv in enumerate(res): + if not isinstance(dv, DataValue): + error_txt = ( + "The result list returned by DataSource {} contains" + " an entry that is not a DataValue. An entry of type" + " {} was instead found in position {}." + " Fix the DataSource.run() method" + " to return the appropriate entity.".format( + factory.name, + type(dv), + idx + ) + ) + log.error(error_txt) + raise RuntimeError(error_txt) + + # At this point, the returned data values are unnamed. + # Add the names as specified by the user. + for dv, output_slot_info in zip(res, model.output_slot_info): + dv.name = output_slot_info.name + + # If the name was not specified, simply discard the value, + # because apparently the user is not interested in it. + res = [r for r in res if r.name != ""] + results.extend(res) + + log.info("Returned values:") + for idx, dv in enumerate(res): + log.info("{}: {}".format(idx, dv)) + + # Finally, return all the computed data values from all evaluators, + # properly named. + return results + + +def _bind_data_values(available_data_values, + model_slot_map, + slots): + """ + Given the named data values in the environment, the slots a given + data source expects, and the user-specified names for each of these + slots, returns those data values with the requested names, ordered + in the correct order as specified by the slot map. + """ + passed_data_values = [] + lookup_map = {dv.name: dv for dv in available_data_values} + + if len(slots) != len(model_slot_map): + raise RuntimeError("The length of the slots is not equal to" + " the length of the slot map. This may" + " indicate a file error.") + + try: + for slot, slot_map in zip(slots, model_slot_map): + passed_data_values.append(lookup_map[slot_map.name]) + except KeyError: + raise RuntimeError( + "Unable to find requested name '{}' in available " + "data values. Current data value names: {}".format( + slot_map.name, + list(lookup_map.keys()))) + + return passed_data_values diff --git a/force_bdss/core/tests/test_execution.py b/force_bdss/core/tests/test_execution.py new file mode 100644 index 0000000000000000000000000000000000000000..ceed26f80066dc11a9186cd6fa126b2117ecb3b9 --- /dev/null +++ b/force_bdss/core/tests/test_execution.py @@ -0,0 +1,227 @@ +import unittest + +import testfixtures +import six + +from force_bdss.core.execution_layer import ExecutionLayer +from force_bdss.core.kpi_specification import KPISpecification +from force_bdss.core.output_slot_info import OutputSlotInfo +from force_bdss.core.workflow import Workflow +from force_bdss.tests.probe_classes.data_source import ProbeDataSourceFactory + +from force_bdss.core.input_slot_info import InputSlotInfo +from force_bdss.core.data_value import DataValue +from force_bdss.core.slot import Slot +from force_bdss.tests.probe_classes.factory_registry_plugin import \ + ProbeFactoryRegistryPlugin +from force_bdss.tests.probe_classes.mco import ProbeMCOFactory + +from force_bdss.core.execution import execute_workflow, execute_layer, \ + _bind_data_values + + +class TestExecution(unittest.TestCase): + def setUp(self): + self.registry = ProbeFactoryRegistryPlugin() + self.plugin = self.registry.plugin + + def test_bind_data_values(self): + data_values = [ + DataValue(name="foo"), + DataValue(name="bar"), + DataValue(name="baz") + ] + + slot_map = ( + InputSlotInfo(name="baz"), + InputSlotInfo(name="bar") + ) + + slots = ( + Slot(), + Slot() + ) + + result = _bind_data_values(data_values, slot_map, slots) + self.assertEqual(result[0], data_values[2]) + self.assertEqual(result[1], data_values[1]) + + # Check the errors. Only one slot map for two slots. + slot_map = ( + InputSlotInfo(name="baz"), + ) + + with testfixtures.LogCapture(): + with six.assertRaisesRegex( + self, + RuntimeError, + "The length of the slots is not equal to the length of" + " the slot map"): + _bind_data_values(data_values, slot_map, slots) + + # missing value in the given data values. + slot_map = ( + InputSlotInfo(name="blap"), + InputSlotInfo(name="bar") + ) + + with testfixtures.LogCapture(): + with six.assertRaisesRegex( + self, + RuntimeError, + "Unable to find requested name 'blap' in available" + " data values."): + _bind_data_values(data_values, slot_map, slots) + + def test_compute_layer_results(self): + data_values = [ + DataValue(name="foo"), + DataValue(name="bar"), + DataValue(name="baz"), + DataValue(name="quux") + ] + + def run(self, *args, **kwargs): + return [DataValue(value=1), DataValue(value=2), DataValue(value=3)] + + ds_factory = self.registry.data_source_factories[0] + ds_factory.input_slots_size = 2 + ds_factory.output_slots_size = 3 + ds_factory.run_function = run + evaluator_model = ds_factory.create_model() + + evaluator_model.input_slot_info = [ + InputSlotInfo(name="foo"), + InputSlotInfo(name="quux") + ] + evaluator_model.output_slot_info = [ + OutputSlotInfo(name="one"), + OutputSlotInfo(name=""), + OutputSlotInfo(name="three") + ] + + res = execute_layer( + ExecutionLayer(data_sources=[evaluator_model]), + data_values, + ) + self.assertEqual(len(res), 2) + self.assertEqual(res[0].name, "one") + self.assertEqual(res[0].value, 1) + self.assertEqual(res[1].name, "three") + self.assertEqual(res[1].value, 3) + + def test_multilayer_execution(self): + # The multilayer peforms the following execution + # layer 0: in1 + in2 | in3 + in4 + # res1 res2 + # layer 1: res1 + res2 + # res3 + # layer 2: res3 * res1 + # res4 + # layer 3: res4 * res2 + # out1 + # Final result should be + # out1 = ((in1 + in2 + in3 + in4) * (in1 + in2) * (in3 + in4) + + data_values = [ + DataValue(value=10, name="in1"), + DataValue(value=15, name="in2"), + DataValue(value=3, name="in3"), + DataValue(value=7, name="in4") + ] + + def adder(model, parameters): + + first = parameters[0].value + second = parameters[1].value + return [DataValue(value=(first+second))] + + adder_factory = ProbeDataSourceFactory( + self.plugin, + input_slots_size=2, + output_slots_size=1, + run_function=adder) + + def multiplier(model, parameters): + first = parameters[0].value + second = parameters[1].value + return [DataValue(value=(first*second))] + + multiplier_factory = ProbeDataSourceFactory( + self.plugin, + input_slots_size=2, + output_slots_size=1, + run_function=multiplier) + + mco_factory = ProbeMCOFactory(self.plugin) + mco_model = mco_factory.create_model() + mco_model.kpis = [ + KPISpecification(name="out1") + ] + + wf = Workflow( + mco=mco_model, + execution_layers=[ + ExecutionLayer(), + ExecutionLayer(), + ExecutionLayer(), + ExecutionLayer() + ] + ) + # Layer 0 + model = adder_factory.create_model() + model.input_slot_info = [ + InputSlotInfo(name="in1"), + InputSlotInfo(name="in2") + ] + model.output_slot_info = [ + OutputSlotInfo(name="res1") + ] + wf.execution_layers[0].data_sources.append(model) + + model = adder_factory.create_model() + model.input_slot_info = [ + InputSlotInfo(name="in3"), + InputSlotInfo(name="in4") + ] + model.output_slot_info = [ + OutputSlotInfo(name="res2") + ] + wf.execution_layers[0].data_sources.append(model) + + # layer 1 + model = adder_factory.create_model() + model.input_slot_info = [ + InputSlotInfo(name="res1"), + InputSlotInfo(name="res2") + ] + model.output_slot_info = [ + OutputSlotInfo(name="res3") + ] + wf.execution_layers[1].data_sources.append(model) + + # layer 2 + model = multiplier_factory.create_model() + model.input_slot_info = [ + InputSlotInfo(name="res3"), + InputSlotInfo(name="res1") + ] + model.output_slot_info = [ + OutputSlotInfo(name="res4") + ] + wf.execution_layers[2].data_sources.append(model) + + # layer 3 + model = multiplier_factory.create_model() + model.input_slot_info = [ + InputSlotInfo(name="res4"), + InputSlotInfo(name="res2") + ] + model.output_slot_info = [ + OutputSlotInfo(name="out1") + ] + wf.execution_layers[3].data_sources.append(model) + + kpi_results = execute_workflow(wf, data_values) + self.assertEqual(len(kpi_results), 1) + self.assertEqual(kpi_results[0].value, 8750) diff --git a/force_bdss/core_evaluation_driver.py b/force_bdss/core_evaluation_driver.py index be6677e0c1a8c49fe207c2c1a0ab5b042d5cefa3..34bff5e176ccb510361ac3bad3949b43f1b243cf 100644 --- a/force_bdss/core_evaluation_driver.py +++ b/force_bdss/core_evaluation_driver.py @@ -3,7 +3,7 @@ import logging from traits.api import on_trait_change -from force_bdss.core.data_value import DataValue +from force_bdss.core.execution import execute_workflow from force_bdss.ids import InternalPluginID from .base_core_driver import BaseCoreDriver @@ -51,174 +51,6 @@ class CoreEvaluationDriver(BaseCoreDriver): mco_communicator.send_to_mco(mco_model, kpi_results) -def execute_workflow(workflow, data_values): - """Executes the given workflow using the list of data values. - Returns a list of data values for the KPI results - """ - - available_data_values = data_values[:] - for index, layer in enumerate(workflow.execution_layers): - log.info("Computing data layer {}".format(index)) - ds_results = _compute_layer_results( - available_data_values, - layer, - ) - available_data_values += ds_results - - log.info("Aggregating KPI data") - - kpi_results = [] - kpi_names = [kpi.name for kpi in workflow.mco.kpis] - - kpi_results = [ - dv - for dv in available_data_values - if dv.name in kpi_names - ] - - return kpi_results - - -def _compute_layer_results(environment_data_values, - layer, - ): - """Helper routine. - Performs the evaluation of a single layer. - At the moment we have a single layer of DataSources followed - by a single layer of KPI calculators. - - Parameters - ---------- - environment_data_values: list - A list of data values to submit to the evaluators. - - layer: ExecutionLayer - A list of the models for all the data sources - - NOTE: The above parameter is going to go away as soon as we move - to unlimited layers and remove the distinction between data sources - and KPI calculators. - """ - results = [] - - for model in layer.data_sources: - factory = model.factory - try: - data_source = factory.create_data_source() - except Exception: - log.exception( - "Unable to create data source from factory '{}' " - "in plugin '{}'. This may indicate a programming " - "error in the plugin".format( - factory.id, - factory.plugin.id)) - raise - - # Get the slots for this data source. These must be matched to - # the appropriate values in the environment data values. - # Matching is by position. - in_slots, out_slots = data_source.slots(model) - - # Binding performs the extraction of the specified data values - # satisfying the above input slots from the environment data values - # considering what the user specified in terms of names (which is - # in the model input slot info - # The resulting data are the ones picked by name from the - # environment data values, and in the appropriate ordering as - # needed by the input slots. - passed_data_values = _bind_data_values( - environment_data_values, - model.input_slot_info, - in_slots) - - # execute data source, passing only relevant data values. - log.info("Evaluating for Data Source {}".format( - factory.name)) - log.info("Passed values:") - for idx, dv in enumerate(passed_data_values): - log.info("{}: {}".format(idx, dv)) - - try: - res = data_source.run(model, passed_data_values) - except Exception: - log.exception( - "Evaluation could not be performed. " - "Run method raised exception.") - raise - - if not isinstance(res, list): - error_txt = ( - "The run method of data source {} must return a list." - " It returned instead {}. Fix the run() method to return" - " the appropriate entity.".format( - factory.name, - type(res) - )) - log.error(error_txt) - raise RuntimeError(error_txt) - - if len(res) != len(out_slots): - error_txt = ( - "The number of data values ({} values) returned" - " by '{}' does not match the number" - " of output slots it specifies ({} values)." - " This is likely a plugin error.").format( - len(res), factory.name, len(out_slots) - ) - - log.error(error_txt) - raise RuntimeError(error_txt) - - if len(res) != len(model.output_slot_info): - error_txt = ( - "The number of data values ({} values) returned" - " by '{}' does not match the number" - " of user-defined names specified ({} values)." - " This is either a plugin error or a file" - " error.").format( - len(res), - factory.name, - len(model.output_slot_info) - ) - - log.error(error_txt) - raise RuntimeError(error_txt) - - for idx, dv in enumerate(res): - if not isinstance(dv, DataValue): - error_txt = ( - "The result list returned by DataSource {} contains" - " an entry that is not a DataValue. An entry of type" - " {} was instead found in position {}." - " Fix the DataSource.run() method" - " to return the appropriate entity.".format( - factory.name, - type(dv), - idx - ) - ) - log.error(error_txt) - raise RuntimeError(error_txt) - - # At this point, the returned data values are unnamed. - # Add the names as specified by the user. - for dv, output_slot_info in zip(res, model.output_slot_info): - dv.name = output_slot_info.name - - # If the name was not specified, simply discard the value, - # because apparently the user is not interested in it. - res = [r for r in res if r.name != ""] - results.extend(res) - - log.info("Returned values:") - for idx, dv in enumerate(res): - log.info("{}: {}".format(idx, dv)) - - # Finally, return all the computed data values from all evaluators, - # properly named. - return results - - def _get_data_values_from_mco(model, communicator): """Helper method. Receives the data (in order) from the MCO, and bind them to the @@ -255,33 +87,3 @@ def _get_data_values_from_mco(model, communicator): # Exclude those who have no name set. return [dv for dv in mco_data_values if dv.name != ""] - - -def _bind_data_values(available_data_values, - model_slot_map, - slots): - """ - Given the named data values in the environment, the slots a given - data source expects, and the user-specified names for each of these - slots, returns those data values with the requested names, ordered - in the correct order as specified by the slot map. - """ - passed_data_values = [] - lookup_map = {dv.name: dv for dv in available_data_values} - - if len(slots) != len(model_slot_map): - raise RuntimeError("The length of the slots is not equal to" - " the length of the slot map. This may" - " indicate a file error.") - - try: - for slot, slot_map in zip(slots, model_slot_map): - passed_data_values.append(lookup_map[slot_map.name]) - except KeyError: - raise RuntimeError( - "Unable to find requested name '{}' in available " - "data values. Current data value names: {}".format( - slot_map.name, - list(lookup_map.keys()))) - - return passed_data_values diff --git a/force_bdss/data_sources/base_data_source_model.py b/force_bdss/data_sources/base_data_source_model.py index d210f9492fb32518916802200028e273b286fd72..a5df537c9faf7a7c8768dfc7d9e4c5b7a788ad97 100644 --- a/force_bdss/data_sources/base_data_source_model.py +++ b/force_bdss/data_sources/base_data_source_model.py @@ -1,4 +1,6 @@ -from traits.api import ABCHasStrictTraits, Instance, List, Event +from traits.api import ( + ABCHasStrictTraits, Instance, List, Event, on_trait_change +) from force_bdss.core.input_slot_info import InputSlotInfo from force_bdss.core.output_slot_info import OutputSlotInfo @@ -49,3 +51,10 @@ class BaseDataSourceModel(ABCHasStrictTraits): x.__getstate__() for x in self.output_slot_info ] return state + + @on_trait_change("+changes_slots") + def _trigger_changes_slots(self, obj, name, new): + changes_slots = self.traits()[name].changes_slots + + if changes_slots: + self.changes_slots = True diff --git a/force_bdss/data_sources/tests/test_base_data_source_factory.py b/force_bdss/data_sources/tests/test_base_data_source_factory.py index e2aa949db237b64a47a70e5ad4cd746a2642704a..8716ac78e4f0a246ae18b1f2a6d61666ca1a77b1 100644 --- a/force_bdss/data_sources/tests/test_base_data_source_factory.py +++ b/force_bdss/data_sources/tests/test_base_data_source_factory.py @@ -25,6 +25,7 @@ class TestBaseDataSourceFactory(unittest.TestCase): factory = DummyDataSourceFactory(self.plugin) self.assertEqual(factory.id, 'pid.factory.dummy_data_source') self.assertEqual(factory.name, 'Dummy data source') + self.assertEqual(factory.description, u"No description available.") self.assertEqual(factory.model_class, DummyDataSourceModel) self.assertEqual(factory.data_source_class, DummyDataSource) self.assertIsInstance(factory.create_data_source(), DummyDataSource) diff --git a/force_bdss/data_sources/tests/test_base_data_source_model.py b/force_bdss/data_sources/tests/test_base_data_source_model.py index a9be54b713f94d073c2659aa001a5000758b41f9..dde26ea0cd99e43c66747f793842acdb468445b8 100644 --- a/force_bdss/data_sources/tests/test_base_data_source_model.py +++ b/force_bdss/data_sources/tests/test_base_data_source_model.py @@ -1,7 +1,10 @@ import unittest +from traits.api import Int +from traits.testing.api import UnittestTools from force_bdss.core.input_slot_info import InputSlotInfo from force_bdss.core.output_slot_info import OutputSlotInfo +from force_bdss.data_sources.base_data_source_model import BaseDataSourceModel from force_bdss.tests.dummy_classes.data_source import DummyDataSourceModel try: @@ -13,9 +16,18 @@ from force_bdss.data_sources.base_data_source_factory import \ BaseDataSourceFactory -class TestBaseDataSourceModel(unittest.TestCase): +class ChangesSlotsModel(BaseDataSourceModel): + a = Int() + b = Int(changes_slots=True) + c = Int(changes_slots=False) + + +class TestBaseDataSourceModel(unittest.TestCase, UnittestTools): + def setUp(self): + self.mock_factory = mock.Mock(spec=BaseDataSourceFactory) + def test_getstate(self): - model = DummyDataSourceModel(mock.Mock(spec=BaseDataSourceFactory)) + model = DummyDataSourceModel(self.mock_factory) self.assertEqual( model.__getstate__(), { @@ -64,3 +76,15 @@ class TestBaseDataSourceModel(unittest.TestCase): } ] }) + + def test_changes_slots(self): + model = ChangesSlotsModel(self.mock_factory) + + with self.assertTraitDoesNotChange(model, "changes_slots"): + model.a = 5 + + with self.assertTraitChanges(model, "changes_slots"): + model.b = 5 + + with self.assertTraitDoesNotChange(model, "changes_slots"): + model.c = 5 diff --git a/force_bdss/local_traits.py b/force_bdss/local_traits.py index f759aed8596558d9783bae4e0ce1f35671f3059c..12c982542bad386d9874386c3df6f2ec1077ee45 100644 --- a/force_bdss/local_traits.py +++ b/force_bdss/local_traits.py @@ -1,4 +1,4 @@ -from traits.api import Regex, String +from traits.api import Regex, String, BaseInt #: Used for variable names, but allow also empty string as it's the default #: case and it will be present if the workflow is saved before actually @@ -8,3 +8,19 @@ Identifier = Regex(regex="(^[^\d\W]\w*\Z|^\Z)") #: Identifies a CUBA type with its key. At the moment a String with #: no validation, but will come later. CUBAType = String() + + +class PositiveInt(BaseInt): + """A positive integer trait.""" + + info_text = 'a positive integer' + + default_value = 1 + + def validate(self, object, name, value): + int_value = super(PositiveInt, self).validate(object, name, value) + + if int_value > 0: + return int_value + + self.error(object, name, value) diff --git a/force_bdss/mco/parameters/base_mco_parameter_factory.py b/force_bdss/mco/parameters/base_mco_parameter_factory.py index bc3a0fd078c0001f12267f426a22e363231815fa..03598138df2b680a0557a55727eae4acb67994e2 100644 --- a/force_bdss/mco/parameters/base_mco_parameter_factory.py +++ b/force_bdss/mco/parameters/base_mco_parameter_factory.py @@ -1,4 +1,4 @@ -from traits.api import Str, Type, Instance, provides +from traits.api import Type, Instance, provides from force_bdss.core.base_factory import BaseFactory from force_bdss.ids import mco_parameter_id @@ -20,20 +20,12 @@ class BaseMCOParameterFactory(BaseFactory): mco_factory = Instance('force_bdss.mco.base_mco_factory.BaseMCOFactory', allow_none=False) - #: A long description of the parameter - description = Str() - # The model class to instantiate when create_model is called. model_class = Type( "force_bdss.mco.parameters.base_mco_parameter.BaseMCOParameter", allow_none=False ) - def get_description(self): - raise NotImplementedError( - "get_description was not implemented in factory {}".format( - self.__class__)) - def get_model_class(self): raise NotImplementedError( "get_model_class was not implemented in factory {}".format( @@ -46,7 +38,6 @@ class BaseMCOParameterFactory(BaseFactory): *args, **kwargs) - self.description = self.get_description() self.model_class = self.get_model_class() def create_model(self, data_values=None): diff --git a/force_bdss/mco/parameters/tests/test_base_mco_parameter_factory.py b/force_bdss/mco/parameters/tests/test_base_mco_parameter_factory.py index abc6ee6f778eba999888f1bdbbc3b8ada0a704c6..6456ecef5c565b5aa26805a1926ab42f3f33abe2 100644 --- a/force_bdss/mco/parameters/tests/test_base_mco_parameter_factory.py +++ b/force_bdss/mco/parameters/tests/test_base_mco_parameter_factory.py @@ -28,7 +28,7 @@ class TestBaseMCOParameterFactory(unittest.TestCase): factory = DummyMCOParameterFactory(mco_factory=self.mco_factory) self.assertEqual(factory.id, "mcoid.parameter.dummy_mco_parameter") self.assertEqual(factory.name, "Dummy MCO parameter") - self.assertEqual(factory.description, "description") + self.assertEqual(factory.description, u"description") self.assertEqual(factory.model_class, DummyMCOParameter) self.assertIsInstance(factory.create_model(), DummyMCOParameter) diff --git a/force_bdss/mco/tests/test_base_mco_factory.py b/force_bdss/mco/tests/test_base_mco_factory.py index e070daf6841b0e899ce4eef1218c972a6875317a..03982dc5da2bdca09c514504d01781199f170a88 100644 --- a/force_bdss/mco/tests/test_base_mco_factory.py +++ b/force_bdss/mco/tests/test_base_mco_factory.py @@ -2,6 +2,7 @@ import unittest from traits.trait_errors import TraitError +from force_bdss.mco.base_mco_factory import BaseMCOFactory from force_bdss.mco.tests.test_base_mco import DummyMCO from force_bdss.mco.tests.test_base_mco_communicator import \ DummyMCOCommunicator @@ -15,6 +16,23 @@ except ImportError: from envisage.plugin import Plugin +class MCOFactory(BaseMCOFactory): + def get_identifier(self): + return "dummy_mco_2" + + def get_name(self): + return "Dummy MCO 2" + + def get_model_class(self): + return DummyMCOModel + + def get_communicator_class(self): + return DummyMCOCommunicator + + def get_optimizer_class(self): + return DummyMCO + + class TestBaseMCOFactory(unittest.TestCase): def setUp(self): self.plugin = mock.Mock(spec=Plugin, id="pid") @@ -30,6 +48,10 @@ class TestBaseMCOFactory(unittest.TestCase): self.assertIsInstance(factory.create_model(), DummyMCOModel) + def test_base_object_parameter_factories(self): + factory = MCOFactory(self.plugin) + self.assertEqual(factory.parameter_factories(), []) + def test_broken_get_identifier(self): class Broken(DummyMCOFactory): def get_identifier(self): diff --git a/force_bdss/tests/dummy_classes/mco.py b/force_bdss/tests/dummy_classes/mco.py index 289d88efc5d016ae36e9a825edea2ed3b7617bde..e65b182693604d3206c9a4c2cf222ea00e80d56f 100644 --- a/force_bdss/tests/dummy_classes/mco.py +++ b/force_bdss/tests/dummy_classes/mco.py @@ -37,7 +37,7 @@ class DummyMCOParameterFactory(BaseMCOParameterFactory): return "Dummy MCO parameter" def get_description(self): - return "description" + return u"description" def get_model_class(self): return DummyMCOParameter diff --git a/force_bdss/tests/test_bdss_application.py b/force_bdss/tests/test_bdss_application.py index 53e5a5cf2e3c3b82f0d7260d7a47018c2d672c68..b985316a62ff76e462601232f23baffac7bbe6ec 100644 --- a/force_bdss/tests/test_bdss_application.py +++ b/force_bdss/tests/test_bdss_application.py @@ -8,6 +8,8 @@ from force_bdss.bdss_application import ( _load_failure_callback, _import_extensions ) +from force_bdss.core.workflow import Workflow +from force_bdss.tests import fixtures try: import mock @@ -17,9 +19,10 @@ except ImportError: class TestBDSSApplication(unittest.TestCase): def test_initialization(self): - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - app = BDSSApplication(False, "foo/bar") + with testfixtures.LogCapture(): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + app = BDSSApplication(False, "foo/bar") self.assertFalse(app.evaluate) self.assertEqual(app.workflow_filepath, "foo/bar") @@ -47,3 +50,18 @@ class TestBDSSApplication(unittest.TestCase): _import_extensions(plugins, ext) self.assertEqual(plugins[0], plugin) + + def test_workflow(self): + with testfixtures.LogCapture(): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + app = BDSSApplication(False, fixtures.get("test_empty.json")) + + self.assertIsInstance(app.workflow, Workflow) + + with testfixtures.LogCapture(): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + app = BDSSApplication(True, fixtures.get("test_empty.json")) + + self.assertIsInstance(app.workflow, Workflow) diff --git a/force_bdss/tests/test_core_evaluation_driver.py b/force_bdss/tests/test_core_evaluation_driver.py index 734a08b173a09e35af26cc4a51d4593ced39600d..3f78414977f7853fab1e760d150f4f4843518251 100644 --- a/force_bdss/tests/test_core_evaluation_driver.py +++ b/force_bdss/tests/test_core_evaluation_driver.py @@ -3,19 +3,11 @@ import unittest import testfixtures import six -from force_bdss.core.execution_layer import ExecutionLayer -from force_bdss.core.kpi_specification import KPISpecification -from force_bdss.core.output_slot_info import OutputSlotInfo -from force_bdss.core.workflow import Workflow from force_bdss.tests.probe_classes.factory_registry_plugin import \ ProbeFactoryRegistryPlugin -from force_bdss.tests.probe_classes.data_source import ProbeDataSourceFactory -from force_bdss.core.input_slot_info import InputSlotInfo from force_bdss.core.data_value import DataValue -from force_bdss.core.slot import Slot from force_bdss.tests import fixtures -from force_bdss.tests.probe_classes.mco import ProbeMCOFactory try: import mock @@ -25,10 +17,7 @@ except ImportError: from envisage.api import Application from force_bdss.core_evaluation_driver import ( - CoreEvaluationDriver, - execute_workflow, - _bind_data_values, - _compute_layer_results + CoreEvaluationDriver ) @@ -131,207 +120,6 @@ class TestCoreEvaluationDriver(unittest.TestCase): " the number of user-defined names"): driver.application_started() - def test_bind_data_values(self): - data_values = [ - DataValue(name="foo"), - DataValue(name="bar"), - DataValue(name="baz") - ] - - slot_map = ( - InputSlotInfo(name="baz"), - InputSlotInfo(name="bar") - ) - - slots = ( - Slot(), - Slot() - ) - - result = _bind_data_values(data_values, slot_map, slots) - self.assertEqual(result[0], data_values[2]) - self.assertEqual(result[1], data_values[1]) - - # Check the errors. Only one slot map for two slots. - slot_map = ( - InputSlotInfo(name="baz"), - ) - - with testfixtures.LogCapture(): - with six.assertRaisesRegex( - self, - RuntimeError, - "The length of the slots is not equal to the length of" - " the slot map"): - _bind_data_values(data_values, slot_map, slots) - - # missing value in the given data values. - slot_map = ( - InputSlotInfo(name="blap"), - InputSlotInfo(name="bar") - ) - - with testfixtures.LogCapture(): - with six.assertRaisesRegex( - self, - RuntimeError, - "Unable to find requested name 'blap' in available" - " data values."): - _bind_data_values(data_values, slot_map, slots) - - def test_compute_layer_results(self): - data_values = [ - DataValue(name="foo"), - DataValue(name="bar"), - DataValue(name="baz"), - DataValue(name="quux") - ] - - def run(self, *args, **kwargs): - return [DataValue(value=1), DataValue(value=2), DataValue(value=3)] - - ds_factory = self.registry.data_source_factories[0] - ds_factory.input_slots_size = 2 - ds_factory.output_slots_size = 3 - ds_factory.run_function = run - evaluator_model = ds_factory.create_model() - - evaluator_model.input_slot_info = [ - InputSlotInfo(name="foo"), - InputSlotInfo(name="quux") - ] - evaluator_model.output_slot_info = [ - OutputSlotInfo(name="one"), - OutputSlotInfo(name=""), - OutputSlotInfo(name="three") - ] - - res = _compute_layer_results( - data_values, - ExecutionLayer(data_sources=[evaluator_model]), - ) - self.assertEqual(len(res), 2) - self.assertEqual(res[0].name, "one") - self.assertEqual(res[0].value, 1) - self.assertEqual(res[1].name, "three") - self.assertEqual(res[1].value, 3) - - def test_multilayer_execution(self): - # The multilayer peforms the following execution - # layer 0: in1 + in2 | in3 + in4 - # res1 res2 - # layer 1: res1 + res2 - # res3 - # layer 2: res3 * res1 - # res4 - # layer 3: res4 * res2 - # out1 - # Final result should be - # out1 = ((in1 + in2 + in3 + in4) * (in1 + in2) * (in3 + in4) - - data_values = [ - DataValue(value=10, name="in1"), - DataValue(value=15, name="in2"), - DataValue(value=3, name="in3"), - DataValue(value=7, name="in4") - ] - - def adder(model, parameters): - - first = parameters[0].value - second = parameters[1].value - return [DataValue(value=(first+second))] - - adder_factory = ProbeDataSourceFactory( - self.plugin, - input_slots_size=2, - output_slots_size=1, - run_function=adder) - - def multiplier(model, parameters): - first = parameters[0].value - second = parameters[1].value - return [DataValue(value=(first*second))] - - multiplier_factory = ProbeDataSourceFactory( - self.plugin, - input_slots_size=2, - output_slots_size=1, - run_function=multiplier) - - mco_factory = ProbeMCOFactory(self.plugin) - mco_model = mco_factory.create_model() - mco_model.kpis = [ - KPISpecification(name="out1") - ] - - wf = Workflow( - mco=mco_model, - execution_layers=[ - ExecutionLayer(), - ExecutionLayer(), - ExecutionLayer(), - ExecutionLayer() - ] - ) - # Layer 0 - model = adder_factory.create_model() - model.input_slot_info = [ - InputSlotInfo(name="in1"), - InputSlotInfo(name="in2") - ] - model.output_slot_info = [ - OutputSlotInfo(name="res1") - ] - wf.execution_layers[0].data_sources.append(model) - - model = adder_factory.create_model() - model.input_slot_info = [ - InputSlotInfo(name="in3"), - InputSlotInfo(name="in4") - ] - model.output_slot_info = [ - OutputSlotInfo(name="res2") - ] - wf.execution_layers[0].data_sources.append(model) - - # layer 1 - model = adder_factory.create_model() - model.input_slot_info = [ - InputSlotInfo(name="res1"), - InputSlotInfo(name="res2") - ] - model.output_slot_info = [ - OutputSlotInfo(name="res3") - ] - wf.execution_layers[1].data_sources.append(model) - - # layer 2 - model = multiplier_factory.create_model() - model.input_slot_info = [ - InputSlotInfo(name="res3"), - InputSlotInfo(name="res1") - ] - model.output_slot_info = [ - OutputSlotInfo(name="res4") - ] - wf.execution_layers[2].data_sources.append(model) - - # layer 3 - model = multiplier_factory.create_model() - model.input_slot_info = [ - InputSlotInfo(name="res4"), - InputSlotInfo(name="res2") - ] - model.output_slot_info = [ - OutputSlotInfo(name="out1") - ] - wf.execution_layers[3].data_sources.append(model) - - kpi_results = execute_workflow(wf, data_values) - self.assertEqual(len(kpi_results), 1) - self.assertEqual(kpi_results[0].value, 8750) - def test_mco_communicator_broken(self): self.registry.mco_factories[0].raises_on_create_communicator = True driver = CoreEvaluationDriver( @@ -367,9 +155,9 @@ class TestCoreEvaluationDriver(unittest.TestCase): 'Creating communicator'), ('force_bdss.core_evaluation_driver', 'INFO', 'Received data from MCO: \n whatever = 1.0 (AVERAGE)'), - ('force_bdss.core_evaluation_driver', 'INFO', + ('force_bdss.core.execution', 'INFO', 'Computing data layer 0'), - ('force_bdss.core_evaluation_driver', 'ERROR', + ('force_bdss.core.execution', 'ERROR', 'Unable to create data source from factory ' "'force.bdss.enthought.plugin.test.v0" ".factory.probe_data_source' in plugin " diff --git a/force_bdss/tests/test_local_traits.py b/force_bdss/tests/test_local_traits.py index d61e6d2c89520653d65431a0030ad89b362f18be..7d0c59c991be882835312c0ce661e24761f7affc 100644 --- a/force_bdss/tests/test_local_traits.py +++ b/force_bdss/tests/test_local_traits.py @@ -1,12 +1,13 @@ import unittest from traits.api import HasStrictTraits, TraitError -from force_bdss.local_traits import Identifier, CUBAType +from force_bdss.local_traits import Identifier, CUBAType, PositiveInt class Traited(HasStrictTraits): val = Identifier() cuba = CUBAType() + positive_int = PositiveInt() class TestLocalTraits(unittest.TestCase): @@ -25,3 +26,13 @@ class TestLocalTraits(unittest.TestCase): c = Traited() c.cuba = "PRESSURE" self.assertEqual(c.cuba, "PRESSURE") + + def test_positive_int(self): + c = Traited() + with self.assertRaises(TraitError): + c.positive_int = 0 + + with self.assertRaises(TraitError): + c.positive_int = -1 + + c.positive_int = 3