From 1ee0cd0e357d68b966b262394e4ec814443380bd Mon Sep 17 00:00:00 2001 From: Stefano Borini <sborini@enthought.com> Date: Fri, 28 Jul 2017 17:07:45 +0100 Subject: [PATCH] Moved parameters into mco --- force_bdss/base_core_driver.py | 9 +++++---- force_bdss/bundle_registry_plugin.py | 10 ++++++++++ .../dummy/dummy_dakota/dakota_bundle.py | 7 +++++++ .../dummy/dummy_dakota/parameters.py} | 15 ++------------- .../tests/test_dakota_communicator.py | 2 +- .../dummy_dakota/tests/test_dakota_optimizer.py | 4 ++-- force_bdss/io/workflow_reader.py | 5 +++-- force_bdss/mco/base_mco_bundle.py | 9 +++++++++ .../mco/parameters/base_mco_parameter_factory.py | 11 +++++++++-- .../parameters/mco_parameter_factory_registry.py | 6 +++--- .../parameters/tests/test_core_mco_parameters.py | 4 ++-- 11 files changed, 53 insertions(+), 29 deletions(-) rename force_bdss/{mco/parameters/core_mco_parameters.py => core_plugins/dummy/dummy_dakota/parameters.py} (58%) diff --git a/force_bdss/base_core_driver.py b/force_bdss/base_core_driver.py index 6b43b4c..e17d586 100644 --- a/force_bdss/base_core_driver.py +++ b/force_bdss/base_core_driver.py @@ -6,10 +6,9 @@ from .bundle_registry_plugin import ( BUNDLE_REGISTRY_PLUGIN_ID ) from .io.workflow_reader import WorkflowReader -from .workspecs.workflow import Workflow from .mco.parameters.mco_parameter_factory_registry import ( MCOParameterFactoryRegistry) -from .mco.parameters.core_mco_parameters import all_core_factories +from .workspecs.workflow import Workflow class BaseCoreDriver(Plugin): @@ -31,8 +30,10 @@ class BaseCoreDriver(Plugin): def _parameter_factory_registry_default(self): registry = MCOParameterFactoryRegistry() - for f in all_core_factories(): - registry.register(f) + + for mco_bundle in self.bundle_registry.mco_bundles: + for factory in mco_bundle.parameter_factories(): + registry.register(factory) return registry diff --git a/force_bdss/bundle_registry_plugin.py b/force_bdss/bundle_registry_plugin.py index 24bc5dd..aa22110 100644 --- a/force_bdss/bundle_registry_plugin.py +++ b/force_bdss/bundle_registry_plugin.py @@ -107,3 +107,13 @@ class BundleRegistryPlugin(Plugin): raise ValueError("Requested MCO {} but don't know how " "to find it.".format(id)) + + def mco_parameters_by_id(self, mco_id, parameter_id): + mco_bundle = self.mco_bundle_by_id(mco_id) + + for factory in mco_bundle.parameter_factories(): + if factory.id == parameter_id: + return factory + + raise ValueError("Requested MCO parameter {}:{} but don't know" + " how to find it.".format(mco_id, parameter_id)) diff --git a/force_bdss/core_plugins/dummy/dummy_dakota/dakota_bundle.py b/force_bdss/core_plugins/dummy/dummy_dakota/dakota_bundle.py index 925a127..189c27c 100644 --- a/force_bdss/core_plugins/dummy/dummy_dakota/dakota_bundle.py +++ b/force_bdss/core_plugins/dummy/dummy_dakota/dakota_bundle.py @@ -1,5 +1,7 @@ from traits.api import String from force_bdss.api import bundle_id, BaseMCOBundle +from force_bdss.core_plugins.dummy.dummy_dakota.parameters import \ + RangedMCOParameterFactory from .dakota_communicator import DummyDakotaCommunicator from .dakota_model import DummyDakotaModel @@ -21,3 +23,8 @@ class DummyDakotaBundle(BaseMCOBundle): def create_communicator(self): return DummyDakotaCommunicator(self) + + def parameter_factories(self): + return [ + RangedMCOParameterFactory(self) + ] diff --git a/force_bdss/mco/parameters/core_mco_parameters.py b/force_bdss/core_plugins/dummy/dummy_dakota/parameters.py similarity index 58% rename from force_bdss/mco/parameters/core_mco_parameters.py rename to force_bdss/core_plugins/dummy/dummy_dakota/parameters.py index b9efa4e..6585312 100644 --- a/force_bdss/mco/parameters/core_mco_parameters.py +++ b/force_bdss/core_plugins/dummy/dummy_dakota/parameters.py @@ -1,7 +1,7 @@ from traits.api import Float -from ...ids import mco_parameter_id -from .base_mco_parameter import BaseMCOParameter +from force_bdss.ids import mco_parameter_id +from force_bdss.mco.parameters.base_mco_parameter import BaseMCOParameter from force_bdss.mco.parameters.base_mco_parameter_factory import \ BaseMCOParameterFactory @@ -20,14 +20,3 @@ class RangedMCOParameterFactory(BaseMCOParameterFactory): model_class = RangedMCOParameter name = "Range" description = "A ranged parameter in floating point values." - - -def all_core_factories(): - """Produces a list of all factories contained in this module.""" - import inspect - - return [c() - for c in inspect.getmodule(all_core_factories).__dict__.values() - if inspect.isclass(c) and - c is not BaseMCOParameterFactory and - issubclass(c, BaseMCOParameterFactory)] diff --git a/force_bdss/core_plugins/dummy/dummy_dakota/tests/test_dakota_communicator.py b/force_bdss/core_plugins/dummy/dummy_dakota/tests/test_dakota_communicator.py index e7bc0eb..97cf108 100644 --- a/force_bdss/core_plugins/dummy/dummy_dakota/tests/test_dakota_communicator.py +++ b/force_bdss/core_plugins/dummy/dummy_dakota/tests/test_dakota_communicator.py @@ -13,7 +13,7 @@ from force_bdss.data_sources.data_source_parameters import DataSourceParameters from force_bdss.mco.parameters.base_mco_parameter_factory import \ BaseMCOParameterFactory -from force_bdss.mco.parameters.core_mco_parameters import RangedMCOParameter +from force_bdss.core_plugins.dummy.dummy_dakota.parameters import RangedMCOParameter class TestDakotaCommunicator(unittest.TestCase): diff --git a/force_bdss/core_plugins/dummy/dummy_dakota/tests/test_dakota_optimizer.py b/force_bdss/core_plugins/dummy/dummy_dakota/tests/test_dakota_optimizer.py index b0b9f85..40ce8cc 100644 --- a/force_bdss/core_plugins/dummy/dummy_dakota/tests/test_dakota_optimizer.py +++ b/force_bdss/core_plugins/dummy/dummy_dakota/tests/test_dakota_optimizer.py @@ -1,10 +1,10 @@ import unittest +from force_bdss.core_plugins.dummy.dummy_dakota.parameters import RangedMCOParameter, \ + RangedMCOParameterFactory from force_bdss.core_plugins.dummy.dummy_dakota.dakota_model import \ DummyDakotaModel from force_bdss.mco.base_mco_bundle import BaseMCOBundle -from force_bdss.mco.parameters.core_mco_parameters import RangedMCOParameter, \ - RangedMCOParameterFactory try: import mock diff --git a/force_bdss/io/workflow_reader.py b/force_bdss/io/workflow_reader.py index e12e13f..e13f6ac 100644 --- a/force_bdss/io/workflow_reader.py +++ b/force_bdss/io/workflow_reader.py @@ -128,6 +128,7 @@ class WorkflowReader(HasStrictTraits): mco_bundle = registry.mco_bundle_by_id(mco_id) model_data = wf_data["mco"]["model_data"] model_data["parameters"] = self._extract_mco_parameters( + mco_id, model_data["parameters"]) model = mco_bundle.create_model( wf_data["mco"]["model_data"]) @@ -182,7 +183,7 @@ class WorkflowReader(HasStrictTraits): return kpi_calculators - def _extract_mco_parameters(self, parameters_data): + def _extract_mco_parameters(self, mco_id, parameters_data): """Extracts the MCO parameters from the data as dictionary. Parameters @@ -200,7 +201,7 @@ class WorkflowReader(HasStrictTraits): for p in parameters_data: id = p["id"] - factory = registry.get_factory_by_id(id) + factory = registry.get_factory(mco_id, id) model = factory.create_model(p["model_data"]) parameters.append(model) diff --git a/force_bdss/mco/base_mco_bundle.py b/force_bdss/mco/base_mco_bundle.py index 5fe9fdb..c8e0d6d 100644 --- a/force_bdss/mco/base_mco_bundle.py +++ b/force_bdss/mco/base_mco_bundle.py @@ -68,3 +68,12 @@ class BaseMCOBundle(ABCHasStrictTraits): BaseMCOCommunicator An instance of the communicator """ + + @abc.abstractmethod + def parameter_factories(self): + """Returns the parameter factories supported by this MCO + + Returns + ------- + List of BaseMCOParameterFactory + """ diff --git a/force_bdss/mco/parameters/base_mco_parameter_factory.py b/force_bdss/mco/parameters/base_mco_parameter_factory.py index 6124173..f74e325 100644 --- a/force_bdss/mco/parameters/base_mco_parameter_factory.py +++ b/force_bdss/mco/parameters/base_mco_parameter_factory.py @@ -1,5 +1,6 @@ -from traits.has_traits import HasStrictTraits -from traits.trait_types import String, Type +from traits.api import HasStrictTraits, String, Type, Instance + +from ..base_mco_bundle import BaseMCOBundle class BaseMCOParameterFactory(HasStrictTraits): @@ -8,6 +9,8 @@ class BaseMCOParameterFactory(HasStrictTraits): Must be reimplemented for the specific parameter.""" + bundle = Instance(BaseMCOBundle) + #: A unique string identifying the parameter id = String() @@ -20,6 +23,10 @@ class BaseMCOParameterFactory(HasStrictTraits): # The model class to instantiate when create_model is called. model_class = Type('BaseMCOParameter') + def __init__(self, bundle): + self.bundle = bundle + super(BaseMCOParameterFactory, self).__init__() + def create_model(self, data_values=None): """Creates the instance of the model class and returns it. """ diff --git a/force_bdss/mco/parameters/mco_parameter_factory_registry.py b/force_bdss/mco/parameters/mco_parameter_factory_registry.py index 709f7f9..1920ac0 100644 --- a/force_bdss/mco/parameters/mco_parameter_factory_registry.py +++ b/force_bdss/mco/parameters/mco_parameter_factory_registry.py @@ -10,12 +10,12 @@ class MCOParameterFactoryRegistry(HasStrictTraits): # Temp: this will become an extension point. factories = Dict(String, BaseMCOParameterFactory) - def get_factory_by_id(self, id): + def get_factory_by_id(self, bundle_id, parameter_factory_id): """Finds the factory by its id, so that we can obtain it as from the id in the model file. """ - return self.factories[id] + return self.factories[(bundle_id, parameter_factory_id)] def register(self, factory): """Registers a new factory""" - self.factories[factory.id] = factory + self.factories[(factory.bundle.id, factory.id)] = factory diff --git a/force_bdss/mco/parameters/tests/test_core_mco_parameters.py b/force_bdss/mco/parameters/tests/test_core_mco_parameters.py index af6f4c9..8bff8fe 100644 --- a/force_bdss/mco/parameters/tests/test_core_mco_parameters.py +++ b/force_bdss/mco/parameters/tests/test_core_mco_parameters.py @@ -1,13 +1,13 @@ import unittest -from force_bdss.mco.parameters import core_mco_parameters +from force_bdss.core_plugins.dummy.dummy_dakota import parameters from force_bdss.mco.parameters.base_mco_parameter_factory import \ BaseMCOParameterFactory class TestCoreMCOParameters(unittest.TestCase): def test_all_factories(self): - factories = core_mco_parameters.all_core_factories() + factories = parameters.all_core_factories() self.assertEqual(len(factories), 1) for f in factories: -- GitLab