Skip to content
Snippets Groups Projects
Commit 592e7195 authored by Stefano Borini's avatar Stefano Borini
Browse files

Fixed tests

parent 1ee0cd0e
No related branches found
No related tags found
1 merge request!60Moved parameters into MCO
......@@ -201,7 +201,7 @@ class WorkflowReader(HasStrictTraits):
for p in parameters_data:
id = p["id"]
factory = registry.get_factory(mco_id, id)
factory = registry.get_factory_by_id(mco_id, id)
model = factory.create_model(p["model_data"])
parameters.append(model)
......
from traits.api import HasStrictTraits, Dict, String
from traits.api import HasStrictTraits, Dict, Tuple, String
from force_bdss.mco.parameters.base_mco_parameter_factory import \
BaseMCOParameterFactory
......@@ -8,7 +8,7 @@ class MCOParameterFactoryRegistry(HasStrictTraits):
"""Registry to keep the parameter factories and lookup them.
"""
# Temp: this will become an extension point.
factories = Dict(String, BaseMCOParameterFactory)
factories = Dict(Tuple(String, String), BaseMCOParameterFactory)
def get_factory_by_id(self, bundle_id, parameter_factory_id):
"""Finds the factory by its id, so that we can obtain it as from
......
import unittest
from force_bdss.mco.base_mco_bundle import BaseMCOBundle
try:
import mock
except ImportError:
from unittest import mock
from traits.api import Int
......@@ -20,7 +27,7 @@ class DummyMCOParameterFactory(BaseMCOParameterFactory):
class TestBaseMCOParameterFactory(unittest.TestCase):
def test_initialization(self):
factory = DummyMCOParameterFactory()
factory = DummyMCOParameterFactory(mock.Mock(spec=BaseMCOBundle))
model = factory.create_model({"x": 42})
self.assertIsInstance(model, DummyMCOParameter)
self.assertEqual(model.x, 42)
import unittest
from force_bdss.core_plugins.dummy.dummy_dakota import parameters
from force_bdss.mco.parameters.base_mco_parameter_factory import \
BaseMCOParameterFactory
class TestCoreMCOParameters(unittest.TestCase):
def test_all_factories(self):
factories = parameters.all_core_factories()
self.assertEqual(len(factories), 1)
for f in factories:
self.assertIsInstance(f, BaseMCOParameterFactory)
......@@ -24,6 +24,9 @@ class DummyMCOBundle(BaseMCOBundle):
def create_communicator(self):
pass
def parameter_factories(self):
return []
class TestBaseMCOBundle(unittest.TestCase):
def test_initialization(self):
......
import unittest
from traits.api import Float
from force_bdss.bundle_registry_plugin import BundleRegistryPlugin
from force_bdss.data_sources.base_data_source import BaseDataSource
from force_bdss.data_sources.base_data_source_bundle import \
BaseDataSourceBundle
from force_bdss.data_sources.base_data_source_model import BaseDataSourceModel
from force_bdss.ids import mco_parameter_id, bundle_id
from force_bdss.kpi.base_kpi_calculator import BaseKPICalculator
from force_bdss.kpi.base_kpi_calculator_bundle import BaseKPICalculatorBundle
from force_bdss.kpi.base_kpi_calculator_model import BaseKPICalculatorModel
......@@ -12,6 +13,9 @@ from force_bdss.mco.base_mco import BaseMCO
from force_bdss.mco.base_mco_bundle import BaseMCOBundle
from force_bdss.mco.base_mco_communicator import BaseMCOCommunicator
from force_bdss.mco.base_mco_model import BaseMCOModel
from force_bdss.mco.parameters.base_mco_parameter import BaseMCOParameter
from force_bdss.mco.parameters.base_mco_parameter_factory import \
BaseMCOParameterFactory
from force_bdss.tests import fixtures
try:
......@@ -33,6 +37,17 @@ class NullMCO(BaseMCO):
pass
class NullParameter(BaseMCOParameter):
initial_value = Float()
lower_bound = Float()
upper_bound = Float()
class NullParameterFactory(BaseMCOParameterFactory):
id = mco_parameter_id("enthought", "ranged")
model_class = NullParameter
class NullMCOCommunicator(BaseMCOCommunicator):
def send_to_mco(self, model, kpi_results):
pass
......@@ -42,6 +57,8 @@ class NullMCOCommunicator(BaseMCOCommunicator):
class NullMCOBundle(BaseMCOBundle):
id = bundle_id("enthought", "dummy_dakota")
def create_model(self, model_data=None):
return NullMCOModel(self)
......@@ -51,6 +68,9 @@ class NullMCOBundle(BaseMCOBundle):
def create_optimizer(self):
return NullMCO(self)
def parameter_factories(self):
return [NullParameterFactory(self)]
class NullKPICalculatorModel(BaseKPICalculatorModel):
pass
......@@ -88,6 +108,8 @@ class NullDataSourceBundle(BaseDataSourceBundle):
def mock_bundle_registry_plugin():
bundle_registry_plugin = mock.Mock(spec=BundleRegistryPlugin)
bundle_registry_plugin.mco_bundles = [
NullMCOBundle(bundle_registry_plugin)]
bundle_registry_plugin.mco_bundle_by_id = mock.Mock(
return_value=NullMCOBundle(bundle_registry_plugin))
bundle_registry_plugin.kpi_calculator_bundle_by_id = mock.Mock(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment