Skip to content
Snippets Groups Projects
Commit 33273b70 authored by martinRenou's avatar martinRenou
Browse files

Clean test_core_evaluation_driver

parent 31ccc154
No related branches found
No related tags found
1 merge request!101Create probe classes for tests
import unittest
from traits.api import Float
from force_bdss.tests.probe_classes.factory_registry_plugin import \
ProbeFactoryRegistryPlugin
from force_bdss.tests.probe_classes.mco import (
ProbeMCOCommunicator, ProbeMCOFactory)
from force_bdss.tests.probe_classes.data_source import (
ProbeDataSourceFactory)
from force_bdss.tests.probe_classes.kpi_calculator import (
ProbeKPICalculatorFactory)
from force_bdss.core.input_slot_map import InputSlotMap
from force_bdss.core.data_value import DataValue
from force_bdss.core.slot import Slot
from force_bdss.data_sources.base_data_source import BaseDataSource
from force_bdss.data_sources.base_data_source_factory import \
BaseDataSourceFactory
from force_bdss.data_sources.base_data_source_model import BaseDataSourceModel
from force_bdss.ids import mco_parameter_id
from force_bdss.kpi.base_kpi_calculator import BaseKPICalculator
from force_bdss.mco.parameters.base_mco_parameter import BaseMCOParameter
from force_bdss.mco.parameters.base_mco_parameter_factory import \
BaseMCOParameterFactory
from force_bdss.tests import fixtures
try:
......@@ -35,68 +26,12 @@ from force_bdss.core_evaluation_driver import CoreEvaluationDriver, \
_bind_data_values, _compute_layer_results
class RangedParameter(BaseMCOParameter):
initial_value = Float()
lower_bound = Float()
upper_bound = Float()
class RangedParameterFactory(BaseMCOParameterFactory):
id = mco_parameter_id("enthought", "null_mco", "null")
model_class = RangedParameter
class OneValueMCOCommunicator(ProbeMCOCommunicator):
"""A communicator that returns one single datavalue, for testing purposes.
"""
nb_output_data_values = 1
class BrokenOneValueKPICalculator(BaseKPICalculator):
def run(self, model, data_source_results):
return [DataValue()]
def slots(self, model):
return (), ()
class OneValueKPICalculator(BaseKPICalculator):
def run(self, model, data_source_results):
return [DataValue()]
def slots(self, model):
return (), (Slot(), )
class NullDataSourceModel(BaseDataSourceModel):
pass
class OneValueDataSource(BaseDataSource):
"""Incorrect data source implementation whose run returns a data value
but no slot was specified for it."""
def run(self, model, parameters):
return [DataValue()]
def slots(self, model):
return (), (
Slot(),
)
class TwoInputsThreeOutputsDataSource(BaseDataSource):
"""Incorrect data source implementation whose run returns a data value
but no slot was specified for it."""
def run(self, model, parameters):
return [DataValue(value=1), DataValue(value=2), DataValue(value=3)]
def slots(self, model):
return (
(Slot(), Slot()),
(Slot(), Slot(), Slot())
)
class TestCoreEvaluationDriver(unittest.TestCase):
def setUp(self):
self.factory_registry_plugin = ProbeFactoryRegistryPlugin()
......@@ -144,49 +79,63 @@ class TestCoreEvaluationDriver(unittest.TestCase):
driver.application_started()
def test_error_for_missing_ds_output_names(self):
factory = self.factory_registry_plugin.data_source_factories[0]
with mock.patch.object(factory.__class__,
"create_data_source") as create_ds:
create_ds.return_value = OneValueDataSource(factory)
driver = CoreEvaluationDriver(
application=self.mock_application,
)
with self.assertRaisesRegexp(
RuntimeError,
"The number of data values \(1 values\)"
" returned by 'test_data_source' does not match"
" the number of user-defined names"):
driver.application_started()
data_source_factories = \
self.factory_registry_plugin.data_source_factories
def run(self, *args, **kwargs):
return [DataValue()]
data_source_factories[0] = ProbeDataSourceFactory(
None,
run_function=run,
output_slots_size=1)
driver = CoreEvaluationDriver(
application=self.mock_application,
)
with self.assertRaisesRegexp(
RuntimeError,
"The number of data values \(1 values\)"
" returned by 'test_data_source' does not match"
" the number of user-defined names"):
driver.application_started()
def test_error_for_incorrect_kpic_output_slots(self):
factory = self.factory_registry_plugin.kpi_calculator_factories[0]
with mock.patch.object(factory.__class__,
"create_kpi_calculator") as create_kpic:
create_kpic.return_value = BrokenOneValueKPICalculator(factory)
driver = CoreEvaluationDriver(
application=self.mock_application,
)
with self.assertRaisesRegexp(
RuntimeError,
"The number of data values \(1 values\)"
" returned by 'test_kpi_calculator' does not match"
" the number of output slots"):
driver.application_started()
kpi_calculator_factories = \
self.factory_registry_plugin.kpi_calculator_factories
def run(self, *args, **kwargs):
return [DataValue()]
kpi_calculator_factories[0] = ProbeKPICalculatorFactory(
None,
run_function=run)
driver = CoreEvaluationDriver(
application=self.mock_application,
)
with self.assertRaisesRegexp(
RuntimeError,
"The number of data values \(1 values\)"
" returned by 'test_kpi_calculator' does not match"
" the number of output slots"):
driver.application_started()
def test_error_for_missing_kpic_output_names(self):
factory = self.factory_registry_plugin.kpi_calculator_factories[0]
with mock.patch.object(factory.__class__,
"create_kpi_calculator") as create_kpic:
create_kpic.return_value = OneValueKPICalculator(factory)
driver = CoreEvaluationDriver(
application=self.mock_application,
)
with self.assertRaisesRegexp(
RuntimeError,
"The number of data values \(1 values\)"
" returned by 'test_kpi_calculator' does not match"
" the number of user-defined names"):
driver.application_started()
kpi_calculator_factories = \
self.factory_registry_plugin.kpi_calculator_factories
def run(self, *args, **kwargs):
return [DataValue()]
kpi_calculator_factories[0] = ProbeKPICalculatorFactory(
None,
run_function=run,
output_slots_size=1)
driver = CoreEvaluationDriver(
application=self.mock_application,
)
with self.assertRaisesRegexp(
RuntimeError,
"The number of data values \(1 values\)"
" returned by 'test_kpi_calculator' does not match"
" the number of user-defined names"):
driver.application_started()
def test_bind_data_values(self):
data_values = [
......@@ -233,7 +182,6 @@ class TestCoreEvaluationDriver(unittest.TestCase):
_bind_data_values(data_values, slot_map, slots)
def test_compute_layer_results(self):
data_values = [
DataValue(name="foo"),
DataValue(name="bar"),
......@@ -241,11 +189,14 @@ class TestCoreEvaluationDriver(unittest.TestCase):
DataValue(name="quux")
]
mock_ds_factory = mock.Mock(spec=BaseDataSourceFactory)
mock_ds_factory.name = "mock factory"
mock_ds_factory.create_data_source.return_value = \
TwoInputsThreeOutputsDataSource(mock_ds_factory)
evaluator_model = NullDataSourceModel(factory=mock_ds_factory)
def run(self, *args, **kwargs):
return [DataValue(value=1), DataValue(value=2), DataValue(value=3)]
ds_factory = ProbeDataSourceFactory(
None,
input_slots_size=2,
output_slots_size=3,
run_function=run)
evaluator_model = ds_factory.create_model()
evaluator_model.input_slot_maps = [
InputSlotMap(name="foo"),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment