From a95aa5e1f640ae7292c262804a900a4af66e8a78 Mon Sep 17 00:00:00 2001 From: Stefano Borini <sborini@enthought.com> Date: Wed, 20 Jun 2018 10:55:58 +0100 Subject: [PATCH] Added stricter controls over the returned values by the datasource --- force_bdss/core_evaluation_driver.py | 29 ++++++++++++++++ .../tests/test_core_evaluation_driver.py | 34 +++++++++++++++++++ 2 files changed, 63 insertions(+) diff --git a/force_bdss/core_evaluation_driver.py b/force_bdss/core_evaluation_driver.py index e89f724..1aae1ca 100644 --- a/force_bdss/core_evaluation_driver.py +++ b/force_bdss/core_evaluation_driver.py @@ -3,6 +3,7 @@ import logging from traits.api import on_trait_change +from force_bdss.core.data_value import DataValue from force_bdss.ids import InternalPluginID from .base_core_driver import BaseCoreDriver @@ -142,6 +143,17 @@ def _compute_layer_results(environment_data_values, "Run method raised exception.") raise + if not isinstance(res, list): + error_txt = ( + "The run method of data source {} must return a list." + " It returned instead {}. Fix the run() method to return" + " the appropriate entity.".format( + factory.name, + type(res) + )) + log.error(error_txt) + raise RuntimeError(error_txt) + if len(res) != len(out_slots): error_txt = ( "The number of data values ({} values) returned" @@ -169,6 +181,23 @@ def _compute_layer_results(environment_data_values, log.error(error_txt) raise RuntimeError(error_txt) + for idx, dv in enumerate(res): + if not isinstance(dv, DataValue): + error_txt = ( + "The result list returned by DataSource {} contains" + " an entry that is not a DataValue. An entry of type" + " {} was instead found in position {}." + " Fix the DataSource.run() method" + " to return the appropriate entity.".format( + factory.name, + type(dv), + idx + ) + ) + log.error(error_txt) + raise RuntimeError(error_txt) + + # At this point, the returned data values are unnamed. # Add the names as specified by the user. for dv, output_slot_info in zip(res, model.output_slot_info): diff --git a/force_bdss/tests/test_core_evaluation_driver.py b/force_bdss/tests/test_core_evaluation_driver.py index 172ceaa..1b09402 100644 --- a/force_bdss/tests/test_core_evaluation_driver.py +++ b/force_bdss/tests/test_core_evaluation_driver.py @@ -77,6 +77,40 @@ class TestCoreEvaluationDriver(unittest.TestCase): " the number of output slots"): driver.application_started() + def test_error_for_incorrect_return_type(self): + def run(self, *args, **kwargs): + return "hello" + ds_factory = self.registry.data_source_factories[0] + ds_factory.run_function = run + driver = CoreEvaluationDriver(application=self.mock_application) + with testfixtures.LogCapture(): + with six.assertRaisesRegex( + self, + RuntimeError, + "The run method of data source test_data_source must" + " return a list. It returned instead <type 'str'>. Fix" + " the run\(\) method to return the appropriate entity."): + driver.application_started() + + def test_error_for_incorrect_data_value_entries(self): + def run(self, *args, **kwargs): + return ["hello"] + ds_factory = self.registry.data_source_factories[0] + ds_factory.run_function = run + driver = CoreEvaluationDriver(application=self.mock_application) + with testfixtures.LogCapture(): + with six.assertRaisesRegex( + self, + RuntimeError, + "The result list returned by DataSource test_data_source" + " contains an entry that is not a DataValue." + " An entry of type <type 'str'> was instead found" + " in position 0." + " Fix the DataSource.run\(\) method to" + " return the appropriate entity." + ): + driver.application_started() + def test_error_for_missing_ds_output_names(self): def run(self, *args, **kwargs): -- GitLab