Skip to content
Snippets Groups Projects
Commit 33273b70 authored by martinRenou's avatar martinRenou
Browse files

Clean test_core_evaluation_driver

parent 31ccc154
No related branches found
No related tags found
1 merge request!101Create probe classes for tests
import unittest import unittest
from traits.api import Float
from force_bdss.tests.probe_classes.factory_registry_plugin import \ from force_bdss.tests.probe_classes.factory_registry_plugin import \
ProbeFactoryRegistryPlugin ProbeFactoryRegistryPlugin
from force_bdss.tests.probe_classes.mco import ( from force_bdss.tests.probe_classes.mco import (
ProbeMCOCommunicator, ProbeMCOFactory) ProbeMCOCommunicator, ProbeMCOFactory)
from force_bdss.tests.probe_classes.data_source import ( from force_bdss.tests.probe_classes.data_source import (
ProbeDataSourceFactory) ProbeDataSourceFactory)
from force_bdss.tests.probe_classes.kpi_calculator import (
ProbeKPICalculatorFactory)
from force_bdss.core.input_slot_map import InputSlotMap from force_bdss.core.input_slot_map import InputSlotMap
from force_bdss.core.data_value import DataValue from force_bdss.core.data_value import DataValue
from force_bdss.core.slot import Slot from force_bdss.core.slot import Slot
from force_bdss.data_sources.base_data_source import BaseDataSource
from force_bdss.data_sources.base_data_source_factory import \
BaseDataSourceFactory
from force_bdss.data_sources.base_data_source_model import BaseDataSourceModel
from force_bdss.ids import mco_parameter_id
from force_bdss.kpi.base_kpi_calculator import BaseKPICalculator
from force_bdss.mco.parameters.base_mco_parameter import BaseMCOParameter
from force_bdss.mco.parameters.base_mco_parameter_factory import \
BaseMCOParameterFactory
from force_bdss.tests import fixtures from force_bdss.tests import fixtures
try: try:
...@@ -35,68 +26,12 @@ from force_bdss.core_evaluation_driver import CoreEvaluationDriver, \ ...@@ -35,68 +26,12 @@ from force_bdss.core_evaluation_driver import CoreEvaluationDriver, \
_bind_data_values, _compute_layer_results _bind_data_values, _compute_layer_results
class RangedParameter(BaseMCOParameter):
initial_value = Float()
lower_bound = Float()
upper_bound = Float()
class RangedParameterFactory(BaseMCOParameterFactory):
id = mco_parameter_id("enthought", "null_mco", "null")
model_class = RangedParameter
class OneValueMCOCommunicator(ProbeMCOCommunicator): class OneValueMCOCommunicator(ProbeMCOCommunicator):
"""A communicator that returns one single datavalue, for testing purposes. """A communicator that returns one single datavalue, for testing purposes.
""" """
nb_output_data_values = 1 nb_output_data_values = 1
class BrokenOneValueKPICalculator(BaseKPICalculator):
def run(self, model, data_source_results):
return [DataValue()]
def slots(self, model):
return (), ()
class OneValueKPICalculator(BaseKPICalculator):
def run(self, model, data_source_results):
return [DataValue()]
def slots(self, model):
return (), (Slot(), )
class NullDataSourceModel(BaseDataSourceModel):
pass
class OneValueDataSource(BaseDataSource):
"""Incorrect data source implementation whose run returns a data value
but no slot was specified for it."""
def run(self, model, parameters):
return [DataValue()]
def slots(self, model):
return (), (
Slot(),
)
class TwoInputsThreeOutputsDataSource(BaseDataSource):
"""Incorrect data source implementation whose run returns a data value
but no slot was specified for it."""
def run(self, model, parameters):
return [DataValue(value=1), DataValue(value=2), DataValue(value=3)]
def slots(self, model):
return (
(Slot(), Slot()),
(Slot(), Slot(), Slot())
)
class TestCoreEvaluationDriver(unittest.TestCase): class TestCoreEvaluationDriver(unittest.TestCase):
def setUp(self): def setUp(self):
self.factory_registry_plugin = ProbeFactoryRegistryPlugin() self.factory_registry_plugin = ProbeFactoryRegistryPlugin()
...@@ -144,49 +79,63 @@ class TestCoreEvaluationDriver(unittest.TestCase): ...@@ -144,49 +79,63 @@ class TestCoreEvaluationDriver(unittest.TestCase):
driver.application_started() driver.application_started()
def test_error_for_missing_ds_output_names(self): def test_error_for_missing_ds_output_names(self):
factory = self.factory_registry_plugin.data_source_factories[0] data_source_factories = \
with mock.patch.object(factory.__class__, self.factory_registry_plugin.data_source_factories
"create_data_source") as create_ds:
create_ds.return_value = OneValueDataSource(factory) def run(self, *args, **kwargs):
driver = CoreEvaluationDriver( return [DataValue()]
application=self.mock_application, data_source_factories[0] = ProbeDataSourceFactory(
) None,
with self.assertRaisesRegexp( run_function=run,
RuntimeError, output_slots_size=1)
"The number of data values \(1 values\)" driver = CoreEvaluationDriver(
" returned by 'test_data_source' does not match" application=self.mock_application,
" the number of user-defined names"): )
driver.application_started() with self.assertRaisesRegexp(
RuntimeError,
"The number of data values \(1 values\)"
" returned by 'test_data_source' does not match"
" the number of user-defined names"):
driver.application_started()
def test_error_for_incorrect_kpic_output_slots(self): def test_error_for_incorrect_kpic_output_slots(self):
factory = self.factory_registry_plugin.kpi_calculator_factories[0] kpi_calculator_factories = \
with mock.patch.object(factory.__class__, self.factory_registry_plugin.kpi_calculator_factories
"create_kpi_calculator") as create_kpic:
create_kpic.return_value = BrokenOneValueKPICalculator(factory) def run(self, *args, **kwargs):
driver = CoreEvaluationDriver( return [DataValue()]
application=self.mock_application, kpi_calculator_factories[0] = ProbeKPICalculatorFactory(
) None,
with self.assertRaisesRegexp( run_function=run)
RuntimeError, driver = CoreEvaluationDriver(
"The number of data values \(1 values\)" application=self.mock_application,
" returned by 'test_kpi_calculator' does not match" )
" the number of output slots"): with self.assertRaisesRegexp(
driver.application_started() RuntimeError,
"The number of data values \(1 values\)"
" returned by 'test_kpi_calculator' does not match"
" the number of output slots"):
driver.application_started()
def test_error_for_missing_kpic_output_names(self): def test_error_for_missing_kpic_output_names(self):
factory = self.factory_registry_plugin.kpi_calculator_factories[0] kpi_calculator_factories = \
with mock.patch.object(factory.__class__, self.factory_registry_plugin.kpi_calculator_factories
"create_kpi_calculator") as create_kpic:
create_kpic.return_value = OneValueKPICalculator(factory) def run(self, *args, **kwargs):
driver = CoreEvaluationDriver( return [DataValue()]
application=self.mock_application, kpi_calculator_factories[0] = ProbeKPICalculatorFactory(
) None,
with self.assertRaisesRegexp( run_function=run,
RuntimeError, output_slots_size=1)
"The number of data values \(1 values\)" driver = CoreEvaluationDriver(
" returned by 'test_kpi_calculator' does not match" application=self.mock_application,
" the number of user-defined names"): )
driver.application_started() with self.assertRaisesRegexp(
RuntimeError,
"The number of data values \(1 values\)"
" returned by 'test_kpi_calculator' does not match"
" the number of user-defined names"):
driver.application_started()
def test_bind_data_values(self): def test_bind_data_values(self):
data_values = [ data_values = [
...@@ -233,7 +182,6 @@ class TestCoreEvaluationDriver(unittest.TestCase): ...@@ -233,7 +182,6 @@ class TestCoreEvaluationDriver(unittest.TestCase):
_bind_data_values(data_values, slot_map, slots) _bind_data_values(data_values, slot_map, slots)
def test_compute_layer_results(self): def test_compute_layer_results(self):
data_values = [ data_values = [
DataValue(name="foo"), DataValue(name="foo"),
DataValue(name="bar"), DataValue(name="bar"),
...@@ -241,11 +189,14 @@ class TestCoreEvaluationDriver(unittest.TestCase): ...@@ -241,11 +189,14 @@ class TestCoreEvaluationDriver(unittest.TestCase):
DataValue(name="quux") DataValue(name="quux")
] ]
mock_ds_factory = mock.Mock(spec=BaseDataSourceFactory) def run(self, *args, **kwargs):
mock_ds_factory.name = "mock factory" return [DataValue(value=1), DataValue(value=2), DataValue(value=3)]
mock_ds_factory.create_data_source.return_value = \ ds_factory = ProbeDataSourceFactory(
TwoInputsThreeOutputsDataSource(mock_ds_factory) None,
evaluator_model = NullDataSourceModel(factory=mock_ds_factory) input_slots_size=2,
output_slots_size=3,
run_function=run)
evaluator_model = ds_factory.create_model()
evaluator_model.input_slot_maps = [ evaluator_model.input_slot_maps = [
InputSlotMap(name="foo"), InputSlotMap(name="foo"),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment