Skip to content
Snippets Groups Projects
Commit 8cf53e00 authored by Stefano Borini's avatar Stefano Borini
Browse files

Extracted methods as functions and tested their behavior with the new change

parent 993c115e
No related branches found
No related tags found
1 merge request!78Discard data value for empty name.
...@@ -33,16 +33,16 @@ class CoreEvaluationDriver(BaseCoreDriver): ...@@ -33,16 +33,16 @@ class CoreEvaluationDriver(BaseCoreDriver):
mco_factory = mco_model.factory mco_factory = mco_model.factory
mco_communicator = mco_factory.create_communicator() mco_communicator = mco_factory.create_communicator()
mco_data_values = self._get_data_values_from_mco(mco_model, mco_data_values = _get_data_values_from_mco(
mco_communicator) mco_model, mco_communicator)
ds_results = self._compute_layer_results( ds_results = _compute_layer_results(
mco_data_values, mco_data_values,
workflow.data_sources, workflow.data_sources,
"create_data_source" "create_data_source"
) )
kpi_results = self._compute_layer_results( kpi_results = _compute_layer_results(
ds_results + mco_data_values, ds_results + mco_data_values,
workflow.kpi_calculators, workflow.kpi_calculators,
"create_kpi_calculator" "create_kpi_calculator"
...@@ -50,153 +50,162 @@ class CoreEvaluationDriver(BaseCoreDriver): ...@@ -50,153 +50,162 @@ class CoreEvaluationDriver(BaseCoreDriver):
mco_communicator.send_to_mco(mco_model, kpi_results) mco_communicator.send_to_mco(mco_model, kpi_results)
def _compute_layer_results(self,
environment_data_values, def _compute_layer_results(environment_data_values,
evaluator_models, evaluator_models,
creator_method_name creator_method_name
): ):
"""Helper routine. """Helper routine.
Performs the evaluation of a single layer. Performs the evaluation of a single layer.
At the moment we have a single layer of DataSources followed At the moment we have a single layer of DataSources followed
by a single layer of KPI calculators. by a single layer of KPI calculators.
Parameters Parameters
---------- ----------
environment_data_values: list environment_data_values: list
A list of data values to submit to the evaluators. A list of data values to submit to the evaluators.
evaluator_models: list evaluator_models: list
A list of the models for all the evaluators (data source A list of the models for all the evaluators (data source
or kpi calculator) or kpi calculator)
creator_method_name: str creator_method_name: str
A string of the creator method for the evaluator on the A string of the creator method for the evaluator on the
factory (e.g. create_kpi_calculator) factory (e.g. create_kpi_calculator)
NOTE: The above parameter is going to go away as soon as we move NOTE: The above parameter is going to go away as soon as we move
to unlimited layers and remove the distinction between data sources to unlimited layers and remove the distinction between data sources
and KPI calculators. and KPI calculators.
""" """
results = [] results = []
for model in evaluator_models: for model in evaluator_models:
factory = model.factory factory = model.factory
evaluator = getattr(factory, creator_method_name)() evaluator = getattr(factory, creator_method_name)()
# Get the slots for this data source. These must be matched to # Get the slots for this data source. These must be matched to
# the appropriate values in the environment data values. # the appropriate values in the environment data values.
# Matching is by position. # Matching is by position.
in_slots, out_slots = evaluator.slots(model) in_slots, out_slots = evaluator.slots(model)
# Binding performs the extraction of the specified data values # Binding performs the extraction of the specified data values
# satisfying the above input slots from the environment data values # satisfying the above input slots from the environment data values
# considering what the user specified in terms of names (which is # considering what the user specified in terms of names (which is
# in the model input slot maps. # in the model input slot maps.
# The resulting data are the ones picked by name from the # The resulting data are the ones picked by name from the
# environment data values, and in the appropriate ordering as # environment data values, and in the appropriate ordering as
# needed by the input slots. # needed by the input slots.
passed_data_values = self._bind_data_values( passed_data_values = _bind_data_values(
environment_data_values, environment_data_values,
model.input_slot_maps, model.input_slot_maps,
in_slots) in_slots)
# execute data source, passing only relevant data values. # execute data source, passing only relevant data values.
logging.info("Evaluating for Data Source {}".format( logging.info("Evaluating for Data Source {}".format(
factory.name)) factory.name))
res = evaluator.run(model, passed_data_values) res = evaluator.run(model, passed_data_values)
if len(res) != len(out_slots): if len(res) != len(out_slots):
error_txt = ( error_txt = (
"The number of data values ({} values) returned" "The number of data values ({} values) returned"
" by '{}' does not match the number" " by '{}' does not match the number"
" of output slots it specifies ({} values)." " of output slots it specifies ({} values)."
" This is likely a plugin error.").format( " This is likely a plugin error.").format(
len(res), factory.name, len(out_slots) len(res), factory.name, len(out_slots)
) )
logging.error(error_txt) logging.error(error_txt)
raise RuntimeError(error_txt) raise RuntimeError(error_txt)
if len(res) != len(model.output_slot_names): if len(res) != len(model.output_slot_names):
error_txt = ( error_txt = (
"The number of data values ({} values) returned" "The number of data values ({} values) returned"
" by '{}' does not match the number" " by '{}' does not match the number"
" of user-defined names specified ({} values)." " of user-defined names specified ({} values)."
" This is either a plugin error or a file" " This is either a plugin error or a file"
" error.").format( " error.").format(
len(res), len(res),
factory.name, factory.name,
len(model.output_slot_names) len(model.output_slot_names)
)
logging.error(error_txt)
raise RuntimeError(error_txt)
# At this point, the returned data values are unnamed.
# Add the names as specified by the user.
for dv, output_slot_name in zip(res, model.output_slot_names):
dv.name = output_slot_name
# If the name was not specified, simply discard the value,
# because apparently the user is not interested in it.
results.extend([r for r in res if r.name != ""])
# Finally, return all the computed data values from all evaluators,
# properly named.
return results
def _get_data_values_from_mco(self, model, communicator):
"""Helper method.
Receives the data (in order) from the MCO, and bind them to the
specified names as from the model.
Parameters
----------
model: BaseMCOModel
the MCO model (where the user-defined variable names are specified)
communicator: BaseMCOCommunicator
The communicator that produces the (temporarily unnamed) datavalues
from the MCO.
"""
mco_data_values = communicator.receive_from_mco(model)
if len(mco_data_values) != len(model.parameters):
error_txt = ("The number of data values returned by"
" the MCO ({} values) does not match the"
" number of parameters specified ({} values)."
" This is either a MCO plugin error or the workflow"
" file is corrupted.").format(
len(mco_data_values), len(model.parameters)
) )
logging.error(error_txt) logging.error(error_txt)
raise RuntimeError(error_txt) raise RuntimeError(error_txt)
# The data values obtained by the communicator are unnamed. # At this point, the returned data values are unnamed.
# Assign the name to each datavalue as specified by the user. # Add the names as specified by the user.
for dv, param in zip(mco_data_values, model.parameters): for dv, output_slot_name in zip(res, model.output_slot_names):
dv.name = param.name dv.name = output_slot_name
return mco_data_values # If the name was not specified, simply discard the value,
# because apparently the user is not interested in it.
def _bind_data_values(self, results.extend([r for r in res if r.name != ""])
available_data_values,
model_slot_map, # Finally, return all the computed data values from all evaluators,
slots): # properly named.
""" return results
Given the named data values in the environment, the slots a given
data source expects, and the user-specified names for each of these
slots, returns those data values with the requested names, ordered
in the correct order as specified by the slot map.
"""
passed_data_values = []
lookup_map = {dv.name: dv for dv in available_data_values}
if len(slots) != len(model_slot_map):
raise RuntimeError("The length of the slots is not equal to"
" the length of the slot map. This may"
" indicate a file error.")
def _get_data_values_from_mco(model, communicator):
"""Helper method.
Receives the data (in order) from the MCO, and bind them to the
specified names as from the model.
Parameters
----------
model: BaseMCOModel
the MCO model (where the user-defined variable names are specified)
communicator: BaseMCOCommunicator
The communicator that produces the (temporarily unnamed) datavalues
from the MCO.
"""
mco_data_values = communicator.receive_from_mco(model)
if len(mco_data_values) != len(model.parameters):
error_txt = ("The number of data values returned by"
" the MCO ({} values) does not match the"
" number of parameters specified ({} values)."
" This is either a MCO plugin error or the workflow"
" file is corrupted.").format(
len(mco_data_values), len(model.parameters)
)
logging.error(error_txt)
raise RuntimeError(error_txt)
# The data values obtained by the communicator are unnamed.
# Assign the name to each datavalue as specified by the user.
for dv, param in zip(mco_data_values, model.parameters):
dv.name = param.name
# Exclude those who have no name set.
return [dv for dv in mco_data_values if dv.name != ""]
def _bind_data_values(available_data_values,
model_slot_map,
slots):
"""
Given the named data values in the environment, the slots a given
data source expects, and the user-specified names for each of these
slots, returns those data values with the requested names, ordered
in the correct order as specified by the slot map.
"""
passed_data_values = []
lookup_map = {dv.name: dv for dv in available_data_values}
if len(slots) != len(model_slot_map):
raise RuntimeError("The length of the slots is not equal to"
" the length of the slot map. This may"
" indicate a file error.")
try:
for slot, slot_map in zip(slots, model_slot_map): for slot, slot_map in zip(slots, model_slot_map):
passed_data_values.append(lookup_map[slot_map.name]) passed_data_values.append(lookup_map[slot_map.name])
except KeyError:
return passed_data_values raise RuntimeError(
"Unable to find requested name '{}' in available "
"data values. Current data value names: {}".format(
slot_map.name,
list(lookup_map.keys())))
return passed_data_values
import unittest import unittest
from traits.api import Float, List from traits.api import Float, List
from force_bdss.core.input_slot_map import InputSlotMap
from force_bdss.factory_registry_plugin import FactoryRegistryPlugin from force_bdss.factory_registry_plugin import FactoryRegistryPlugin
from force_bdss.core.data_value import DataValue from force_bdss.core.data_value import DataValue
from force_bdss.core.slot import Slot from force_bdss.core.slot import Slot
...@@ -27,7 +30,8 @@ except ImportError: ...@@ -27,7 +30,8 @@ except ImportError:
from envisage.api import Application from envisage.api import Application
from force_bdss.core_evaluation_driver import CoreEvaluationDriver from force_bdss.core_evaluation_driver import CoreEvaluationDriver, \
_bind_data_values, _compute_layer_results
class NullMCOModel(BaseMCOModel): class NullMCOModel(BaseMCOModel):
...@@ -159,6 +163,19 @@ class OneValueDataSource(BaseDataSource): ...@@ -159,6 +163,19 @@ class OneValueDataSource(BaseDataSource):
) )
class TwoInputsThreeOutputsDataSource(BaseDataSource):
"""Incorrect data source implementation whose run returns a data value
but no slot was specified for it."""
def run(self, model, parameters):
return [DataValue(value=1), DataValue(value=2), DataValue(value=3)]
def slots(self, model):
return (
(Slot(), Slot()),
(Slot(), Slot(), Slot())
)
class NullDataSourceFactory(BaseDataSourceFactory): class NullDataSourceFactory(BaseDataSourceFactory):
id = factory_id("enthought", "null_ds") id = factory_id("enthought", "null_ds")
name = "null_ds" name = "null_ds"
...@@ -276,3 +293,84 @@ class TestCoreEvaluationDriver(unittest.TestCase): ...@@ -276,3 +293,84 @@ class TestCoreEvaluationDriver(unittest.TestCase):
" returned by 'null_kpic' does not match" " returned by 'null_kpic' does not match"
" the number of user-defined names"): " the number of user-defined names"):
driver.application_started() driver.application_started()
def test_bind_data_values(self):
data_values = [
DataValue(name="foo"),
DataValue(name="bar"),
DataValue(name="baz")
]
slot_map = (
InputSlotMap(name="baz"),
InputSlotMap(name="bar")
)
slots = (
Slot(),
Slot()
)
result = _bind_data_values(data_values, slot_map, slots)
self.assertEqual(result[0], data_values[2])
self.assertEqual(result[1], data_values[1])
# Check the errors. Only one slot map for two slots.
slot_map = (
InputSlotMap(name="baz"),
)
with self.assertRaisesRegexp(
RuntimeError,
"The length of the slots is not equal to the length of"
" the slot map"):
_bind_data_values(data_values, slot_map, slots)
# missing value in the given data values.
slot_map = (
InputSlotMap(name="blap"),
InputSlotMap(name="bar")
)
with self.assertRaisesRegexp(
RuntimeError,
"Unable to find requested name 'blap' in available"
" data values."):
_bind_data_values(data_values, slot_map, slots)
def test_compute_layer_results(self):
data_values = [
DataValue(name="foo"),
DataValue(name="bar"),
DataValue(name="baz"),
DataValue(name="quux")
]
mock_ds_factory = mock.Mock(spec=BaseDataSourceFactory)
mock_ds_factory.name = "mock factory"
mock_ds_factory.create_data_source.return_value = \
TwoInputsThreeOutputsDataSource(mock_ds_factory)
evaluator_model = NullDataSourceModel(factory=mock_ds_factory)
evaluator_model.input_slot_maps = [
InputSlotMap(name="foo"),
InputSlotMap(name="quux")
]
evaluator_model.output_slot_names = ["one", "", "three"]
res = _compute_layer_results(
data_values,
[evaluator_model],
"create_data_source"
)
self.assertEqual(len(res), 2)
self.assertEqual(res[0].name, "one")
self.assertEqual(res[0].value, 1)
self.assertEqual(res[1].name, "three")
self.assertEqual(res[1].value, 3)
def test_empty_slot_name_skips_data_value(self):
"""Checks if leaving a slot name empty will skip the data value
in the final output
"""
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment