Skip to content
Snippets Groups Projects
Commit a708b10c authored by martinRenou's avatar martinRenou
Browse files

Update prob classes

parent 167f7b85
No related branches found
No related tags found
1 merge request!101Create probe classes for tests
......@@ -2,7 +2,7 @@
"version": "1",
"workflow": {
"mco": {
"id": "force.bdss.enthought.factory.null_mco",
"id": "force.bdss.enthought.factory.test_mco",
"model_data": {
"parameters" : [
]
......@@ -10,7 +10,7 @@
},
"data_sources": [
{
"id": "force.bdss.enthought.factory.null_ds",
"id": "force.bdss.enthought.factory.test_ds",
"model_data": {
"input_slot_maps": [
],
......@@ -21,7 +21,7 @@
],
"kpi_calculators": [
{
"id": "force.bdss.enthought.factory.null_kpic",
"id": "force.bdss.enthought.factory.test_kpic",
"model_data": {
"input_slot_maps": [
],
......@@ -32,7 +32,7 @@
],
"notification_listeners": [
{
"id": "force.bdss.enthought.factory.null_nl",
"id": "force.bdss.enthought.factory.test_nl",
"model_data": {
}
}
......
from traits.api import Bool, Function, Str, Int, on_trait_change, Type
from force_bdss.ids import factory_id
from force_bdss.api import (
BaseDataSourceFactory, BaseDataSourceModel, BaseDataSource,
Slot
......@@ -8,12 +9,21 @@ from force_bdss.api import (
from .evaluator_factory import ProbeEvaluatorFactory
def run_func(*args, **kwargs):
return []
class ProbeDataSource(BaseDataSource):
run_function = Function
run_function = Function(default_value=run_func)
run_called = Bool(False)
slots_called = Bool(False)
def __init__(self, factory, run_function=None, *args, **kwargs):
if run_function is None:
self.run_function = run_func
super(ProbeDataSource, self).__init__(self, factory, *args, **kwargs)
def run(self, model, parameters):
self.run_called = True
self.run_function(model, parameters)
......@@ -44,12 +54,14 @@ class ProbeDataSourceModel(BaseDataSourceModel):
class ProbeDataSourceFactory(BaseDataSourceFactory,
ProbeEvaluatorFactory):
id = Str('enthought.test.data_source')
id = Str(factory_id("enthought", "test_ds"))
name = Str('test_data_source')
model_class = Type(ProbeDataSourceModel)
def create_model(self, model_data=None):
if model_data is None:
model_data = {}
return self.model_class(
factory=self,
input_slots_type=self.input_slots_type,
......
from traits.api import Bool, Function, Str, Int, on_trait_change, Type
from force_bdss.ids import factory_id
from force_bdss.api import (
BaseKPICalculatorFactory, BaseKPICalculatorModel, BaseKPICalculator,
Slot
......@@ -8,12 +9,22 @@ from force_bdss.api import (
from .evaluator_factory import ProbeEvaluatorFactory
class ProbeEvaluator(BaseKPICalculator):
run_function = Function
def run_func(*args, **kwargs):
return []
class ProbeKPICalculator(BaseKPICalculator):
run_function = Function()
run_called = Bool(False)
slots_called = Bool(False)
def __init__(self, factory, run_function=None, *args, **kwargs):
if run_function is None:
self.run_function = run_func
super(ProbeKPICalculator, self).__init__(
self, factory, *args, **kwargs)
def run(self, model, parameters):
self.run_called = True
self.run_function(model, parameters)
......@@ -28,6 +39,11 @@ class ProbeEvaluator(BaseKPICalculator):
for _ in range(model.output_slots_size))
)
def _run_function_default(self):
def run_func(*args, **kwargs):
pass
return run_func
class ProbeKPICalculatorModel(BaseKPICalculatorModel):
input_slots_type = Str('PRESSURE')
......@@ -44,12 +60,14 @@ class ProbeKPICalculatorModel(BaseKPICalculatorModel):
class ProbeKPICalculatorFactory(BaseKPICalculatorFactory,
ProbeEvaluatorFactory):
id = Str('enthought.test.kpi_calculator')
id = Str(factory_id("enthought", "test_kpic"))
name = Str('test_kpi_calculator')
model_class = Type(ProbeKPICalculatorModel)
def create_model(self, model_data=None):
if model_data is None:
model_data = {}
return self.model_class(
factory=self,
input_slots_type=self.input_slots_type,
......@@ -60,7 +78,7 @@ class ProbeKPICalculatorFactory(BaseKPICalculatorFactory,
)
def create_kpi_calculator(self):
return ProbeEvaluator(
return ProbeKPICalculator(
factory=self,
run_function=self.run_function,
)
from traits.api import Str, Type, Bool, Int
from traits.api import Str, Type, Bool, Int, Function
from force_bdss.ids import mco_parameter_id, factory_id
from force_bdss.core.data_value import DataValue
from force_bdss.api import (
BaseMCOModel, BaseMCO, BaseMCOFactory,
......@@ -12,19 +13,35 @@ class ProbeMCOModel(BaseMCOModel):
pass
def run_func(*args, **kwargs):
return []
class ProbeMCO(BaseMCO):
run_function = Function()
run_called = Bool(False)
def __init__(self, factory, run_function=None, *args, **kwargs):
if run_function is None:
self.run_function = run_func
super(ProbeMCO, self).__init__(self, factory, *args, **kwargs)
def run(self, model):
self.run_called = True
def _run_function_default(self):
def run_func(*args, **kwargs):
pass
return run_func
class ProbeParameter(BaseMCOParameter):
pass
class RangedParameterFactory(BaseMCOParameterFactory):
id = Str("enthought.test.mco_parameter")
id = Str(mco_parameter_id("enthought", "test_mco", "test"))
model_class = Type(ProbeParameter)
......@@ -46,7 +63,7 @@ class ProbeMCOCommunicator(BaseMCOCommunicator):
class ProbeMCOFactory(BaseMCOFactory):
id = Str("enthought.test.mco")
id = Str(factory_id("enthought", "test_mco"))
model_class = Type(ProbeMCOModel)
......@@ -55,6 +72,8 @@ class ProbeMCOFactory(BaseMCOFactory):
mco_class = Type(ProbeMCO)
def create_model(self, model_data=None):
if model_data is None:
model_data = {}
return self.model_class(
self,
**model_data
......
from traits.api import Bool, Str, Type
from force_bdss.ids import factory_id
from force_bdss.api import (
BaseNotificationListener, BaseNotificationListenerModel,
BaseNotificationListenerFactory)
......@@ -25,7 +26,7 @@ class ProbeNotificationListenerModel(BaseNotificationListenerModel):
class ProbeNotificationListenerFactory(BaseNotificationListenerFactory):
id = Str("enthought.test.notification_listener")
id = Str(factory_id("enthought", "test_nl"))
name = "test_notification_listener"
model_class = Type(ProbeNotificationListenerModel)
......@@ -36,4 +37,6 @@ class ProbeNotificationListenerFactory(BaseNotificationListenerFactory):
return self.listener_class(self)
def create_model(self, model_data=None):
return self.model_class(self, model_data=model_data)
if model_data is None:
model_data = {}
return self.model_class(self, **model_data)
import unittest
from traits.api import Float, List
from traits.api import Float
from force_bdss.tests.probe_classes.factory_registry_plugin import \
ProbeFactoryRegistryPlugin
from force_bdss.core.input_slot_map import InputSlotMap
from force_bdss.factory_registry_plugin import FactoryRegistryPlugin
from force_bdss.core.data_value import DataValue
from force_bdss.core.slot import Slot
from force_bdss.data_sources.base_data_source import BaseDataSource
......@@ -83,7 +85,7 @@ class OneDataValueMCOCommunicator(BaseMCOCommunicator):
class NullMCOFactory(BaseMCOFactory):
id = factory_id("enthought", "null_mco")
id = factory_id("enthought", "test_mco")
def create_model(self, model_data=None):
return NullMCOModel(self, **model_data)
......@@ -221,33 +223,12 @@ class NullNotificationListenerFactory(BaseNotificationListenerFactory):
return NullNotificationListenerModel(self)
class DummyFactoryRegistryPlugin(FactoryRegistryPlugin):
mco_factories = List()
kpi_calculator_factories = List()
data_source_factories = List()
notification_listener_factories = List()
def mock_factory_registry_plugin():
factory_registry_plugin = DummyFactoryRegistryPlugin()
factory_registry_plugin.mco_factories = [
NullMCOFactory(factory_registry_plugin)]
factory_registry_plugin.kpi_calculator_factories = [
NullKPICalculatorFactory(factory_registry_plugin)]
factory_registry_plugin.data_source_factories = [
NullDataSourceFactory(factory_registry_plugin)]
factory_registry_plugin.notification_listener_factories = [
NullNotificationListenerFactory(factory_registry_plugin)
]
return factory_registry_plugin
class TestCoreEvaluationDriver(unittest.TestCase):
def setUp(self):
self.mock_factory_registry_plugin = mock_factory_registry_plugin()
self.factory_registry_plugin = ProbeFactoryRegistryPlugin()
application = mock.Mock(spec=Application)
application.get_plugin = mock.Mock(
return_value=self.mock_factory_registry_plugin
return_value=self.factory_registry_plugin
)
application.workflow_filepath = fixtures.get("test_null.json")
self.mock_application = application
......@@ -259,7 +240,7 @@ class TestCoreEvaluationDriver(unittest.TestCase):
driver.application_started()
def test_error_for_non_matching_mco_parameters(self):
factory = self.mock_factory_registry_plugin.mco_factories[0]
factory = self.factory_registry_plugin.mco_factories[0]
with mock.patch.object(factory.__class__,
"create_communicator") as create_comm:
create_comm.return_value = OneDataValueMCOCommunicator(
......@@ -273,7 +254,7 @@ class TestCoreEvaluationDriver(unittest.TestCase):
driver.application_started()
def test_error_for_incorrect_output_slots(self):
factory = self.mock_factory_registry_plugin.data_source_factories[0]
factory = self.factory_registry_plugin.data_source_factories[0]
with mock.patch.object(factory.__class__,
"create_data_source") as create_ds:
create_ds.return_value = BrokenOneValueDataSource(factory)
......@@ -288,7 +269,7 @@ class TestCoreEvaluationDriver(unittest.TestCase):
driver.application_started()
def test_error_for_missing_ds_output_names(self):
factory = self.mock_factory_registry_plugin.data_source_factories[0]
factory = self.factory_registry_plugin.data_source_factories[0]
with mock.patch.object(factory.__class__,
"create_data_source") as create_ds:
create_ds.return_value = OneValueDataSource(factory)
......@@ -303,7 +284,7 @@ class TestCoreEvaluationDriver(unittest.TestCase):
driver.application_started()
def test_error_for_incorrect_kpic_output_slots(self):
factory = self.mock_factory_registry_plugin.kpi_calculator_factories[0]
factory = self.factory_registry_plugin.kpi_calculator_factories[0]
with mock.patch.object(factory.__class__,
"create_kpi_calculator") as create_kpic:
create_kpic.return_value = BrokenOneValueKPICalculator(factory)
......@@ -318,7 +299,7 @@ class TestCoreEvaluationDriver(unittest.TestCase):
driver.application_started()
def test_error_for_missing_kpic_output_names(self):
factory = self.mock_factory_registry_plugin.kpi_calculator_factories[0]
factory = self.factory_registry_plugin.kpi_calculator_factories[0]
with mock.patch.object(factory.__class__,
"create_kpi_calculator") as create_kpic:
create_kpic.return_value = OneValueKPICalculator(factory)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment