Skip to content
Snippets Groups Projects
Commit cc3d0dd8 authored by Stefano Borini's avatar Stefano Borini
Browse files

Merge branch 'master' into deliver-to-ui

parents 2f12ba4e a93e0e33
No related branches found
No related tags found
1 merge request!79Deliver notification info
Showing
with 200 additions and 207 deletions
......@@ -5,21 +5,37 @@ A single Plugin can provide one or more of the pluggable entities
described elsewhere (MCO/KPICalculators/DataSources). Multiple plugins can
be installed to provide a broad range of functionalities.
Plugins must return "Bundles". Each Bundle acts as a Factory, providing
factory methods for one of the above pluggable entities and its associated
classes.
Plugins must return Factories. Each Factory provides factory methods for
one of the above pluggable entities and its associated classes.
To implement a new plugin, you must
- define the entity you want to extend (e.g. ``MyOwnDataSource``) as a derived
class of the appropriate class (e.g. BaseDataSource), and reimplement
the appropriate methods.
- Define the model that this DataSource needs, by extending
class of the appropriate class (e.g. ``BaseDataSource``), and reimplement
the appropriate methods:
- ``run()``: where the actual computation takes place, given the
configuration options specified in the model (which is received as an
argument). It is strongly advised that the ``run()`` method is stateless.
- ``slots()``: must return a 2-tuple of tuples. Each tuple contains instances
of the ``Slot`` class. Slots are the input and output entities of the
data source or KPI calculator. Given that this information depends on the
configuration options, ``slots()`` accepts the model and must return the
appropriate values according to the model options.
- Define the model that this ``DataSource`` needs, by extending
``BaseDataSourceModel`` and adding, with traits, the appropriate data that
are required by your data source to perform its task.
- Define the Bundle, by reimplementing BaseDataSourceBundle and reimplementing
If a trait change in your model influences the input/output slots, you must
make sure that the event ``changes_slots`` is fired as a consequence of
those changes. This will notify the UI that the new slots need to be
recomputed and presented to the user. Failing to do so will have unexpected
consequences.
- Define the Factory, by reimplementing BaseDataSourceFactory and reimplementing
its ``create_*`` methods to return the above entities.
- Define a ``Plugin`` by reimplementing ``BaseExtensionPlugin`` and
reimplementing its initialization defaults methods to return your bundle.
reimplementing its initialization defaults methods to return your factory.
- add the plugin class in the setup.py entry_point, under the namespace
``force.bdss.extensions``
from .base_extension_plugin import BaseExtensionPlugin # noqa
from .ids import bundle_id, plugin_id # noqa
from .ids import factory_id, plugin_id # noqa
from .core.data_value import DataValue # noqa
from .data_sources.base_data_source_model import BaseDataSourceModel # noqa
from .data_sources.base_data_source import BaseDataSource # noqa
from .data_sources.base_data_source_bundle import BaseDataSourceBundle # noqa
from .data_sources.i_data_source_bundle import IDataSourceBundle # noqa
from .data_sources.base_data_source_factory import BaseDataSourceFactory # noqa
from .data_sources.i_data_source_factory import IDataSourceFactory # noqa
from .kpi.base_kpi_calculator import BaseKPICalculator # noqa
from .kpi.base_kpi_calculator_model import BaseKPICalculatorModel # noqa
from .kpi.base_kpi_calculator_bundle import BaseKPICalculatorBundle # noqa
from .kpi.i_kpi_calculator_bundle import IKPICalculatorBundle # noqa
from .kpi.base_kpi_calculator_factory import BaseKPICalculatorFactory # noqa
from .kpi.i_kpi_calculator_factory import IKPICalculatorFactory # noqa
from .mco.base_mco_model import BaseMCOModel # noqa
from .mco.base_mco_communicator import BaseMCOCommunicator # noqa
from .mco.base_mco import BaseMCO # noqa
from .mco.base_mco_bundle import BaseMCOBundle # noqa
from .mco.i_mco_bundle import IMCOBundle # noqa
from .mco.base_mco_factory import BaseMCOFactory # noqa
from .mco.i_mco_factory import IMCOFactory # noqa
from .mco.parameters.base_mco_parameter_factory import BaseMCOParameterFactory # noqa
from .mco.parameters.base_mco_parameter import BaseMCOParameter # noqa
......@@ -2,9 +2,9 @@ from envisage.plugin import Plugin
from traits.trait_types import Instance
from .core.workflow import Workflow
from .bundle_registry_plugin import (
BundleRegistryPlugin,
BUNDLE_REGISTRY_PLUGIN_ID
from .factory_registry_plugin import (
FactoryRegistryPlugin,
FACTORY_REGISTRY_PLUGIN_ID
)
from .io.workflow_reader import WorkflowReader
......@@ -14,16 +14,16 @@ class BaseCoreDriver(Plugin):
or the evaluation.
"""
#: The registry of the bundles.
bundle_registry = Instance(BundleRegistryPlugin)
#: The registry of the factories
factory_registry = Instance(FactoryRegistryPlugin)
#: Deserialized content of the workflow file.
workflow = Instance(Workflow)
def _bundle_registry_default(self):
return self.application.get_plugin(BUNDLE_REGISTRY_PLUGIN_ID)
def _factory_registry_default(self):
return self.application.get_plugin(FACTORY_REGISTRY_PLUGIN_ID)
def _workflow_default(self):
reader = WorkflowReader(self.bundle_registry)
reader = WorkflowReader(self.factory_registry)
with open(self.application.workflow_filepath) as f:
return reader.read(f)
......@@ -2,9 +2,9 @@ from envisage.plugin import Plugin
from traits.trait_types import List
from .ids import ExtensionPointID
from .data_sources.i_data_source_bundle import IDataSourceBundle
from .kpi.i_kpi_calculator_bundle import IKPICalculatorBundle
from .mco.i_mco_bundle import IMCOBundle
from .data_sources.i_data_source_factory import IDataSourceFactory
from .kpi.i_kpi_calculator_factory import IKPICalculatorFactory
from .mco.i_mco_factory import IMCOFactory
class BaseExtensionPlugin(Plugin):
......@@ -17,30 +17,30 @@ class BaseExtensionPlugin(Plugin):
specific trait you want to populate. For example::
class MyPlugin(BaseExtensionPlugin):
def _data_source_bundles_default(self):
return [MyDataSourceBundle1(),
MyDataSourceBundle2()]
def _data_source_factories_default(self):
return [MyDataSourceFactory1(),
MyDataSourceFactory2()]
"""
#: A list of available Multi Criteria Optimizers this plugin exports.
mco_bundles = List(
IMCOBundle,
contributes_to=ExtensionPointID.MCO_BUNDLES
mco_factories = List(
IMCOFactory,
contributes_to=ExtensionPointID.MCO_FACTORIES
)
#: A list of the available Data Sources this plugin exports.
data_source_bundles = List(
IDataSourceBundle,
contributes_to=ExtensionPointID.DATA_SOURCE_BUNDLES
data_source_factories = List(
IDataSourceFactory,
contributes_to=ExtensionPointID.DATA_SOURCE_FACTORIES
)
#: A list of the available KPI calculators this plugin exports.
kpi_calculator_bundles = List(
IKPICalculatorBundle,
contributes_to=ExtensionPointID.KPI_CALCULATOR_BUNDLES
kpi_calculator_factories = List(
IKPICalculatorFactory,
contributes_to=ExtensionPointID.KPI_CALCULATOR_FACTORIES
)
notification_listener_bundles = List(
INotificationListenerBundle,
contributes_to=ExtensionPointID.NOTIFICATION_LISTENER_BUNDLES
notifier_factory = List(
INotifierFactory,
contributes_to=ExtensionPointID.NOTIFIER_FACTORIES
)
......@@ -8,7 +8,7 @@ from envisage.api import Application
from envisage.core_plugin import CorePlugin
from traits.api import Unicode, Bool
from .bundle_registry_plugin import BundleRegistryPlugin
from .factory_registry_plugin import FactoryRegistryPlugin
from .core_evaluation_driver import CoreEvaluationDriver
from .core_mco_driver import CoreMCODriver
......@@ -30,7 +30,7 @@ class BDSSApplication(Application):
self.evaluate = evaluate
self.workflow_filepath = workflow_filepath
plugins = [CorePlugin(), BundleRegistryPlugin()]
plugins = [CorePlugin(), FactoryRegistryPlugin()]
if self.evaluate:
plugins.append(CoreEvaluationDriver())
......
......@@ -7,14 +7,14 @@ from force_bdss.mco.base_mco_model import BaseMCOModel
class Workflow(HasStrictTraits):
"""Model object that represents the Workflow as a whole"""
#: Contains the bundle-specific MCO Model object.
#: Contains the factory-specific MCO Model object.
#: Can be None if no MCO has been specified yet.
mco = Instance(BaseMCOModel, allow_none=True)
#: Contains the bundle-specific DataSource Model objects.
#: Contains the factory-specific DataSource Model objects.
#: The list can be empty
data_sources = List(BaseDataSourceModel)
#: Contains the bundle-specific KPI Calculator Model objects.
#: Contains the factory-specific KPI Calculator Model objects.
#: The list can be empty
kpi_calculators = List(BaseKPICalculatorModel)
......@@ -30,37 +30,63 @@ class CoreEvaluationDriver(BaseCoreDriver):
sys.exit(1)
mco_model = workflow.mco
mco_bundle = mco_model.bundle
mco_communicator = mco_bundle.create_communicator()
mco_factory = mco_model.factory
mco_communicator = mco_factory.create_communicator()
mco_data_values = self._get_data_values_from_mco(mco_model,
mco_communicator)
ds_results = self._compute_ds_results(
ds_results = self._compute_layer_results(
mco_data_values,
workflow)
workflow.data_sources,
"create_data_source"
)
kpi_results = self._compute_kpi_results(
kpi_results = self._compute_layer_results(
ds_results + mco_data_values,
workflow)
workflow.kpi_calculators,
"create_kpi_calculator"
)
mco_communicator.send_to_mco(mco_model, kpi_results)
def _compute_ds_results(self, environment_data_values, workflow):
def _compute_layer_results(self,
environment_data_values,
evaluator_models,
creator_method_name
):
"""Helper routine.
Performs the evaluation of the DataSources, passing the current
environment data values (the MCO data)
Performs the evaluation of a single layer.
At the moment we have a single layer of DataSources followed
by a single layer of KPI calculators.
Parameters
----------
environment_data_values: list
A list of data values to submit to the evaluators.
evaluator_models: list
A list of the models for all the evaluators (data source
or kpi calculator)
creator_method_name: str
A string of the creator method for the evaluator on the
factory (e.g. create_kpi_calculator)
NOTE: The above parameter is going to go away as soon as we move
to unlimited layers and remove the distinction between data sources
and KPI calculators.
"""
ds_results = []
results = []
for ds_model in workflow.data_sources:
ds_bundle = ds_model.bundle
data_source = ds_bundle.create_data_source()
for model in evaluator_models:
factory = model.factory
evaluator = getattr(factory, creator_method_name)()
# Get the slots for this data source. These must be matched to
# the appropriate values in the environment data values.
# Matching is by position.
in_slots, out_slots = data_source.slots(ds_model)
in_slots, out_slots = evaluator.slots(model)
# Binding performs the extraction of the specified data values
# satisfying the above input slots from the environment data values
......@@ -71,36 +97,36 @@ class CoreEvaluationDriver(BaseCoreDriver):
# needed by the input slots.
passed_data_values = self._bind_data_values(
environment_data_values,
ds_model.input_slot_maps,
model.input_slot_maps,
in_slots)
# execute data source, passing only relevant data values.
logging.info("Evaluating for Data Source {}".format(
ds_bundle.name))
res = data_source.run(ds_model, passed_data_values)
factory.name))
res = evaluator.run(model, passed_data_values)
if len(res) != len(out_slots):
error_txt = (
"The number of data values ({} values) returned"
" by the DataSource '{}' does not match the number"
" by '{}' does not match the number"
" of output slots it specifies ({} values)."
" This is likely a DataSource plugin error.").format(
len(res), ds_bundle.name, len(out_slots)
" This is likely a plugin error.").format(
len(res), factory.name, len(out_slots)
)
logging.error(error_txt)
raise RuntimeError(error_txt)
if len(res) != len(ds_model.output_slot_names):
if len(res) != len(model.output_slot_names):
error_txt = (
"The number of data values ({} values) returned"
" by the DataSource '{}' does not match the number"
" by '{}' does not match the number"
" of user-defined names specified ({} values)."
" This is either a DataSource plugin error or a file"
" This is either a plugin error or a file"
" error.").format(
len(res),
ds_bundle.name,
len(ds_model.output_slot_names)
factory.name,
len(model.output_slot_names)
)
logging.error(error_txt)
......@@ -108,67 +134,14 @@ class CoreEvaluationDriver(BaseCoreDriver):
# At this point, the returned data values are unnamed.
# Add the names as specified by the user.
for dv, output_slot_name in zip(res, ds_model.output_slot_names):
for dv, output_slot_name in zip(res, model.output_slot_names):
dv.name = output_slot_name
ds_results.extend(res)
results.extend(res)
# Finally, return all the computed data values from all data sources,
# Finally, return all the computed data values from all evaluators,
# properly named.
return ds_results
def _compute_kpi_results(self, environment_data_values, workflow):
"""Perform evaluation of all KPI calculators.
environment_data_values contains all data values provided from
the MCO and data sources.
"""
kpi_results = []
for kpic_model in workflow.kpi_calculators:
kpic_bundle = kpic_model.bundle
kpi_calculator = kpic_bundle.create_kpi_calculator()
in_slots, out_slots = kpi_calculator.slots(kpic_model)
passed_data_values = self._bind_data_values(
environment_data_values,
kpic_model.input_slot_maps,
in_slots)
logging.info("Evaluating for KPICalculator {}".format(
kpic_bundle.name))
res = kpi_calculator.run(kpic_model, passed_data_values)
if len(res) != len(out_slots):
error_txt = (
"The number of data values ({} values) returned by"
" the KPICalculator '{}' does not match the"
" number of output slots ({} values). This is"
" likely a KPICalculator plugin error."
).format(len(res), kpic_bundle.name, len(out_slots))
logging.error(error_txt)
raise RuntimeError(error_txt)
if len(res) != len(kpic_model.output_slot_names):
error_txt = (
"The number of data values ({} values) returned by"
" the KPICalculator '{}' does not match the"
" number of user-defined names specified ({} values)."
" This is either an input file error or a plugin"
" error."
).format(len(res), kpic_bundle.name,
len(kpic_model.output_slot_names))
logging.error(error_txt)
raise RuntimeError(error_txt)
for kpi, output_slot_name in zip(
res, kpic_model.output_slot_names):
kpi.name = output_slot_name
kpi_results.extend(res)
return kpi_results
return results
def _get_data_values_from_mco(self, model, communicator):
"""Helper method.
......
......@@ -23,7 +23,7 @@ class CoreMCODriver(BaseCoreDriver):
mco = Instance(BaseMCO, allow_none=True)
listeners = Instance(BaseNotificationListener)
listeners = Instance(BaseNotifier)
@on_trait_change("application:started")
def application_started(self):
......@@ -34,7 +34,6 @@ class CoreMCODriver(BaseCoreDriver):
sys.exit(1)
mco_model = workflow.mco
mco_bundle = mco_model.bundle
self.mco = mco_bundle.create_optimizer()
self.mco.run(mco_model)
mco_factory = mco_model.factory
mco = mco_factory.create_optimizer()
mco.run(mco_model)
from traits.api import String
from force_bdss.api import bundle_id, BaseDataSourceBundle
from force_bdss.api import factory_id, BaseDataSourceFactory
from .csv_extractor_model import CSVExtractorModel
from .csv_extractor_data_source import CSVExtractorDataSource
class CSVExtractorBundle(BaseDataSourceBundle):
id = String(bundle_id("enthought", "csv_extractor"))
class CSVExtractorFactory(BaseDataSourceFactory):
id = String(factory_id("enthought", "csv_extractor"))
name = String("CSV Extractor")
......
from traits.api import Int, String
from traits.api import Int, String, on_trait_change
from force_bdss.api import BaseDataSourceModel
......@@ -8,3 +8,7 @@ class CSVExtractorModel(BaseDataSourceModel):
row = Int()
column = Int()
cuba_type = String()
@on_trait_change("cuba_type")
def _notify_changes_slots(self):
self.changes_slots = True
......@@ -6,8 +6,8 @@ from force_bdss.core_plugins.dummy.csv_extractor.csv_extractor_data_source \
import CSVExtractorDataSource
from force_bdss.core_plugins.dummy.csv_extractor.csv_extractor_model import \
CSVExtractorModel
from force_bdss.data_sources.base_data_source_bundle import \
BaseDataSourceBundle
from force_bdss.data_sources.base_data_source_factory import \
BaseDataSourceFactory
from force_bdss.tests import fixtures
try:
......@@ -18,15 +18,15 @@ except ImportError:
class TestCSVExtractorDataSource(unittest.TestCase):
def setUp(self):
self.bundle = mock.Mock(spec=BaseDataSourceBundle)
self.factory = mock.Mock(spec=BaseDataSourceFactory)
def test_initialization(self):
ds = CSVExtractorDataSource(self.bundle)
self.assertEqual(ds.bundle, self.bundle)
ds = CSVExtractorDataSource(self.factory)
self.assertEqual(ds.factory, self.factory)
def test_run(self):
ds = CSVExtractorDataSource(self.bundle)
model = CSVExtractorModel(self.bundle)
ds = CSVExtractorDataSource(self.factory)
model = CSVExtractorModel(self.factory)
model.filename = fixtures.get("foo.csv")
model.row = 3
model.column = 5
......@@ -38,8 +38,8 @@ class TestCSVExtractorDataSource(unittest.TestCase):
self.assertEqual(result[0].value, 42)
def test_run_with_exception(self):
ds = CSVExtractorDataSource(self.bundle)
model = CSVExtractorModel(self.bundle)
ds = CSVExtractorDataSource(self.factory)
model = CSVExtractorModel(self.factory)
model.filename = fixtures.get("foo.csv")
mock_params = []
model.row = 30
......@@ -53,8 +53,8 @@ class TestCSVExtractorDataSource(unittest.TestCase):
ds.run(model, mock_params)
def test_slots(self):
ds = CSVExtractorDataSource(self.bundle)
model = CSVExtractorModel(self.bundle)
ds = CSVExtractorDataSource(self.factory)
model = CSVExtractorModel(self.factory)
slots = ds.slots(model)
self.assertEqual(len(slots), 2)
self.assertEqual(len(slots[0]), 0)
......
import unittest
from force_bdss.core_plugins.dummy.tests.data_source_bundle_test_mixin import \
DataSourceBundleTestMixin
from force_bdss.core_plugins.dummy.csv_extractor.csv_extractor_bundle import \
CSVExtractorBundle
from force_bdss.core_plugins.dummy.tests.data_source_factory_test_mixin \
import DataSourceFactoryTestMixin
from force_bdss.core_plugins.dummy.csv_extractor.csv_extractor_factory import \
CSVExtractorFactory
from force_bdss.core_plugins.dummy.csv_extractor.csv_extractor_data_source \
import CSVExtractorDataSource
from force_bdss.core_plugins.dummy.csv_extractor.csv_extractor_model import \
CSVExtractorModel
class TestCSVExtractorBundle(DataSourceBundleTestMixin, unittest.TestCase):
class TestCSVExtractorFactory(DataSourceFactoryTestMixin, unittest.TestCase):
@property
def bundle_class(self):
return CSVExtractorBundle
def factory_class(self):
return CSVExtractorFactory
@property
def model_class(self):
......
from traits.api import String
from force_bdss.api import bundle_id, BaseMCOBundle
from force_bdss.api import factory_id, BaseMCOFactory
from force_bdss.core_plugins.dummy.dummy_dakota.parameters import \
RangedMCOParameterFactory
......@@ -8,8 +8,8 @@ from .dakota_model import DummyDakotaModel
from .dakota_optimizer import DummyDakotaOptimizer
class DummyDakotaBundle(BaseMCOBundle):
id = String(bundle_id("enthought", "dummy_dakota"))
class DummyDakotaFactory(BaseMCOFactory):
id = String(factory_id("enthought", "dummy_dakota"))
name = "Dummy Dakota"
......
......@@ -28,7 +28,7 @@ class DummyDakotaOptimizer(BaseMCO):
value_iterator = itertools.product(*values)
application = self.bundle.plugin.application
application = self.factory.plugin.application
for value in value_iterator:
ps = subprocess.Popen(
......
......@@ -9,8 +9,8 @@ except ImportError:
from envisage.plugin import Plugin
from force_bdss.core_plugins.dummy.dummy_dakota.dakota_bundle import (
DummyDakotaBundle)
from force_bdss.core_plugins.dummy.dummy_dakota.dakota_factory import (
DummyDakotaFactory)
from force_bdss.mco.parameters.base_mco_parameter_factory import \
BaseMCOParameterFactory
......@@ -20,13 +20,13 @@ from force_bdss.core_plugins.dummy.dummy_dakota.parameters import \
class TestDakotaCommunicator(unittest.TestCase):
def test_receive_from_mco(self):
bundle = DummyDakotaBundle(mock.Mock(spec=Plugin))
factory = DummyDakotaFactory(mock.Mock(spec=Plugin))
mock_parameter_factory = mock.Mock(spec=BaseMCOParameterFactory)
model = bundle.create_model()
model = factory.create_model()
model.parameters = [
RangedMCOParameter(mock_parameter_factory)
]
comm = bundle.create_communicator()
comm = factory.create_communicator()
with mock.patch("sys.stdin") as stdin:
stdin.read.return_value = "1"
......@@ -38,9 +38,9 @@ class TestDakotaCommunicator(unittest.TestCase):
self.assertEqual(data[0].type, "")
def test_send_to_mco(self):
bundle = DummyDakotaBundle(mock.Mock(spec=Plugin))
model = bundle.create_model()
comm = bundle.create_communicator()
factory = DummyDakotaFactory(mock.Mock(spec=Plugin))
model = factory.create_model()
comm = factory.create_communicator()
with mock.patch("sys.stdout") as stdout:
dv = DataValue(value=100)
......
......@@ -2,8 +2,8 @@ import unittest
from envisage.plugin import Plugin
from force_bdss.core_plugins.dummy.dummy_dakota.dakota_bundle import \
DummyDakotaBundle
from force_bdss.core_plugins.dummy.dummy_dakota.dakota_factory import \
DummyDakotaFactory
from force_bdss.core_plugins.dummy.dummy_dakota.dakota_model import \
DummyDakotaModel
from force_bdss.core_plugins.dummy.dummy_dakota.dakota_optimizer import \
......@@ -15,33 +15,33 @@ except ImportError:
from unittest import mock
class TestDakotaBundle(unittest.TestCase):
class TestDakotaFactory(unittest.TestCase):
def setUp(self):
self.plugin = mock.Mock(spec=Plugin)
def test_initialization(self):
bundle = DummyDakotaBundle(self.plugin)
self.assertIn("dummy_dakota", bundle.id)
self.assertEqual(bundle.plugin, self.plugin)
factory = DummyDakotaFactory(self.plugin)
self.assertIn("dummy_dakota", factory.id)
self.assertEqual(factory.plugin, self.plugin)
def test_create_model(self):
bundle = DummyDakotaBundle(self.plugin)
model = bundle.create_model({})
factory = DummyDakotaFactory(self.plugin)
model = factory.create_model({})
self.assertIsInstance(model, DummyDakotaModel)
model = bundle.create_model()
model = factory.create_model()
self.assertIsInstance(model, DummyDakotaModel)
def test_create_mco(self):
bundle = DummyDakotaBundle(self.plugin)
ds = bundle.create_optimizer()
factory = DummyDakotaFactory(self.plugin)
ds = factory.create_optimizer()
self.assertIsInstance(ds, DummyDakotaOptimizer)
def test_create_communicator(self):
bundle = DummyDakotaBundle(self.plugin)
ds = bundle.create_optimizer()
factory = DummyDakotaFactory(self.plugin)
ds = factory.create_optimizer()
self.assertIsInstance(ds, DummyDakotaOptimizer)
def test_parameter_factories(self):
bundle = DummyDakotaBundle(self.plugin)
self.assertNotEqual(len(bundle.parameter_factories()), 0)
factory = DummyDakotaFactory(self.plugin)
self.assertNotEqual(len(factory.parameter_factories()), 0)
......@@ -7,7 +7,7 @@ from force_bdss.core_plugins.dummy.dummy_dakota.parameters import (
from force_bdss.core_plugins.dummy.dummy_dakota.dakota_model import (
DummyDakotaModel
)
from force_bdss.mco.base_mco_bundle import BaseMCOBundle
from force_bdss.mco.base_mco_factory import BaseMCOFactory
try:
import mock
......@@ -20,18 +20,18 @@ from force_bdss.core_plugins.dummy.dummy_dakota.dakota_optimizer import \
class TestDakotaOptimizer(unittest.TestCase):
def setUp(self):
self.bundle = mock.Mock(spec=BaseMCOBundle)
self.bundle.plugin = mock.Mock()
self.bundle.plugin.application = mock.Mock()
self.bundle.plugin.application.workflow_filepath = "whatever"
self.factory = mock.Mock(spec=BaseMCOFactory)
self.factory.plugin = mock.Mock()
self.factory.plugin.application = mock.Mock()
self.factory.plugin.application.workflow_filepath = "whatever"
def test_initialization(self):
opt = DummyDakotaOptimizer(self.bundle)
self.assertEqual(opt.bundle, self.bundle)
opt = DummyDakotaOptimizer(self.factory)
self.assertEqual(opt.factory, self.factory)
def test_run(self):
opt = DummyDakotaOptimizer(self.bundle)
model = DummyDakotaModel(self.bundle)
opt = DummyDakotaOptimizer(self.factory)
model = DummyDakotaModel(self.factory)
model.parameters = [
RangedMCOParameter(
mock.Mock(spec=RangedMCOParameterFactory),
......
from force_bdss.api import BaseDataSourceBundle, bundle_id
from force_bdss.api import BaseDataSourceFactory, factory_id
from .dummy_data_source_model import DummyDataSourceModel
from .dummy_data_source import DummyDataSource
class DummyDataSourceBundle(BaseDataSourceBundle):
id = bundle_id("enthought", "dummy_data_source")
class DummyDataSourceFactory(BaseDataSourceFactory):
id = factory_id("enthought", "dummy_data_source")
def create_model(self, model_data=None):
if model_data is None:
......
......@@ -5,8 +5,8 @@ from force_bdss.core_plugins.dummy.dummy_data_source.dummy_data_source import \
from force_bdss.core_plugins.dummy.dummy_data_source.dummy_data_source_model\
import \
DummyDataSourceModel
from force_bdss.data_sources.base_data_source_bundle import \
BaseDataSourceBundle
from force_bdss.data_sources.base_data_source_factory import \
BaseDataSourceFactory
try:
import mock
......@@ -16,14 +16,14 @@ except ImportError:
class TestDummyDataSource(unittest.TestCase):
def setUp(self):
self.bundle = mock.Mock(spec=BaseDataSourceBundle)
self.factory = mock.Mock(spec=BaseDataSourceFactory)
def test_initialization(self):
ds = DummyDataSource(self.bundle)
self.assertEqual(ds.bundle, self.bundle)
ds = DummyDataSource(self.factory)
self.assertEqual(ds.factory, self.factory)
def test_slots(self):
ds = DummyDataSource(self.bundle)
model = DummyDataSourceModel(self.bundle)
ds = DummyDataSource(self.factory)
model = DummyDataSourceModel(self.factory)
slots = ds.slots(model)
self.assertEqual(slots, ((), ()))
......@@ -3,17 +3,18 @@ import unittest
from force_bdss.core_plugins.dummy.dummy_data_source.dummy_data_source import \
DummyDataSource
from force_bdss.core_plugins.dummy.dummy_data_source\
.dummy_data_source_bundle import DummyDataSourceBundle
.dummy_data_source_factory import DummyDataSourceFactory
from force_bdss.core_plugins.dummy.dummy_data_source.dummy_data_source_model\
import DummyDataSourceModel
from force_bdss.core_plugins.dummy.tests.data_source_bundle_test_mixin import \
DataSourceBundleTestMixin
from force_bdss.core_plugins.dummy.tests.data_source_factory_test_mixin \
import DataSourceFactoryTestMixin
class TestDummyDataSourceBundle(DataSourceBundleTestMixin, unittest.TestCase):
class TestDummyDataSourceFactory(
DataSourceFactoryTestMixin, unittest.TestCase):
@property
def bundle_class(self):
return DummyDataSourceBundle
def factory_class(self):
return DummyDataSourceFactory
@property
def model_class(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment