Skip to content
Snippets Groups Projects
Commit 2da2873c authored by Stefano Borini's avatar Stefano Borini
Browse files

Testing of the core engine

parent 63573d48
No related branches found
No related tags found
1 merge request!69Introduce slots and resolution of named variables
from traits.api import HasStrictTraits, String from traits.api import HasStrictTraits, String
from ..local_traits import CUBAType
class Slot(HasStrictTraits): class Slot(HasStrictTraits):
"""Describes an input or output slot in the DataSource or """
KPICalculator""" Describes an input or output slot in the DataSource or
KPICalculator. If the DataSource and KPICalculator are functions, slots
define their argument number and types they need as input and what
they return as output.
"""
#: A textual description of the slot #: A textual description of the slot
description = String("No description") description = String("No description")
#: The CUBA key of the slot #: The CUBA key of the slot
type = String() type = CUBAType()
from __future__ import print_function from __future__ import print_function
import sys import sys
import logging
from traits.api import on_trait_change from traits.api import on_trait_change
from .ids import plugin_id from .ids import plugin_id
...@@ -31,87 +33,185 @@ class CoreEvaluationDriver(BaseCoreDriver): ...@@ -31,87 +33,185 @@ class CoreEvaluationDriver(BaseCoreDriver):
mco_bundle = mco_model.bundle mco_bundle = mco_model.bundle
mco_communicator = mco_bundle.create_communicator() mco_communicator = mco_bundle.create_communicator()
# Receives the data from the MCO. These are technically unnamed. mco_data_values = self._get_data_values_from_mco(mco_model,
# The names are then assigned. Order is important mco_communicator)
mco_data_values = mco_communicator.receive_from_mco(mco_model)
if len(mco_data_values) != len(mco_model.parameters): ds_results = self._compute_ds_results(
raise RuntimeError("The number of data values returned by" mco_data_values,
" the MCO does not match the number of" workflow)
" parameters specified. This is likely a"
" MCO plugin error.")
# Assign the name to the data value that was emitted. kpi_results = self._compute_kpi_results(
for dv, param in zip(mco_data_values, mco_model.parameters): ds_results + mco_data_values,
dv.name = param.name workflow)
mco_communicator.send_to_mco(mco_model, kpi_results)
def _compute_ds_results(self, environment_data_values, workflow):
"""Helper routine.
Performs the evaluation of the DataSources, passing the current
environment data values (the MCO data)
"""
ds_results = [] ds_results = []
for ds_model in workflow.data_sources: for ds_model in workflow.data_sources:
ds_bundle = ds_model.bundle ds_bundle = ds_model.bundle
data_source = ds_bundle.create_data_source() data_source = ds_bundle.create_data_source()
# Get the slots for this data source. These must be matched to
# the appropriate values in the environment data values.
# Matching is by position.
in_slots, out_slots = data_source.slots(ds_model) in_slots, out_slots = data_source.slots(ds_model)
# Binding performs the extraction of the specified data values
# satisfying the above input slots from the environment data values
# considering what the user specified in terms of names (which is
# in the model input slot maps.
# The resulting data are the ones picked by name from the
# environment data values, and in the appropriate ordering as
# needed by the input slots.
passed_data_values = self._bind_data_values( passed_data_values = self._bind_data_values(
mco_data_values, environment_data_values,
ds_model.input_slot_maps, ds_model.input_slot_maps,
in_slots) in_slots)
# execute data source, passing only relevant data values.
logging.info("Evaluating for Data Source {}".format(
ds_bundle.name))
res = data_source.run(ds_model, passed_data_values) res = data_source.run(ds_model, passed_data_values)
if len(res) != len(out_slots): if len(res) != len(out_slots):
raise RuntimeError("The number of data values returned by" error_txt = (
" the DataSource does not match the number" "The number of data values ({} values) returned"
" of parameters specified. This is likely a" " by the DataSource '{}' does not match the number"
" DataSource plugin error.") " of output slots it specifies ({} values)."
" This is likely a DataSource plugin error.").format(
len(res), ds_bundle.name, len(out_slots)
)
if len(res) != len(ds_model.output_slot_names): logging.error(error_txt)
raise RuntimeError("The number of data values returned by" raise RuntimeError(error_txt)
" the DataSource does not match the number"
" of names specified. This is either an"
" input file error or a plugin error.")
if len(res) != len(ds_model.output_slot_names):
error_txt = (
"The number of data values ({} values) returned"
" by the DataSource '{}' does not match the number"
" of user-defined names specified ({} values)."
" This is likely a DataSource plugin error.").format(
len(res),
ds_bundle.name,
len(ds_model.output_slot_names)
)
logging.error(error_txt)
raise RuntimeError(error_txt)
# At this point, the returned data values are unnamed.
# Add the names as specified by the user.
for dv, output_slot_name in zip(res, ds_model.output_slot_names): for dv, output_slot_name in zip(res, ds_model.output_slot_names):
dv.name = output_slot_name dv.name = output_slot_name
ds_results.extend(res) ds_results.extend(res)
# Finally, return all the computed data values from all data sources,
# properly named.
return ds_results
def _compute_kpi_results(self, environment_data_values, workflow):
"""Perform evaluation of all KPI calculators.
environment_data_values contains all data values provided from
the MCO and data sources.
"""
kpi_results = [] kpi_results = []
for kpic_model in workflow.kpi_calculators: for kpic_model in workflow.kpi_calculators:
kpic_bundle = kpic_model.bundle kpic_bundle = kpic_model.bundle
kpi_calculator = kpic_bundle.create_kpi_calculator() kpi_calculator = kpic_bundle.create_kpi_calculator()
in_slots, out_slots = kpi_calculator.slots(kpic_model) in_slots, out_slots = kpi_calculator.slots(kpic_model)
passed_data_values = self._bind_data_values( passed_data_values = self._bind_data_values(
mco_data_values+ds_results, environment_data_values,
kpic_model.input_slot_maps, kpic_model.input_slot_maps,
in_slots) in_slots)
logging.info("Evaluating for KPICalculator {}".format(
kpic_bundle.name))
res = kpi_calculator.run(kpic_model, passed_data_values) res = kpi_calculator.run(kpic_model, passed_data_values)
if len(res) != len(out_slots): if len(res) != len(out_slots):
raise RuntimeError("The number of data values returned by" error_txt = (
" the KPICalculator does not match the" "The number of data values ({} values) returned by"
" number of parameters specified. This is" " the KPICalculator '{}' does not match the"
" likely a KPICalculator plugin error.") " number of output slots ({} values). This is"
" likely a KPICalculator plugin error."
).format(len(res), kpic_bundle.name, len(out_slots))
logging.error(error_txt)
raise RuntimeError(error_txt)
if len(res) != len(kpic_model.output_slot_names): if len(res) != len(kpic_model.output_slot_names):
raise RuntimeError("The number of data values returned by" error_txt = (
" the KPICalculator does not match the" "The number of data values ({} values) returned by"
" number of names specified. This is" " the KPICalculator '{}' does not match the"
" either an input file error or a plugin" " number of user-defined names specified ({} values)."
" error.") " This is either an input file error or a plugin"
" error."
for kpi, output_slot_name in zip(res, ).format(len(res), kpic_bundle.name,
kpic_model.output_slot_names): len(kpic_model.output_slot_names))
logging.error(error_txt)
raise RuntimeError(error_txt)
for kpi, output_slot_name in zip(
res, kpic_model.output_slot_names):
kpi.name = output_slot_name kpi.name = output_slot_name
kpi_results.extend(res) kpi_results.extend(res)
mco_communicator.send_to_mco(mco_model, kpi_results) return kpi_results
def _get_data_values_from_mco(self, model, communicator):
"""Helper method.
Receives the data (in order) from the MCO, and bind them to the
specified names as from the model.
Parameters
----------
model: BaseMCOModel
the MCO model (where the user-defined variable names are specified)
communicator: BaseMCOCommunicator
The communicator that produces the (temporarily unnamed) datavalues
from the MCO.
"""
mco_data_values = communicator.receive_from_mco(model)
if len(mco_data_values) != len(model.parameters):
error_txt = ("The number of data values returned by"
" the MCO ({} values) does not match the"
" number of parameters specified ({} values)."
" This is either a MCO plugin error or the workflow"
" file is corrupted.").format(
len(mco_data_values), len(model.parameters)
)
logging.error(error_txt)
raise RuntimeError(error_txt)
# The data values obtained by the communicator are unnamed.
# Assign the name to each datavalue as specified by the user.
for dv, param in zip(mco_data_values, model.parameters):
dv.name = param.name
return mco_data_values
def _bind_data_values(self, def _bind_data_values(self,
available_data_values, available_data_values,
model_slot_map, model_slot_map,
slots): slots):
"""
Given the named data values in the environment, the slots a given
data source expects, and the user-specified names for each of these
slots, returns those data values with the requested names, ordered
in the correct order as specified by the slot map.
"""
passed_data_values = [] passed_data_values = []
lookup_map = {dv.name: dv for dv in available_data_values} lookup_map = {dv.name: dv for dv in available_data_values}
......
from traits.api import Regex from traits.api import Regex, String
#: Used for variable names, but allow also empty string as it's the default #: Used for variable names, but allow also empty string as it's the default
#: case and it will be present if the workflow is saved before actually #: case and it will be present if the workflow is saved before actually
#: specifying the value. #: specifying the value.
Identifier = Regex(regex="(^[^\d\W]\w*\Z|^\Z)") Identifier = Regex(regex="(^[^\d\W]\w*\Z|^\Z)")
#: Identifies a CUBA type with its key. At the moment a String with
#: no validation, but will come later.
CUBAType = String()
{
"version": "1",
"workflow": {
"mco": {
"id": "force.bdss.enthought.bundle.null_mco",
"model_data": {
"parameters" : [
]
}
},
"data_sources": [
{
"id": "force.bdss.enthought.bundle.null_ds",
"model_data": {
"input_slot_maps": [
],
"output_slot_names": [
]
}
}
],
"kpi_calculators": [
{
"id": "force.bdss.enthought.bundle.null_kpic",
"model_data": {
"input_slot_maps": [
],
"output_slot_names": [
]
}
}
]
}
}
import unittest import unittest
from traits.api import Float from traits.api import Float, List
from force_bdss.bundle_registry_plugin import BundleRegistryPlugin from force_bdss.bundle_registry_plugin import BundleRegistryPlugin
from force_bdss.core.data_value import DataValue
from force_bdss.core.slot import Slot
from force_bdss.data_sources.base_data_source import BaseDataSource from force_bdss.data_sources.base_data_source import BaseDataSource
from force_bdss.data_sources.base_data_source_bundle import \ from force_bdss.data_sources.base_data_source_bundle import \
BaseDataSourceBundle BaseDataSourceBundle
...@@ -37,15 +39,15 @@ class NullMCO(BaseMCO): ...@@ -37,15 +39,15 @@ class NullMCO(BaseMCO):
pass pass
class NullParameter(BaseMCOParameter): class RangedParameter(BaseMCOParameter):
initial_value = Float() initial_value = Float()
lower_bound = Float() lower_bound = Float()
upper_bound = Float() upper_bound = Float()
class NullParameterFactory(BaseMCOParameterFactory): class RangedParameterFactory(BaseMCOParameterFactory):
id = mco_parameter_id("enthought", "dummy_dakota", "ranged") id = mco_parameter_id("enthought", "null_mco", "null")
model_class = NullParameter model_class = RangedParameter
class NullMCOCommunicator(BaseMCOCommunicator): class NullMCOCommunicator(BaseMCOCommunicator):
...@@ -56,11 +58,23 @@ class NullMCOCommunicator(BaseMCOCommunicator): ...@@ -56,11 +58,23 @@ class NullMCOCommunicator(BaseMCOCommunicator):
return [] return []
class OneDataValueMCOCommunicator(BaseMCOCommunicator):
"""A communicator that returns one single datavalue, for testing purposes.
"""
def send_to_mco(self, model, kpi_results):
pass
def receive_from_mco(self, model):
return [
DataValue()
]
class NullMCOBundle(BaseMCOBundle): class NullMCOBundle(BaseMCOBundle):
id = bundle_id("enthought", "dummy_dakota") id = bundle_id("enthought", "null_mco")
def create_model(self, model_data=None): def create_model(self, model_data=None):
return NullMCOModel(self) return NullMCOModel(self, **model_data)
def create_communicator(self): def create_communicator(self):
return NullMCOCommunicator(self) return NullMCOCommunicator(self)
...@@ -69,7 +83,7 @@ class NullMCOBundle(BaseMCOBundle): ...@@ -69,7 +83,7 @@ class NullMCOBundle(BaseMCOBundle):
return NullMCO(self) return NullMCO(self)
def parameter_factories(self): def parameter_factories(self):
return [NullParameterFactory(self)] return []
class NullKPICalculatorModel(BaseKPICalculatorModel): class NullKPICalculatorModel(BaseKPICalculatorModel):
...@@ -84,7 +98,26 @@ class NullKPICalculator(BaseKPICalculator): ...@@ -84,7 +98,26 @@ class NullKPICalculator(BaseKPICalculator):
return (), () return (), ()
class BrokenOneValueKPICalculator(BaseKPICalculator):
def run(self, model, data_source_results):
return [DataValue()]
def slots(self, model):
return (), ()
class OneValueKPICalculator(BaseKPICalculator):
def run(self, model, data_source_results):
return [DataValue()]
def slots(self, model):
return (), (Slot(), )
class NullKPICalculatorBundle(BaseKPICalculatorBundle): class NullKPICalculatorBundle(BaseKPICalculatorBundle):
id = bundle_id("enthought", "null_kpic")
name = "null_kpic"
def create_model(self, model_data=None): def create_model(self, model_data=None):
return NullKPICalculatorModel(self) return NullKPICalculatorModel(self)
...@@ -104,7 +137,32 @@ class NullDataSource(BaseDataSource): ...@@ -104,7 +137,32 @@ class NullDataSource(BaseDataSource):
return (), () return (), ()
class BrokenOneValueDataSource(BaseDataSource):
"""Incorrect data source implementation whose run returns a data value
but no slot was specified for it."""
def run(self, model, parameters):
return [DataValue()]
def slots(self, model):
return (), ()
class OneValueDataSource(BaseDataSource):
"""Incorrect data source implementation whose run returns a data value
but no slot was specified for it."""
def run(self, model, parameters):
return [DataValue()]
def slots(self, model):
return (), (
Slot(),
)
class NullDataSourceBundle(BaseDataSourceBundle): class NullDataSourceBundle(BaseDataSourceBundle):
id = bundle_id("enthought", "null_ds")
name = "null_ds"
def create_model(self, model_data=None): def create_model(self, model_data=None):
return NullDataSourceModel(self) return NullDataSourceModel(self)
...@@ -112,16 +170,20 @@ class NullDataSourceBundle(BaseDataSourceBundle): ...@@ -112,16 +170,20 @@ class NullDataSourceBundle(BaseDataSourceBundle):
return NullDataSource(self) return NullDataSource(self)
class DummyBundleRegistryPlugin(BundleRegistryPlugin):
mco_bundles = List()
kpi_calculator_bundles = List()
data_source_bundles = List()
def mock_bundle_registry_plugin(): def mock_bundle_registry_plugin():
bundle_registry_plugin = mock.Mock(spec=BundleRegistryPlugin) bundle_registry_plugin = DummyBundleRegistryPlugin()
bundle_registry_plugin.mco_bundles = [ bundle_registry_plugin.mco_bundles = [
NullMCOBundle(bundle_registry_plugin)] NullMCOBundle(bundle_registry_plugin)]
bundle_registry_plugin.mco_bundle_by_id = mock.Mock( bundle_registry_plugin.kpi_calculator_bundles = [
return_value=NullMCOBundle(bundle_registry_plugin)) NullKPICalculatorBundle(bundle_registry_plugin)]
bundle_registry_plugin.kpi_calculator_bundle_by_id = mock.Mock( bundle_registry_plugin.data_source_bundles = [
return_value=NullKPICalculatorBundle(bundle_registry_plugin)) NullDataSourceBundle(bundle_registry_plugin)]
bundle_registry_plugin.data_source_bundle_by_id = mock.Mock(
return_value=NullDataSourceBundle(bundle_registry_plugin))
return bundle_registry_plugin return bundle_registry_plugin
...@@ -132,7 +194,7 @@ class TestCoreEvaluationDriver(unittest.TestCase): ...@@ -132,7 +194,7 @@ class TestCoreEvaluationDriver(unittest.TestCase):
application.get_plugin = mock.Mock( application.get_plugin = mock.Mock(
return_value=self.mock_bundle_registry_plugin return_value=self.mock_bundle_registry_plugin
) )
application.workflow_filepath = fixtures.get("test_csv.json") application.workflow_filepath = fixtures.get("test_null.json")
self.mock_application = application self.mock_application = application
def test_initialization(self): def test_initialization(self):
...@@ -140,3 +202,77 @@ class TestCoreEvaluationDriver(unittest.TestCase): ...@@ -140,3 +202,77 @@ class TestCoreEvaluationDriver(unittest.TestCase):
application=self.mock_application, application=self.mock_application,
) )
driver.application_started() driver.application_started()
def test_error_for_non_matching_mco_parameters(self):
bundle = self.mock_bundle_registry_plugin.mco_bundles[0]
with mock.patch.object(bundle.__class__,
"create_communicator") as create_comm:
create_comm.return_value = OneDataValueMCOCommunicator(
bundle)
driver = CoreEvaluationDriver(
application=self.mock_application,
)
with self.assertRaisesRegexp(
RuntimeError,
"The number of data values returned by the MCO"):
driver.application_started()
def test_error_for_incorrect_output_slots(self):
bundle = self.mock_bundle_registry_plugin.data_source_bundles[0]
with mock.patch.object(bundle.__class__,
"create_data_source") as create_ds:
create_ds.return_value = BrokenOneValueDataSource(bundle)
driver = CoreEvaluationDriver(
application=self.mock_application,
)
with self.assertRaisesRegexp(
RuntimeError,
"The number of data values \(1 values\)"
" returned by the DataSource 'null_ds' does not match"
" the number of output slots"):
driver.application_started()
def test_error_for_missing_ds_output_names(self):
bundle = self.mock_bundle_registry_plugin.data_source_bundles[0]
with mock.patch.object(bundle.__class__,
"create_data_source") as create_ds:
create_ds.return_value = OneValueDataSource(bundle)
driver = CoreEvaluationDriver(
application=self.mock_application,
)
with self.assertRaisesRegexp(
RuntimeError,
"The number of data values \(1 values\)"
" returned by the DataSource 'null_ds' does not match"
" the number of user-defined names"):
driver.application_started()
def test_error_for_incorrect_kpic_output_slots(self):
bundle = self.mock_bundle_registry_plugin.kpi_calculator_bundles[0]
with mock.patch.object(bundle.__class__,
"create_kpi_calculator") as create_kpic:
create_kpic.return_value = BrokenOneValueKPICalculator(bundle)
driver = CoreEvaluationDriver(
application=self.mock_application,
)
with self.assertRaisesRegexp(
RuntimeError,
"The number of data values \(1 values\)"
" returned by the KPICalculator 'null_kpic' does not match"
" the number of output slots"):
driver.application_started()
def test_error_for_missing_kpic_output_names(self):
bundle = self.mock_bundle_registry_plugin.kpi_calculator_bundles[0]
with mock.patch.object(bundle.__class__,
"create_kpi_calculator") as create_kpic:
create_kpic.return_value = OneValueKPICalculator(bundle)
driver = CoreEvaluationDriver(
application=self.mock_application,
)
with self.assertRaisesRegexp(
RuntimeError,
"The number of data values \(1 values\)"
" returned by the KPICalculator 'null_kpic' does not match"
" the number of user-defined names"):
driver.application_started()
import unittest import unittest
from traits.api import HasStrictTraits, TraitError from traits.api import HasStrictTraits, TraitError
from force_bdss.local_traits import Identifier from force_bdss.local_traits import Identifier, CUBAType
class Traited(HasStrictTraits): class Traited(HasStrictTraits):
val = Identifier() val = Identifier()
cuba = CUBAType()
class TestLocalTraits(unittest.TestCase): class TestLocalTraits(unittest.TestCase):
...@@ -19,3 +20,8 @@ class TestLocalTraits(unittest.TestCase): ...@@ -19,3 +20,8 @@ class TestLocalTraits(unittest.TestCase):
for broken in ["0", None, 123, "0hello", "hi$", "hi%"]: for broken in ["0", None, 123, "0hello", "hi$", "hi%"]:
with self.assertRaises(TraitError): with self.assertRaises(TraitError):
c.val = broken c.val = broken
def test_cuba_type(self):
c = Traited()
c.cuba = "PRESSURE"
self.assertEqual(c.cuba, "PRESSURE")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment