Skip to content
Snippets Groups Projects
Commit 63678bdb authored by Stefano Borini's avatar Stefano Borini
Browse files

Introduced factory based creation of parameters

parent 64254e31
No related branches found
No related tags found
1 merge request!45Added support for MCO named parameters.
......@@ -7,6 +7,8 @@ from .bundle_registry_plugin import (
)
from .io.workflow_reader import WorkflowReader
from .workspecs.workflow import Workflow
from .mco.parameters.parameter_factory_registry import ParameterFactoryRegistry
from .mco.parameters.core_mco_parameters import all_core_factories
class BaseCoreDriver(Plugin):
......@@ -16,12 +18,21 @@ class BaseCoreDriver(Plugin):
bundle_registry = Instance(BundleRegistryPlugin)
parameter_factory_registry = Instance(ParameterFactoryRegistry)
#: Deserialized content of the workflow file.
workflow = Instance(Workflow)
def _bundle_registry_default(self):
return self.application.get_plugin(BUNDLE_REGISTRY_PLUGIN_ID)
def _parameter_factory_registry_default(self):
registry = ParameterFactoryRegistry()
for f in all_core_factories():
self.register(f)
return registry
def _workflow_default(self):
reader = WorkflowReader(self.bundle_registry)
with open(self.application.workflow_filepath) as f:
......
......@@ -3,9 +3,10 @@ import logging
from traits.api import HasStrictTraits, Instance
from ..workspecs.mco_parameters import RangedMCOParameter
from ..workspecs.workflow import Workflow
from ..mco.parameters.parameter_factory_registry import (
ParameterFactoryRegistry)
from ..bundle_registry_plugin import BundleRegistryPlugin
from ..workspecs.workflow import Workflow
SUPPORTED_FILE_VERSIONS = ["1"]
......@@ -26,8 +27,13 @@ class WorkflowReader(HasStrictTraits):
#: The bundle registry. The reader needs it to create the
#: bundle-specific model objects.
bundle_registry = Instance(BundleRegistryPlugin)
mco_parameter_registry = Instance(ParameterFactoryRegistry)
def __init__(self, bundle_registry, *args, **kwargs):
def __init__(self,
bundle_registry,
mco_parameter_registry,
*args,
**kwargs):
"""Initializes the reader.
Parameters
......@@ -37,6 +43,7 @@ class WorkflowReader(HasStrictTraits):
for a bundle identified by a given id.
"""
self.bundle_registry = bundle_registry
self.mco_parameter_registry = mco_parameter_registry
super(WorkflowReader, self).__init__(*args, **kwargs)
......@@ -172,5 +179,14 @@ class WorkflowReader(HasStrictTraits):
return kpi_calculators
def _extract_mco_parameters(self, parameters_data):
return [RangedMCOParameter(**d) for d in parameters_data]
registry = self.mco_parameter_registry
parameters = []
for p in parameters_data:
id = p["id"]
factory = registry.get_factory_by_id(id)
model = factory.create_model(p["model_data"])
parameters.append(model)
return parameters
......@@ -27,6 +27,14 @@ class WorkflowWriter(HasStrictTraits):
"id": workflow.multi_criteria_optimizer.bundle.id,
"model_data": workflow.multi_criteria_optimizer.__getstate__()
}
wf_data["multi_criteria_optimizer"]["parameters"] = []
for param in workflow.multi_criteria_optimizer.parameters:
wf_data["multi_criteria_optimizer"]["parameters"].append(
{
"id": param.factory.id,
"model_data": param.__getstate__()
}
)
kpic_data = []
for kpic in workflow.kpi_calculators:
kpic_data.append({
......
from traits.api import ABCHasStrictTraits, Instance, List
from ..workspecs.mco_parameters import MCOParameter
from .parameters.base_mco_parameter import BaseMCOParameter
from .i_multi_criteria_optimizer_bundle import IMultiCriteriaOptimizerBundle
......@@ -18,7 +18,7 @@ class BaseMCOModel(ABCHasStrictTraits):
visible=False,
transient=True)
parameters = List(MCOParameter)
parameters = List(BaseMCOParameter)
def __init__(self, bundle, *args, **kwargs):
self.bundle = bundle
......
from traits.api import HasStrictTraits, String, Type, Instance
class BaseMCOParameterFactory(HasStrictTraits):
id = String()
name = String("Undefined parameter")
description = String("Undefined parameter")
model_class = Type('BaseMCOParameter')
def create_model(self, data_values=None):
if data_values is None:
data_values = {}
return self.model_class(factory=self, **data_values)
class BaseMCOParameter(HasStrictTraits):
factory = Instance(BaseMCOParameterFactory)
value_name = String()
value_type = String()
from traits.api import Float
from ...ids import mco_parameter_id
from .base_mco_parameter import BaseMCOParameter, BaseMCOParameterFactory
class RangedMCOParameter(BaseMCOParameter):
initial_value = Float()
upper_bound = Float()
lower_bound = Float()
class RangedMCOParameterFactory(BaseMCOParameterFactory):
id = mco_parameter_id("enthought", "ranged")
model_class = RangedMCOParameter
name = "Range"
description = "A ranged parameter in floating point values."
def all_core_factories():
import inspect
return [c for c in inspect.getmodule(all_core_factories).__dict__.values()
if inspect.isclass(c) and issubclass(c, BaseMCOParameterFactory)]
from traits.api import HasStrictTraits, Dict
class ParameterFactoryRegistry(HasStrictTraits):
factories = Dict()
def get_factory_by_id(self, id):
return self.factories[id]
def register(self, factory):
self.factories[factory.id] = factory
import unittest
from force_bdss.mco.parameters import core_mco_parameters
from force_bdss.mco.parameters.base_mco_parameter import \
BaseMCOParameterFactory
class TestCoreMCOParameters(unittest.TestCase):
def test_all_classes(self):
factories = core_mco_parameters.all_core_factories()
self.assertNotEqual(len(factories), 0)
for f in factories:
self.assertTrue(issubclass(f, BaseMCOParameterFactory))
from traits.api import HasStrictTraits, String, Float
class MCOParameter(HasStrictTraits):
pass
class RangedMCOParameter(MCOParameter):
name = String()
value_type = String()
initial_value = Float()
upper_bound = Float()
lower_bound = Float()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment