diff --git a/force_bdss/core_evaluation_driver.py b/force_bdss/core_evaluation_driver.py index e89f72446a5f4d9e995213b6bc6ed05544ed4084..1aae1ca29496be8d333a813f2275f1775b68ac28 100644 --- a/force_bdss/core_evaluation_driver.py +++ b/force_bdss/core_evaluation_driver.py @@ -3,6 +3,7 @@ import logging from traits.api import on_trait_change +from force_bdss.core.data_value import DataValue from force_bdss.ids import InternalPluginID from .base_core_driver import BaseCoreDriver @@ -142,6 +143,17 @@ def _compute_layer_results(environment_data_values, "Run method raised exception.") raise + if not isinstance(res, list): + error_txt = ( + "The run method of data source {} must return a list." + " It returned instead {}. Fix the run() method to return" + " the appropriate entity.".format( + factory.name, + type(res) + )) + log.error(error_txt) + raise RuntimeError(error_txt) + if len(res) != len(out_slots): error_txt = ( "The number of data values ({} values) returned" @@ -169,6 +181,23 @@ def _compute_layer_results(environment_data_values, log.error(error_txt) raise RuntimeError(error_txt) + for idx, dv in enumerate(res): + if not isinstance(dv, DataValue): + error_txt = ( + "The result list returned by DataSource {} contains" + " an entry that is not a DataValue. An entry of type" + " {} was instead found in position {}." + " Fix the DataSource.run() method" + " to return the appropriate entity.".format( + factory.name, + type(dv), + idx + ) + ) + log.error(error_txt) + raise RuntimeError(error_txt) + + # At this point, the returned data values are unnamed. # Add the names as specified by the user. for dv, output_slot_info in zip(res, model.output_slot_info): diff --git a/force_bdss/tests/test_core_evaluation_driver.py b/force_bdss/tests/test_core_evaluation_driver.py index 172ceaaf381c60c3263679e3bc733b68a2129c58..1b09402d1643b117e7caa69d7da94d627a2edca2 100644 --- a/force_bdss/tests/test_core_evaluation_driver.py +++ b/force_bdss/tests/test_core_evaluation_driver.py @@ -77,6 +77,40 @@ class TestCoreEvaluationDriver(unittest.TestCase): " the number of output slots"): driver.application_started() + def test_error_for_incorrect_return_type(self): + def run(self, *args, **kwargs): + return "hello" + ds_factory = self.registry.data_source_factories[0] + ds_factory.run_function = run + driver = CoreEvaluationDriver(application=self.mock_application) + with testfixtures.LogCapture(): + with six.assertRaisesRegex( + self, + RuntimeError, + "The run method of data source test_data_source must" + " return a list. It returned instead <type 'str'>. Fix" + " the run\(\) method to return the appropriate entity."): + driver.application_started() + + def test_error_for_incorrect_data_value_entries(self): + def run(self, *args, **kwargs): + return ["hello"] + ds_factory = self.registry.data_source_factories[0] + ds_factory.run_function = run + driver = CoreEvaluationDriver(application=self.mock_application) + with testfixtures.LogCapture(): + with six.assertRaisesRegex( + self, + RuntimeError, + "The result list returned by DataSource test_data_source" + " contains an entry that is not a DataValue." + " An entry of type <type 'str'> was instead found" + " in position 0." + " Fix the DataSource.run\(\) method to" + " return the appropriate entity." + ): + driver.application_started() + def test_error_for_missing_ds_output_names(self): def run(self, *args, **kwargs):