Skip to content
Snippets Groups Projects
Commit a95aa5e1 authored by Stefano Borini's avatar Stefano Borini
Browse files

Added stricter controls over the returned values by the datasource

parent 1acbe751
No related branches found
No related tags found
1 merge request!155Added stricter controls over the returned values by the datasource.
...@@ -3,6 +3,7 @@ import logging ...@@ -3,6 +3,7 @@ import logging
from traits.api import on_trait_change from traits.api import on_trait_change
from force_bdss.core.data_value import DataValue
from force_bdss.ids import InternalPluginID from force_bdss.ids import InternalPluginID
from .base_core_driver import BaseCoreDriver from .base_core_driver import BaseCoreDriver
...@@ -142,6 +143,17 @@ def _compute_layer_results(environment_data_values, ...@@ -142,6 +143,17 @@ def _compute_layer_results(environment_data_values,
"Run method raised exception.") "Run method raised exception.")
raise raise
if not isinstance(res, list):
error_txt = (
"The run method of data source {} must return a list."
" It returned instead {}. Fix the run() method to return"
" the appropriate entity.".format(
factory.name,
type(res)
))
log.error(error_txt)
raise RuntimeError(error_txt)
if len(res) != len(out_slots): if len(res) != len(out_slots):
error_txt = ( error_txt = (
"The number of data values ({} values) returned" "The number of data values ({} values) returned"
...@@ -169,6 +181,23 @@ def _compute_layer_results(environment_data_values, ...@@ -169,6 +181,23 @@ def _compute_layer_results(environment_data_values,
log.error(error_txt) log.error(error_txt)
raise RuntimeError(error_txt) raise RuntimeError(error_txt)
for idx, dv in enumerate(res):
if not isinstance(dv, DataValue):
error_txt = (
"The result list returned by DataSource {} contains"
" an entry that is not a DataValue. An entry of type"
" {} was instead found in position {}."
" Fix the DataSource.run() method"
" to return the appropriate entity.".format(
factory.name,
type(dv),
idx
)
)
log.error(error_txt)
raise RuntimeError(error_txt)
# At this point, the returned data values are unnamed. # At this point, the returned data values are unnamed.
# Add the names as specified by the user. # Add the names as specified by the user.
for dv, output_slot_info in zip(res, model.output_slot_info): for dv, output_slot_info in zip(res, model.output_slot_info):
......
...@@ -77,6 +77,40 @@ class TestCoreEvaluationDriver(unittest.TestCase): ...@@ -77,6 +77,40 @@ class TestCoreEvaluationDriver(unittest.TestCase):
" the number of output slots"): " the number of output slots"):
driver.application_started() driver.application_started()
def test_error_for_incorrect_return_type(self):
def run(self, *args, **kwargs):
return "hello"
ds_factory = self.registry.data_source_factories[0]
ds_factory.run_function = run
driver = CoreEvaluationDriver(application=self.mock_application)
with testfixtures.LogCapture():
with six.assertRaisesRegex(
self,
RuntimeError,
"The run method of data source test_data_source must"
" return a list. It returned instead <type 'str'>. Fix"
" the run\(\) method to return the appropriate entity."):
driver.application_started()
def test_error_for_incorrect_data_value_entries(self):
def run(self, *args, **kwargs):
return ["hello"]
ds_factory = self.registry.data_source_factories[0]
ds_factory.run_function = run
driver = CoreEvaluationDriver(application=self.mock_application)
with testfixtures.LogCapture():
with six.assertRaisesRegex(
self,
RuntimeError,
"The result list returned by DataSource test_data_source"
" contains an entry that is not a DataValue."
" An entry of type <type 'str'> was instead found"
" in position 0."
" Fix the DataSource.run\(\) method to"
" return the appropriate entity."
):
driver.application_started()
def test_error_for_missing_ds_output_names(self): def test_error_for_missing_ds_output_names(self):
def run(self, *args, **kwargs): def run(self, *args, **kwargs):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment