diff --git a/doc/source/api/force_bdss.core_plugins.dummy.dummy_dakota.rst b/doc/source/api/force_bdss.core_plugins.dummy.dummy_dakota.rst index 61c462f5e1b6dcf940904290845e913f67c73fbf..0d90d4a58cfd87477922ea9af947b4c9fca5c450 100644 --- a/doc/source/api/force_bdss.core_plugins.dummy.dummy_dakota.rst +++ b/doc/source/api/force_bdss.core_plugins.dummy.dummy_dakota.rst @@ -43,6 +43,14 @@ force_bdss.core_plugins.dummy.dummy_dakota.dakota_optimizer module :undoc-members: :show-inheritance: +force_bdss.core_plugins.dummy.dummy_dakota.parameters module +------------------------------------------------------------ + +.. automodule:: force_bdss.core_plugins.dummy.dummy_dakota.parameters + :members: + :undoc-members: + :show-inheritance: + Module contents --------------- diff --git a/doc/source/api/force_bdss.mco.parameters.rst b/doc/source/api/force_bdss.mco.parameters.rst index f289f438bb64123f0dff84ac4d30be8e342d72db..bf6d87c96ed05a7800adc6f8f1040450d9d9c448 100644 --- a/doc/source/api/force_bdss.mco.parameters.rst +++ b/doc/source/api/force_bdss.mco.parameters.rst @@ -27,22 +27,6 @@ force_bdss.mco.parameters.base_mco_parameter_factory module :undoc-members: :show-inheritance: -force_bdss.mco.parameters.core_mco_parameters module ----------------------------------------------------- - -.. automodule:: force_bdss.mco.parameters.core_mco_parameters - :members: - :undoc-members: - :show-inheritance: - -force_bdss.mco.parameters.mco_parameter_factory_registry module ---------------------------------------------------------------- - -.. automodule:: force_bdss.mco.parameters.mco_parameter_factory_registry - :members: - :undoc-members: - :show-inheritance: - Module contents --------------- diff --git a/doc/source/api/force_bdss.mco.parameters.tests.rst b/doc/source/api/force_bdss.mco.parameters.tests.rst index 694cb158a833bfa1c3860bf12166bd71ab93788a..def0b6195bb19b101d98a0870341e9778f292557 100644 --- a/doc/source/api/force_bdss.mco.parameters.tests.rst +++ b/doc/source/api/force_bdss.mco.parameters.tests.rst @@ -20,22 +20,6 @@ force_bdss.mco.parameters.tests.test_base_mco_parameter_factory module :undoc-members: :show-inheritance: -force_bdss.mco.parameters.tests.test_core_mco_parameters module ---------------------------------------------------------------- - -.. automodule:: force_bdss.mco.parameters.tests.test_core_mco_parameters - :members: - :undoc-members: - :show-inheritance: - -force_bdss.mco.parameters.tests.test_parameter_factory_registry module ----------------------------------------------------------------------- - -.. automodule:: force_bdss.mco.parameters.tests.test_parameter_factory_registry - :members: - :undoc-members: - :show-inheritance: - Module contents --------------- diff --git a/force_bdss/base_core_driver.py b/force_bdss/base_core_driver.py index 6b43b4cec5a09370f51edab719cf9e0cd5254f0d..45b1f2265f1811adf1c1351f79fc0069c6430db4 100644 --- a/force_bdss/base_core_driver.py +++ b/force_bdss/base_core_driver.py @@ -7,9 +7,6 @@ from .bundle_registry_plugin import ( ) from .io.workflow_reader import WorkflowReader from .workspecs.workflow import Workflow -from .mco.parameters.mco_parameter_factory_registry import ( - MCOParameterFactoryRegistry) -from .mco.parameters.core_mco_parameters import all_core_factories class BaseCoreDriver(Plugin): @@ -20,24 +17,13 @@ class BaseCoreDriver(Plugin): #: The registry of the bundles. bundle_registry = Instance(BundleRegistryPlugin) - #: The registry of the MCO parameters - parameter_factory_registry = Instance(MCOParameterFactoryRegistry) - #: Deserialized content of the workflow file. workflow = Instance(Workflow) def _bundle_registry_default(self): return self.application.get_plugin(BUNDLE_REGISTRY_PLUGIN_ID) - def _parameter_factory_registry_default(self): - registry = MCOParameterFactoryRegistry() - for f in all_core_factories(): - registry.register(f) - - return registry - def _workflow_default(self): - reader = WorkflowReader(self.bundle_registry, - self.parameter_factory_registry) + reader = WorkflowReader(self.bundle_registry) with open(self.application.workflow_filepath) as f: return reader.read(f) diff --git a/force_bdss/bundle_registry_plugin.py b/force_bdss/bundle_registry_plugin.py index 24bc5ddcb442de1bbfcbbe9b191beaf1c3d01853..44e7d4808f6e58429dc1efa566b90d3fd0afa90e 100644 --- a/force_bdss/bundle_registry_plugin.py +++ b/force_bdss/bundle_registry_plugin.py @@ -55,13 +55,13 @@ class BundleRegistryPlugin(Plugin): Raises ------ - ValueError: if the entry is not found. + KeyError: if the entry is not found. """ for ds in self.data_source_bundles: if ds.id == id: return ds - raise ValueError( + raise KeyError( "Requested data source {} but don't know how " "to find it.".format(id)) @@ -77,13 +77,13 @@ class BundleRegistryPlugin(Plugin): Raises ------ - ValueError: if the entry is not found. + KeyError: if the entry is not found. """ for kpic in self.kpi_calculator_bundles: if kpic.id == id: return kpic - raise ValueError( + raise KeyError( "Requested kpi calculator {} but don't know how " "to find it.".format(id)) @@ -99,11 +99,40 @@ class BundleRegistryPlugin(Plugin): Raises ------ - ValueError: if the entry is not found. + KeyError: if the entry is not found. """ for mco in self.mco_bundles: if mco.id == id: return mco - raise ValueError("Requested MCO {} but don't know how " - "to find it.".format(id)) + raise KeyError("Requested MCO {} but don't know how " + "to find it.".format(id)) + + def mco_parameter_factory_by_id(self, mco_id, parameter_id): + """Retrieves the MCO parameter factory for a given MCO id and + parameter id. + + Parameters + ---------- + mco_id: str + The MCO identifier string + parameter_id: str + the parameter identifier string + + Returns + ------- + An instance of BaseMCOParameterFactory. + + Raises + ------ + KeyError: + if the entry is not found + """ + mco_bundle = self.mco_bundle_by_id(mco_id) + + for factory in mco_bundle.parameter_factories(): + if factory.id == parameter_id: + return factory + + raise KeyError("Requested MCO parameter {}:{} but don't know" + " how to find it.".format(mco_id, parameter_id)) diff --git a/force_bdss/core_plugins/dummy/dummy_dakota/dakota_bundle.py b/force_bdss/core_plugins/dummy/dummy_dakota/dakota_bundle.py index 925a127fa43b9c43f713373268a5e7831cf7abe6..189c27c1d2ecfc350995bd556483bbf5cb7269a2 100644 --- a/force_bdss/core_plugins/dummy/dummy_dakota/dakota_bundle.py +++ b/force_bdss/core_plugins/dummy/dummy_dakota/dakota_bundle.py @@ -1,5 +1,7 @@ from traits.api import String from force_bdss.api import bundle_id, BaseMCOBundle +from force_bdss.core_plugins.dummy.dummy_dakota.parameters import \ + RangedMCOParameterFactory from .dakota_communicator import DummyDakotaCommunicator from .dakota_model import DummyDakotaModel @@ -21,3 +23,8 @@ class DummyDakotaBundle(BaseMCOBundle): def create_communicator(self): return DummyDakotaCommunicator(self) + + def parameter_factories(self): + return [ + RangedMCOParameterFactory(self) + ] diff --git a/force_bdss/mco/parameters/core_mco_parameters.py b/force_bdss/core_plugins/dummy/dummy_dakota/parameters.py similarity index 54% rename from force_bdss/mco/parameters/core_mco_parameters.py rename to force_bdss/core_plugins/dummy/dummy_dakota/parameters.py index b9efa4e4f0b3921507420d2f64e16eddbce8da65..d15c6c6f26390b9ad1a4c7941f1e52a5ac7ba8df 100644 --- a/force_bdss/mco/parameters/core_mco_parameters.py +++ b/force_bdss/core_plugins/dummy/dummy_dakota/parameters.py @@ -1,7 +1,7 @@ from traits.api import Float -from ...ids import mco_parameter_id -from .base_mco_parameter import BaseMCOParameter +from force_bdss.ids import mco_parameter_id +from force_bdss.mco.parameters.base_mco_parameter import BaseMCOParameter from force_bdss.mco.parameters.base_mco_parameter_factory import \ BaseMCOParameterFactory @@ -16,18 +16,7 @@ class RangedMCOParameter(BaseMCOParameter): class RangedMCOParameterFactory(BaseMCOParameterFactory): """The factory of the above model""" - id = mco_parameter_id("enthought", "ranged") + id = mco_parameter_id("enthought", "dummy_dakota", "ranged") model_class = RangedMCOParameter name = "Range" description = "A ranged parameter in floating point values." - - -def all_core_factories(): - """Produces a list of all factories contained in this module.""" - import inspect - - return [c() - for c in inspect.getmodule(all_core_factories).__dict__.values() - if inspect.isclass(c) and - c is not BaseMCOParameterFactory and - issubclass(c, BaseMCOParameterFactory)] diff --git a/force_bdss/core_plugins/dummy/dummy_dakota/tests/test_dakota_bundle.py b/force_bdss/core_plugins/dummy/dummy_dakota/tests/test_dakota_bundle.py index b455c6d08d6e21659b0cb2346a41823b505f8c74..673fdd05f3fbf5da8fb436484d334cdac483e497 100644 --- a/force_bdss/core_plugins/dummy/dummy_dakota/tests/test_dakota_bundle.py +++ b/force_bdss/core_plugins/dummy/dummy_dakota/tests/test_dakota_bundle.py @@ -41,3 +41,7 @@ class TestDakotaBundle(unittest.TestCase): bundle = DummyDakotaBundle(self.plugin) ds = bundle.create_optimizer() self.assertIsInstance(ds, DummyDakotaOptimizer) + + def test_parameter_factories(self): + bundle = DummyDakotaBundle(self.plugin) + self.assertNotEqual(len(bundle.parameter_factories()), 0) diff --git a/force_bdss/core_plugins/dummy/dummy_dakota/tests/test_dakota_communicator.py b/force_bdss/core_plugins/dummy/dummy_dakota/tests/test_dakota_communicator.py index e7bc0ebfaa80052e1043f5f6f9185859aff6856d..6e74d744f66e52cbb510bacde7a3879f0d200818 100644 --- a/force_bdss/core_plugins/dummy/dummy_dakota/tests/test_dakota_communicator.py +++ b/force_bdss/core_plugins/dummy/dummy_dakota/tests/test_dakota_communicator.py @@ -13,7 +13,8 @@ from force_bdss.data_sources.data_source_parameters import DataSourceParameters from force_bdss.mco.parameters.base_mco_parameter_factory import \ BaseMCOParameterFactory -from force_bdss.mco.parameters.core_mco_parameters import RangedMCOParameter +from force_bdss.core_plugins.dummy.dummy_dakota.parameters import \ + RangedMCOParameter class TestDakotaCommunicator(unittest.TestCase): diff --git a/force_bdss/core_plugins/dummy/dummy_dakota/tests/test_dakota_optimizer.py b/force_bdss/core_plugins/dummy/dummy_dakota/tests/test_dakota_optimizer.py index b0b9f85ecb5909a4cc72cbc0e2ece04cc8ccc54e..b8be17e7f34a1444162c0e2d873f30dcb34ad321 100644 --- a/force_bdss/core_plugins/dummy/dummy_dakota/tests/test_dakota_optimizer.py +++ b/force_bdss/core_plugins/dummy/dummy_dakota/tests/test_dakota_optimizer.py @@ -1,10 +1,13 @@ import unittest -from force_bdss.core_plugins.dummy.dummy_dakota.dakota_model import \ +from force_bdss.core_plugins.dummy.dummy_dakota.parameters import ( + RangedMCOParameter, + RangedMCOParameterFactory +) +from force_bdss.core_plugins.dummy.dummy_dakota.dakota_model import ( DummyDakotaModel +) from force_bdss.mco.base_mco_bundle import BaseMCOBundle -from force_bdss.mco.parameters.core_mco_parameters import RangedMCOParameter, \ - RangedMCOParameterFactory try: import mock diff --git a/force_bdss/ids.py b/force_bdss/ids.py index 2af5c0b5541a3132eb9010067cc9194e4b5fd165..4e5c4144566323d7e93e9f5b65592358bbf0ab16 100644 --- a/force_bdss/ids.py +++ b/force_bdss/ids.py @@ -29,23 +29,27 @@ def bundle_id(producer, identifier): ------- str: an identifier to be used in the bundle. """ - return _string_id("bundle", producer, identifier) + return _string_id(producer, "bundle", identifier) -def mco_parameter_id(producer, identifier): +def mco_parameter_id(producer, mco_identifier, parameter_identifier): """Creates an ID for an MCO parameter, so that it can be identified uniquely.""" - return _string_id("mco_parameter", producer, identifier) + return _string_id(producer, + "bundle", + mco_identifier, + "parameter", + parameter_identifier) def plugin_id(producer, identifier): """Creates an ID for the plugins. These must be defined, otherwise the envisage system will complain (but not break) """ - return _string_id("plugin", producer, identifier) + return _string_id(producer, "plugin", identifier) -def _string_id(entity_namespace, producer, identifier): +def _string_id(*args): """Creates an id for a generic entity. Parameters @@ -68,7 +72,8 @@ def _string_id(entity_namespace, producer, identifier): " " not in entry and len(entry) != 0) - if not all(map(is_valid, [entity_namespace, producer, identifier])): - raise ValueError("Invalid parameters specified.") + if not all(map(is_valid, args)): + raise ValueError("One or more of the specified parameters was " + "invalid: {}".format(str(args))) - return "force.bdss.{}.{}.{}".format(entity_namespace, producer, identifier) + return ".".join(["force", "bdss"]+list(args)) diff --git a/force_bdss/io/tests/test_workflow_reader.py b/force_bdss/io/tests/test_workflow_reader.py index 58518c5b4212655e406cc5315af38effc44ece4b..9119702953768f065117f69df99ac039981f03e4 100644 --- a/force_bdss/io/tests/test_workflow_reader.py +++ b/force_bdss/io/tests/test_workflow_reader.py @@ -6,8 +6,6 @@ from force_bdss.bundle_registry_plugin import BundleRegistryPlugin from force_bdss.io.workflow_reader import ( WorkflowReader, InvalidVersionException, InvalidFileException) -from force_bdss.mco.parameters.mco_parameter_factory_registry import \ - MCOParameterFactoryRegistry try: import mock @@ -18,12 +16,8 @@ except ImportError: class TestWorkflowReader(unittest.TestCase): def setUp(self): self.mock_bundle_registry = mock.Mock(spec=BundleRegistryPlugin) - self.mock_mco_parameter_registry = mock.Mock( - spec=MCOParameterFactoryRegistry) - self.wfreader = WorkflowReader( - self.mock_bundle_registry, - self.mock_mco_parameter_registry) + self.wfreader = WorkflowReader(self.mock_bundle_registry) def test_initialization(self): self.assertEqual(self.wfreader.bundle_registry, diff --git a/force_bdss/io/tests/test_workflow_writer.py b/force_bdss/io/tests/test_workflow_writer.py index af32cd541ac54d833be1fa0486badf1cf4151357..df70880bf6cd64b44e8c274319eb1459ecc32b54 100644 --- a/force_bdss/io/tests/test_workflow_writer.py +++ b/force_bdss/io/tests/test_workflow_writer.py @@ -7,8 +7,6 @@ from force_bdss.io.workflow_reader import WorkflowReader from force_bdss.mco.parameters.base_mco_parameter import BaseMCOParameter from force_bdss.mco.parameters.base_mco_parameter_factory import \ BaseMCOParameterFactory -from force_bdss.mco.parameters.mco_parameter_factory_registry import \ - MCOParameterFactoryRegistry try: import mock @@ -37,9 +35,6 @@ class TestWorkflowWriter(unittest.TestCase): self.mock_registry.mco_bundle_by_id = mock.Mock( return_value=mock_mco_bundle) - self.mock_mco_parameter_registry = mock.Mock( - spec=MCOParameterFactoryRegistry) - def test_write(self): wfwriter = WorkflowWriter() fp = StringIO() @@ -58,8 +53,7 @@ class TestWorkflowWriter(unittest.TestCase): wf = self._create_mock_workflow() wfwriter.write(wf, fp) fp.seek(0) - wfreader = WorkflowReader(self.mock_registry, - self.mock_mco_parameter_registry) + wfreader = WorkflowReader(self.mock_registry) wf_result = wfreader.read(fp) self.assertEqual(wf_result.mco.bundle.id, wf.mco.bundle.id) @@ -74,7 +68,7 @@ class TestWorkflowWriter(unittest.TestCase): BaseMCOParameter( factory=mock.Mock( spec=BaseMCOParameterFactory, - id=mco_parameter_id("enthought", "mock") + id=mco_parameter_id("enthought", "mock", "mock") ) ) ] diff --git a/force_bdss/io/workflow_reader.py b/force_bdss/io/workflow_reader.py index e12e13f207c836b0ff07fc0067f8f43d97779e12..44677cf92b767fab18496bae2b972061ed3dd8fd 100644 --- a/force_bdss/io/workflow_reader.py +++ b/force_bdss/io/workflow_reader.py @@ -3,8 +3,6 @@ import logging from traits.api import HasStrictTraits, Instance -from ..mco.parameters.mco_parameter_factory_registry import ( - MCOParameterFactoryRegistry) from ..bundle_registry_plugin import BundleRegistryPlugin from ..workspecs.workflow import Workflow @@ -28,13 +26,8 @@ class WorkflowReader(HasStrictTraits): #: bundle-specific model objects. bundle_registry = Instance(BundleRegistryPlugin) - #: The registry for the MCO parameters. At the moment this - #: is not extensible via plugins as the one above. - mco_parameter_registry = Instance(MCOParameterFactoryRegistry) - def __init__(self, bundle_registry, - mco_parameter_registry, *args, **kwargs): """Initializes the reader. @@ -46,7 +39,6 @@ class WorkflowReader(HasStrictTraits): for a bundle identified by a given id. """ self.bundle_registry = bundle_registry - self.mco_parameter_registry = mco_parameter_registry super(WorkflowReader, self).__init__(*args, **kwargs) @@ -128,6 +120,7 @@ class WorkflowReader(HasStrictTraits): mco_bundle = registry.mco_bundle_by_id(mco_id) model_data = wf_data["mco"]["model_data"] model_data["parameters"] = self._extract_mco_parameters( + mco_id, model_data["parameters"]) model = mco_bundle.create_model( wf_data["mco"]["model_data"]) @@ -182,7 +175,7 @@ class WorkflowReader(HasStrictTraits): return kpi_calculators - def _extract_mco_parameters(self, parameters_data): + def _extract_mco_parameters(self, mco_id, parameters_data): """Extracts the MCO parameters from the data as dictionary. Parameters @@ -194,13 +187,13 @@ class WorkflowReader(HasStrictTraits): ------- List of instances of a subclass of BaseMCOParameter """ - registry = self.mco_parameter_registry + registry = self.bundle_registry parameters = [] for p in parameters_data: id = p["id"] - factory = registry.get_factory_by_id(id) + factory = registry.mco_parameter_factory_by_id(mco_id, id) model = factory.create_model(p["model_data"]) parameters.append(model) diff --git a/force_bdss/mco/base_mco_bundle.py b/force_bdss/mco/base_mco_bundle.py index 5fe9fdb648e9c738d92273f10d3f6fb0a08652ca..c8e0d6d730918507ac25c7faeedf37b17ce7df03 100644 --- a/force_bdss/mco/base_mco_bundle.py +++ b/force_bdss/mco/base_mco_bundle.py @@ -68,3 +68,12 @@ class BaseMCOBundle(ABCHasStrictTraits): BaseMCOCommunicator An instance of the communicator """ + + @abc.abstractmethod + def parameter_factories(self): + """Returns the parameter factories supported by this MCO + + Returns + ------- + List of BaseMCOParameterFactory + """ diff --git a/force_bdss/mco/parameters/base_mco_parameter_factory.py b/force_bdss/mco/parameters/base_mco_parameter_factory.py index 61241739e6564816e3eab15ddbe6e5726463267b..efaa70c9ef13884a486e29b9683139f7664566cc 100644 --- a/force_bdss/mco/parameters/base_mco_parameter_factory.py +++ b/force_bdss/mco/parameters/base_mco_parameter_factory.py @@ -1,12 +1,19 @@ -from traits.has_traits import HasStrictTraits -from traits.trait_types import String, Type +from traits.api import HasStrictTraits, String, Type, Instance + +from ..base_mco_bundle import BaseMCOBundle class BaseMCOParameterFactory(HasStrictTraits): """Factory that produces the model instance of a given BASEMCOParameter instance. - Must be reimplemented for the specific parameter.""" + Must be reimplemented for the specific parameter. The generic create_model + is generally enough, and the only entity to define is model_class with + the appropriate class of the parameter. + """ + + #: A reference to the bundle this parameter factory lives in. + bundle = Instance(BaseMCOBundle) #: A unique string identifying the parameter id = String() @@ -20,8 +27,24 @@ class BaseMCOParameterFactory(HasStrictTraits): # The model class to instantiate when create_model is called. model_class = Type('BaseMCOParameter') + def __init__(self, bundle): + self.bundle = bundle + super(BaseMCOParameterFactory, self).__init__() + def create_model(self, data_values=None): """Creates the instance of the model class and returns it. + You should not reimplement this, as the default is generally ok. + Instead, just define model_class with the appropriate Parameter class. + + Parameters + ---------- + data_values: dict or None + The dictionary of values for this parameter. If None, a default + object will be returned. + + Returns + ------- + instance of model_class. """ if data_values is None: data_values = {} diff --git a/force_bdss/mco/parameters/mco_parameter_factory_registry.py b/force_bdss/mco/parameters/mco_parameter_factory_registry.py deleted file mode 100644 index 709f7f9ecb0a39d787bf2f40d794ae373ff6728d..0000000000000000000000000000000000000000 --- a/force_bdss/mco/parameters/mco_parameter_factory_registry.py +++ /dev/null @@ -1,21 +0,0 @@ -from traits.api import HasStrictTraits, Dict, String - -from force_bdss.mco.parameters.base_mco_parameter_factory import \ - BaseMCOParameterFactory - - -class MCOParameterFactoryRegistry(HasStrictTraits): - """Registry to keep the parameter factories and lookup them. - """ - # Temp: this will become an extension point. - factories = Dict(String, BaseMCOParameterFactory) - - def get_factory_by_id(self, id): - """Finds the factory by its id, so that we can obtain it as from - the id in the model file. - """ - return self.factories[id] - - def register(self, factory): - """Registers a new factory""" - self.factories[factory.id] = factory diff --git a/force_bdss/mco/parameters/tests/test_base_mco_parameter_factory.py b/force_bdss/mco/parameters/tests/test_base_mco_parameter_factory.py index 3203aa718a8321e29b8d8a053d96f726886a1547..45d586defe8500c5c8b09101e496de8084f1d09b 100644 --- a/force_bdss/mco/parameters/tests/test_base_mco_parameter_factory.py +++ b/force_bdss/mco/parameters/tests/test_base_mco_parameter_factory.py @@ -1,4 +1,11 @@ import unittest + +from force_bdss.mco.base_mco_bundle import BaseMCOBundle + +try: + import mock +except ImportError: + from unittest import mock from traits.api import Int @@ -20,7 +27,11 @@ class DummyMCOParameterFactory(BaseMCOParameterFactory): class TestBaseMCOParameterFactory(unittest.TestCase): def test_initialization(self): - factory = DummyMCOParameterFactory() + factory = DummyMCOParameterFactory(mock.Mock(spec=BaseMCOBundle)) model = factory.create_model({"x": 42}) self.assertIsInstance(model, DummyMCOParameter) self.assertEqual(model.x, 42) + + model = factory.create_model() + self.assertIsInstance(model, DummyMCOParameter) + self.assertEqual(model.x, 0) diff --git a/force_bdss/mco/parameters/tests/test_core_mco_parameters.py b/force_bdss/mco/parameters/tests/test_core_mco_parameters.py deleted file mode 100644 index af6f4c96ab0b6536e44fd2d5465a7dd3934eb396..0000000000000000000000000000000000000000 --- a/force_bdss/mco/parameters/tests/test_core_mco_parameters.py +++ /dev/null @@ -1,14 +0,0 @@ -import unittest - -from force_bdss.mco.parameters import core_mco_parameters -from force_bdss.mco.parameters.base_mco_parameter_factory import \ - BaseMCOParameterFactory - - -class TestCoreMCOParameters(unittest.TestCase): - def test_all_factories(self): - factories = core_mco_parameters.all_core_factories() - self.assertEqual(len(factories), 1) - - for f in factories: - self.assertIsInstance(f, BaseMCOParameterFactory) diff --git a/force_bdss/mco/parameters/tests/test_parameter_factory_registry.py b/force_bdss/mco/parameters/tests/test_parameter_factory_registry.py deleted file mode 100644 index 2adbfc567ba55eb9fcfa30cc63882a1d076d2223..0000000000000000000000000000000000000000 --- a/force_bdss/mco/parameters/tests/test_parameter_factory_registry.py +++ /dev/null @@ -1,10 +0,0 @@ -import unittest - -from force_bdss.mco.parameters.mco_parameter_factory_registry import \ - MCOParameterFactoryRegistry - - -class TestParameterFactoryRegistry(unittest.TestCase): - def test_registry_init(self): - reg = MCOParameterFactoryRegistry() - self.assertEqual(reg.factories, {}) diff --git a/force_bdss/mco/tests/test_base_mco_bundle.py b/force_bdss/mco/tests/test_base_mco_bundle.py index fc173521607f0bc18ded75949a0dce8a48472616..d5dfddce43b13b444e85ba9ba790654ab781b5a6 100644 --- a/force_bdss/mco/tests/test_base_mco_bundle.py +++ b/force_bdss/mco/tests/test_base_mco_bundle.py @@ -24,6 +24,9 @@ class DummyMCOBundle(BaseMCOBundle): def create_communicator(self): pass + def parameter_factories(self): + return [] + class TestBaseMCOBundle(unittest.TestCase): def test_initialization(self): diff --git a/force_bdss/tests/fixtures/test_csv.json b/force_bdss/tests/fixtures/test_csv.json index 669ea67c685db36d30d7bbfab362340ba66cebfc..726de30843d158ec3b1aeb4afcc96e92b9ecaa5d 100644 --- a/force_bdss/tests/fixtures/test_csv.json +++ b/force_bdss/tests/fixtures/test_csv.json @@ -2,11 +2,11 @@ "version": "1", "workflow": { "mco": { - "id": "force.bdss.bundle.enthought.dummy_dakota", + "id": "force.bdss.enthought.bundle.dummy_dakota", "model_data": { "parameters" : [ { - "id": "force.bdss.mco_parameter.enthought.ranged", + "id": "force.bdss.enthought.bundle.dummy_dakota.parameter.ranged", "model_data": { "initial_value": 3, "lower_bound": 0, @@ -18,7 +18,7 @@ }, "data_sources": [ { - "id": "force.bdss.bundle.enthought.csv_extractor", + "id": "force.bdss.enthought.bundle.csv_extractor", "model_data": { "filename": "foo.csv", "row": 3, @@ -27,7 +27,7 @@ } }, { - "id": "force.bdss.bundle.enthought.csv_extractor", + "id": "force.bdss.enthought.bundle.csv_extractor", "model_data": { "filename": "foo.csv", "row": 3, @@ -38,7 +38,7 @@ ], "kpi_calculators": [ { - "id": "force.bdss.bundle.enthought.kpi_adder", + "id": "force.bdss.enthought.bundle.kpi_adder", "model_data": { "cuba_type_in": "PRESSURE", "cuba_type_out": "TOTAL_PRESSURE" diff --git a/force_bdss/tests/test_bundle_registry_plugin.py b/force_bdss/tests/test_bundle_registry_plugin.py index b1ebd018d37a02a059caa713a6bfa2d82398955c..3050087502e47b1424db5ab2ba19788c57c9f176 100644 --- a/force_bdss/tests/test_bundle_registry_plugin.py +++ b/force_bdss/tests/test_bundle_registry_plugin.py @@ -2,7 +2,9 @@ import unittest from force_bdss.base_extension_plugin import ( BaseExtensionPlugin) -from force_bdss.ids import bundle_id +from force_bdss.ids import bundle_id, mco_parameter_id +from force_bdss.mco.parameters.base_mco_parameter_factory import \ + BaseMCOParameterFactory try: import mock @@ -33,8 +35,17 @@ class TestBundleRegistry(unittest.TestCase): class MySuperPlugin(BaseExtensionPlugin): def _mco_bundles_default(self): - return [mock.Mock(spec=IMCOBundle, - id=bundle_id("enthought", "mco1"))] + return [ + mock.Mock( + spec=IMCOBundle, + id=bundle_id("enthought", "mco1"), + parameter_factories=mock.Mock(return_value=[ + mock.Mock( + spec=BaseMCOParameterFactory, + id=mco_parameter_id("enthought", "mco1", "ranged") + ) + ]), + )] def _data_source_bundles_default(self): return [mock.Mock(spec=IDataSourceBundle, @@ -64,8 +75,10 @@ class TestBundleRegistryWithContent(unittest.TestCase): self.assertEqual(len(self.plugin.kpi_calculator_bundles), 3) def test_lookup(self): - id = bundle_id("enthought", "mco1") - self.assertEqual(self.plugin.mco_bundle_by_id(id).id, id) + mco_id = bundle_id("enthought", "mco1") + parameter_id = mco_parameter_id("enthought", "mco1", "ranged") + self.assertEqual(self.plugin.mco_bundle_by_id(mco_id).id, mco_id) + self.plugin.mco_parameter_factory_by_id(mco_id, parameter_id) for entry in ["ds1", "ds2"]: id = bundle_id("enthought", entry) @@ -76,6 +89,31 @@ class TestBundleRegistryWithContent(unittest.TestCase): self.assertEqual(self.plugin.kpi_calculator_bundle_by_id(id).id, id) + with self.assertRaises(KeyError): + self.plugin.mco_bundle_by_id( + bundle_id("enthought", "foo")) + + with self.assertRaises(KeyError): + self.plugin.mco_parameter_factory_by_id( + mco_id, + mco_parameter_id("enthought", "mco1", "foo") + ) + + with self.assertRaises(KeyError): + self.plugin.data_source_bundle_by_id( + bundle_id("enthought", "foo") + ) + + with self.assertRaises(KeyError): + self.plugin.data_source_bundle_by_id( + bundle_id("enthought", "foo") + ) + + with self.assertRaises(KeyError): + self.plugin.kpi_calculator_bundle_by_id( + bundle_id("enthought", "foo") + ) + if __name__ == '__main__': unittest.main() diff --git a/force_bdss/tests/test_core_evaluation_driver.py b/force_bdss/tests/test_core_evaluation_driver.py index 9ececcf5c23c404b5fc7c75fe058fac05244de37..5be1849eba262b3460f68ad49f8e284d9bd61e77 100644 --- a/force_bdss/tests/test_core_evaluation_driver.py +++ b/force_bdss/tests/test_core_evaluation_driver.py @@ -1,10 +1,11 @@ import unittest - +from traits.api import Float from force_bdss.bundle_registry_plugin import BundleRegistryPlugin from force_bdss.data_sources.base_data_source import BaseDataSource from force_bdss.data_sources.base_data_source_bundle import \ BaseDataSourceBundle from force_bdss.data_sources.base_data_source_model import BaseDataSourceModel +from force_bdss.ids import mco_parameter_id, bundle_id from force_bdss.kpi.base_kpi_calculator import BaseKPICalculator from force_bdss.kpi.base_kpi_calculator_bundle import BaseKPICalculatorBundle from force_bdss.kpi.base_kpi_calculator_model import BaseKPICalculatorModel @@ -12,6 +13,9 @@ from force_bdss.mco.base_mco import BaseMCO from force_bdss.mco.base_mco_bundle import BaseMCOBundle from force_bdss.mco.base_mco_communicator import BaseMCOCommunicator from force_bdss.mco.base_mco_model import BaseMCOModel +from force_bdss.mco.parameters.base_mco_parameter import BaseMCOParameter +from force_bdss.mco.parameters.base_mco_parameter_factory import \ + BaseMCOParameterFactory from force_bdss.tests import fixtures try: @@ -33,6 +37,17 @@ class NullMCO(BaseMCO): pass +class NullParameter(BaseMCOParameter): + initial_value = Float() + lower_bound = Float() + upper_bound = Float() + + +class NullParameterFactory(BaseMCOParameterFactory): + id = mco_parameter_id("enthought", "dummy_dakota", "ranged") + model_class = NullParameter + + class NullMCOCommunicator(BaseMCOCommunicator): def send_to_mco(self, model, kpi_results): pass @@ -42,6 +57,8 @@ class NullMCOCommunicator(BaseMCOCommunicator): class NullMCOBundle(BaseMCOBundle): + id = bundle_id("enthought", "dummy_dakota") + def create_model(self, model_data=None): return NullMCOModel(self) @@ -51,6 +68,9 @@ class NullMCOBundle(BaseMCOBundle): def create_optimizer(self): return NullMCO(self) + def parameter_factories(self): + return [NullParameterFactory(self)] + class NullKPICalculatorModel(BaseKPICalculatorModel): pass @@ -88,6 +108,8 @@ class NullDataSourceBundle(BaseDataSourceBundle): def mock_bundle_registry_plugin(): bundle_registry_plugin = mock.Mock(spec=BundleRegistryPlugin) + bundle_registry_plugin.mco_bundles = [ + NullMCOBundle(bundle_registry_plugin)] bundle_registry_plugin.mco_bundle_by_id = mock.Mock( return_value=NullMCOBundle(bundle_registry_plugin)) bundle_registry_plugin.kpi_calculator_bundle_by_id = mock.Mock( diff --git a/force_bdss/tests/test_ids.py b/force_bdss/tests/test_ids.py index 1ebff9a8adfa40748165c75b4f467e6c4f2adb98..e8dbbb74a2fa4b2fc927786ab908590f24d54495 100644 --- a/force_bdss/tests/test_ids.py +++ b/force_bdss/tests/test_ids.py @@ -6,7 +6,7 @@ from force_bdss.ids import bundle_id, plugin_id class TestIdGenerators(unittest.TestCase): def test_bundle_id(self): self.assertEqual(bundle_id("foo", "bar"), - "force.bdss.bundle.foo.bar") + "force.bdss.foo.bundle.bar") for bad_entry in ["", None, " ", "foo bar"]: with self.assertRaises(ValueError): @@ -15,7 +15,7 @@ class TestIdGenerators(unittest.TestCase): bundle_id("foo", bad_entry) def test_plugin_id(self): - self.assertEqual(plugin_id("foo", "bar"), "force.bdss.plugin.foo.bar") + self.assertEqual(plugin_id("foo", "bar"), "force.bdss.foo.plugin.bar") for bad_entry in ["", None, " ", "foo bar"]: with self.assertRaises(ValueError):