Skip to content
Snippets Groups Projects
Commit 1ee0cd0e authored by Stefano Borini's avatar Stefano Borini
Browse files

Moved parameters into mco

parent ea2e6583
No related branches found
No related tags found
1 merge request!60Moved parameters into MCO
This commit is part of merge request !60. Comments created here will be created in the context of that merge request.
Showing with 53 additions and 29 deletions
...@@ -6,10 +6,9 @@ from .bundle_registry_plugin import ( ...@@ -6,10 +6,9 @@ from .bundle_registry_plugin import (
BUNDLE_REGISTRY_PLUGIN_ID BUNDLE_REGISTRY_PLUGIN_ID
) )
from .io.workflow_reader import WorkflowReader from .io.workflow_reader import WorkflowReader
from .workspecs.workflow import Workflow
from .mco.parameters.mco_parameter_factory_registry import ( from .mco.parameters.mco_parameter_factory_registry import (
MCOParameterFactoryRegistry) MCOParameterFactoryRegistry)
from .mco.parameters.core_mco_parameters import all_core_factories from .workspecs.workflow import Workflow
class BaseCoreDriver(Plugin): class BaseCoreDriver(Plugin):
...@@ -31,8 +30,10 @@ class BaseCoreDriver(Plugin): ...@@ -31,8 +30,10 @@ class BaseCoreDriver(Plugin):
def _parameter_factory_registry_default(self): def _parameter_factory_registry_default(self):
registry = MCOParameterFactoryRegistry() registry = MCOParameterFactoryRegistry()
for f in all_core_factories():
registry.register(f) for mco_bundle in self.bundle_registry.mco_bundles:
for factory in mco_bundle.parameter_factories():
registry.register(factory)
return registry return registry
......
...@@ -107,3 +107,13 @@ class BundleRegistryPlugin(Plugin): ...@@ -107,3 +107,13 @@ class BundleRegistryPlugin(Plugin):
raise ValueError("Requested MCO {} but don't know how " raise ValueError("Requested MCO {} but don't know how "
"to find it.".format(id)) "to find it.".format(id))
def mco_parameters_by_id(self, mco_id, parameter_id):
mco_bundle = self.mco_bundle_by_id(mco_id)
for factory in mco_bundle.parameter_factories():
if factory.id == parameter_id:
return factory
raise ValueError("Requested MCO parameter {}:{} but don't know"
" how to find it.".format(mco_id, parameter_id))
from traits.api import String from traits.api import String
from force_bdss.api import bundle_id, BaseMCOBundle from force_bdss.api import bundle_id, BaseMCOBundle
from force_bdss.core_plugins.dummy.dummy_dakota.parameters import \
RangedMCOParameterFactory
from .dakota_communicator import DummyDakotaCommunicator from .dakota_communicator import DummyDakotaCommunicator
from .dakota_model import DummyDakotaModel from .dakota_model import DummyDakotaModel
...@@ -21,3 +23,8 @@ class DummyDakotaBundle(BaseMCOBundle): ...@@ -21,3 +23,8 @@ class DummyDakotaBundle(BaseMCOBundle):
def create_communicator(self): def create_communicator(self):
return DummyDakotaCommunicator(self) return DummyDakotaCommunicator(self)
def parameter_factories(self):
return [
RangedMCOParameterFactory(self)
]
from traits.api import Float from traits.api import Float
from ...ids import mco_parameter_id from force_bdss.ids import mco_parameter_id
from .base_mco_parameter import BaseMCOParameter from force_bdss.mco.parameters.base_mco_parameter import BaseMCOParameter
from force_bdss.mco.parameters.base_mco_parameter_factory import \ from force_bdss.mco.parameters.base_mco_parameter_factory import \
BaseMCOParameterFactory BaseMCOParameterFactory
...@@ -20,14 +20,3 @@ class RangedMCOParameterFactory(BaseMCOParameterFactory): ...@@ -20,14 +20,3 @@ class RangedMCOParameterFactory(BaseMCOParameterFactory):
model_class = RangedMCOParameter model_class = RangedMCOParameter
name = "Range" name = "Range"
description = "A ranged parameter in floating point values." description = "A ranged parameter in floating point values."
def all_core_factories():
"""Produces a list of all factories contained in this module."""
import inspect
return [c()
for c in inspect.getmodule(all_core_factories).__dict__.values()
if inspect.isclass(c) and
c is not BaseMCOParameterFactory and
issubclass(c, BaseMCOParameterFactory)]
...@@ -13,7 +13,7 @@ from force_bdss.data_sources.data_source_parameters import DataSourceParameters ...@@ -13,7 +13,7 @@ from force_bdss.data_sources.data_source_parameters import DataSourceParameters
from force_bdss.mco.parameters.base_mco_parameter_factory import \ from force_bdss.mco.parameters.base_mco_parameter_factory import \
BaseMCOParameterFactory BaseMCOParameterFactory
from force_bdss.mco.parameters.core_mco_parameters import RangedMCOParameter from force_bdss.core_plugins.dummy.dummy_dakota.parameters import RangedMCOParameter
class TestDakotaCommunicator(unittest.TestCase): class TestDakotaCommunicator(unittest.TestCase):
......
import unittest import unittest
from force_bdss.core_plugins.dummy.dummy_dakota.parameters import RangedMCOParameter, \
RangedMCOParameterFactory
from force_bdss.core_plugins.dummy.dummy_dakota.dakota_model import \ from force_bdss.core_plugins.dummy.dummy_dakota.dakota_model import \
DummyDakotaModel DummyDakotaModel
from force_bdss.mco.base_mco_bundle import BaseMCOBundle from force_bdss.mco.base_mco_bundle import BaseMCOBundle
from force_bdss.mco.parameters.core_mco_parameters import RangedMCOParameter, \
RangedMCOParameterFactory
try: try:
import mock import mock
......
...@@ -128,6 +128,7 @@ class WorkflowReader(HasStrictTraits): ...@@ -128,6 +128,7 @@ class WorkflowReader(HasStrictTraits):
mco_bundle = registry.mco_bundle_by_id(mco_id) mco_bundle = registry.mco_bundle_by_id(mco_id)
model_data = wf_data["mco"]["model_data"] model_data = wf_data["mco"]["model_data"]
model_data["parameters"] = self._extract_mco_parameters( model_data["parameters"] = self._extract_mco_parameters(
mco_id,
model_data["parameters"]) model_data["parameters"])
model = mco_bundle.create_model( model = mco_bundle.create_model(
wf_data["mco"]["model_data"]) wf_data["mco"]["model_data"])
...@@ -182,7 +183,7 @@ class WorkflowReader(HasStrictTraits): ...@@ -182,7 +183,7 @@ class WorkflowReader(HasStrictTraits):
return kpi_calculators return kpi_calculators
def _extract_mco_parameters(self, parameters_data): def _extract_mco_parameters(self, mco_id, parameters_data):
"""Extracts the MCO parameters from the data as dictionary. """Extracts the MCO parameters from the data as dictionary.
Parameters Parameters
...@@ -200,7 +201,7 @@ class WorkflowReader(HasStrictTraits): ...@@ -200,7 +201,7 @@ class WorkflowReader(HasStrictTraits):
for p in parameters_data: for p in parameters_data:
id = p["id"] id = p["id"]
factory = registry.get_factory_by_id(id) factory = registry.get_factory(mco_id, id)
model = factory.create_model(p["model_data"]) model = factory.create_model(p["model_data"])
parameters.append(model) parameters.append(model)
......
...@@ -68,3 +68,12 @@ class BaseMCOBundle(ABCHasStrictTraits): ...@@ -68,3 +68,12 @@ class BaseMCOBundle(ABCHasStrictTraits):
BaseMCOCommunicator BaseMCOCommunicator
An instance of the communicator An instance of the communicator
""" """
@abc.abstractmethod
def parameter_factories(self):
"""Returns the parameter factories supported by this MCO
Returns
-------
List of BaseMCOParameterFactory
"""
from traits.has_traits import HasStrictTraits from traits.api import HasStrictTraits, String, Type, Instance
from traits.trait_types import String, Type
from ..base_mco_bundle import BaseMCOBundle
class BaseMCOParameterFactory(HasStrictTraits): class BaseMCOParameterFactory(HasStrictTraits):
...@@ -8,6 +9,8 @@ class BaseMCOParameterFactory(HasStrictTraits): ...@@ -8,6 +9,8 @@ class BaseMCOParameterFactory(HasStrictTraits):
Must be reimplemented for the specific parameter.""" Must be reimplemented for the specific parameter."""
bundle = Instance(BaseMCOBundle)
#: A unique string identifying the parameter #: A unique string identifying the parameter
id = String() id = String()
...@@ -20,6 +23,10 @@ class BaseMCOParameterFactory(HasStrictTraits): ...@@ -20,6 +23,10 @@ class BaseMCOParameterFactory(HasStrictTraits):
# The model class to instantiate when create_model is called. # The model class to instantiate when create_model is called.
model_class = Type('BaseMCOParameter') model_class = Type('BaseMCOParameter')
def __init__(self, bundle):
self.bundle = bundle
super(BaseMCOParameterFactory, self).__init__()
def create_model(self, data_values=None): def create_model(self, data_values=None):
"""Creates the instance of the model class and returns it. """Creates the instance of the model class and returns it.
""" """
......
...@@ -10,12 +10,12 @@ class MCOParameterFactoryRegistry(HasStrictTraits): ...@@ -10,12 +10,12 @@ class MCOParameterFactoryRegistry(HasStrictTraits):
# Temp: this will become an extension point. # Temp: this will become an extension point.
factories = Dict(String, BaseMCOParameterFactory) factories = Dict(String, BaseMCOParameterFactory)
def get_factory_by_id(self, id): def get_factory_by_id(self, bundle_id, parameter_factory_id):
"""Finds the factory by its id, so that we can obtain it as from """Finds the factory by its id, so that we can obtain it as from
the id in the model file. the id in the model file.
""" """
return self.factories[id] return self.factories[(bundle_id, parameter_factory_id)]
def register(self, factory): def register(self, factory):
"""Registers a new factory""" """Registers a new factory"""
self.factories[factory.id] = factory self.factories[(factory.bundle.id, factory.id)] = factory
import unittest import unittest
from force_bdss.mco.parameters import core_mco_parameters from force_bdss.core_plugins.dummy.dummy_dakota import parameters
from force_bdss.mco.parameters.base_mco_parameter_factory import \ from force_bdss.mco.parameters.base_mco_parameter_factory import \
BaseMCOParameterFactory BaseMCOParameterFactory
class TestCoreMCOParameters(unittest.TestCase): class TestCoreMCOParameters(unittest.TestCase):
def test_all_factories(self): def test_all_factories(self):
factories = core_mco_parameters.all_core_factories() factories = parameters.all_core_factories()
self.assertEqual(len(factories), 1) self.assertEqual(len(factories), 1)
for f in factories: for f in factories:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment