diff --git a/force_bdss/base_core_driver.py b/force_bdss/base_core_driver.py index ca884ccac2d6462b581f1db81b5b86adbd777b07..af082f8b2881e9994eb4e1721328c08a107727d5 100644 --- a/force_bdss/base_core_driver.py +++ b/force_bdss/base_core_driver.py @@ -7,6 +7,8 @@ from .bundle_registry_plugin import ( ) from .io.workflow_reader import WorkflowReader from .workspecs.workflow import Workflow +from .mco.parameters.parameter_factory_registry import ParameterFactoryRegistry +from .mco.parameters.core_mco_parameters import all_core_factories class BaseCoreDriver(Plugin): @@ -16,12 +18,21 @@ class BaseCoreDriver(Plugin): bundle_registry = Instance(BundleRegistryPlugin) + parameter_factory_registry = Instance(ParameterFactoryRegistry) + #: Deserialized content of the workflow file. workflow = Instance(Workflow) def _bundle_registry_default(self): return self.application.get_plugin(BUNDLE_REGISTRY_PLUGIN_ID) + def _parameter_factory_registry_default(self): + registry = ParameterFactoryRegistry() + for f in all_core_factories(): + self.register(f) + + return registry + def _workflow_default(self): reader = WorkflowReader(self.bundle_registry) with open(self.application.workflow_filepath) as f: diff --git a/force_bdss/io/workflow_reader.py b/force_bdss/io/workflow_reader.py index 89c806629e9cf574f97234451695f57d52c3a7cd..c210c27a565c5a0dcb208b80560925a21b12cc46 100644 --- a/force_bdss/io/workflow_reader.py +++ b/force_bdss/io/workflow_reader.py @@ -3,9 +3,10 @@ import logging from traits.api import HasStrictTraits, Instance -from ..workspecs.mco_parameters import RangedMCOParameter -from ..workspecs.workflow import Workflow +from ..mco.parameters.parameter_factory_registry import ( + ParameterFactoryRegistry) from ..bundle_registry_plugin import BundleRegistryPlugin +from ..workspecs.workflow import Workflow SUPPORTED_FILE_VERSIONS = ["1"] @@ -26,8 +27,13 @@ class WorkflowReader(HasStrictTraits): #: The bundle registry. The reader needs it to create the #: bundle-specific model objects. bundle_registry = Instance(BundleRegistryPlugin) + mco_parameter_registry = Instance(ParameterFactoryRegistry) - def __init__(self, bundle_registry, *args, **kwargs): + def __init__(self, + bundle_registry, + mco_parameter_registry, + *args, + **kwargs): """Initializes the reader. Parameters @@ -37,6 +43,7 @@ class WorkflowReader(HasStrictTraits): for a bundle identified by a given id. """ self.bundle_registry = bundle_registry + self.mco_parameter_registry = mco_parameter_registry super(WorkflowReader, self).__init__(*args, **kwargs) @@ -172,5 +179,14 @@ class WorkflowReader(HasStrictTraits): return kpi_calculators def _extract_mco_parameters(self, parameters_data): - return [RangedMCOParameter(**d) for d in parameters_data] + registry = self.mco_parameter_registry + + parameters = [] + + for p in parameters_data: + id = p["id"] + factory = registry.get_factory_by_id(id) + model = factory.create_model(p["model_data"]) + parameters.append(model) + return parameters diff --git a/force_bdss/io/workflow_writer.py b/force_bdss/io/workflow_writer.py index a83d54aab3c649a521a140d49774a89995275b8f..29855c8d9c954bf6fd24382339f2e90f16c78653 100644 --- a/force_bdss/io/workflow_writer.py +++ b/force_bdss/io/workflow_writer.py @@ -27,6 +27,14 @@ class WorkflowWriter(HasStrictTraits): "id": workflow.multi_criteria_optimizer.bundle.id, "model_data": workflow.multi_criteria_optimizer.__getstate__() } + wf_data["multi_criteria_optimizer"]["parameters"] = [] + for param in workflow.multi_criteria_optimizer.parameters: + wf_data["multi_criteria_optimizer"]["parameters"].append( + { + "id": param.factory.id, + "model_data": param.__getstate__() + } + ) kpic_data = [] for kpic in workflow.kpi_calculators: kpic_data.append({ diff --git a/force_bdss/mco/base_mco_model.py b/force_bdss/mco/base_mco_model.py index acdebfbb0ce74b0315c2056889153cad21278bd5..8f3a56f0ff7b0d41ae9498a6dd64b25921f98626 100644 --- a/force_bdss/mco/base_mco_model.py +++ b/force_bdss/mco/base_mco_model.py @@ -1,6 +1,6 @@ from traits.api import ABCHasStrictTraits, Instance, List -from ..workspecs.mco_parameters import MCOParameter +from .parameters.base_mco_parameter import BaseMCOParameter from .i_multi_criteria_optimizer_bundle import IMultiCriteriaOptimizerBundle @@ -18,7 +18,7 @@ class BaseMCOModel(ABCHasStrictTraits): visible=False, transient=True) - parameters = List(MCOParameter) + parameters = List(BaseMCOParameter) def __init__(self, bundle, *args, **kwargs): self.bundle = bundle diff --git a/force_bdss/mco/parameters/__init__.py b/force_bdss/mco/parameters/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/force_bdss/mco/parameters/base_mco_parameter.py b/force_bdss/mco/parameters/base_mco_parameter.py new file mode 100644 index 0000000000000000000000000000000000000000..de14213819f88c16c3136149a352806a4b89ecbb --- /dev/null +++ b/force_bdss/mco/parameters/base_mco_parameter.py @@ -0,0 +1,20 @@ +from traits.api import HasStrictTraits, String, Type, Instance + + +class BaseMCOParameterFactory(HasStrictTraits): + id = String() + name = String("Undefined parameter") + description = String("Undefined parameter") + model_class = Type('BaseMCOParameter') + + def create_model(self, data_values=None): + if data_values is None: + data_values = {} + + return self.model_class(factory=self, **data_values) + + +class BaseMCOParameter(HasStrictTraits): + factory = Instance(BaseMCOParameterFactory) + value_name = String() + value_type = String() diff --git a/force_bdss/mco/parameters/core_mco_parameters.py b/force_bdss/mco/parameters/core_mco_parameters.py new file mode 100644 index 0000000000000000000000000000000000000000..a1b6d508b93649f96da0d129e8f0b4d0f1707376 --- /dev/null +++ b/force_bdss/mco/parameters/core_mco_parameters.py @@ -0,0 +1,24 @@ +from traits.api import Float + +from ...ids import mco_parameter_id +from .base_mco_parameter import BaseMCOParameter, BaseMCOParameterFactory + + +class RangedMCOParameter(BaseMCOParameter): + initial_value = Float() + upper_bound = Float() + lower_bound = Float() + + +class RangedMCOParameterFactory(BaseMCOParameterFactory): + id = mco_parameter_id("enthought", "ranged") + model_class = RangedMCOParameter + name = "Range" + description = "A ranged parameter in floating point values." + + +def all_core_factories(): + import inspect + + return [c for c in inspect.getmodule(all_core_factories).__dict__.values() + if inspect.isclass(c) and issubclass(c, BaseMCOParameterFactory)] diff --git a/force_bdss/mco/parameters/parameter_factory_registry.py b/force_bdss/mco/parameters/parameter_factory_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..915cd0cc8144e764bf9209c0cf1f26f492589028 --- /dev/null +++ b/force_bdss/mco/parameters/parameter_factory_registry.py @@ -0,0 +1,11 @@ +from traits.api import HasStrictTraits, Dict + + +class ParameterFactoryRegistry(HasStrictTraits): + factories = Dict() + + def get_factory_by_id(self, id): + return self.factories[id] + + def register(self, factory): + self.factories[factory.id] = factory diff --git a/force_bdss/mco/parameters/tests/__init__.py b/force_bdss/mco/parameters/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/force_bdss/mco/parameters/tests/test_core_mco_parameters.py b/force_bdss/mco/parameters/tests/test_core_mco_parameters.py new file mode 100644 index 0000000000000000000000000000000000000000..80fd17be427e77d3b6a7474c07a70aa03c4edd71 --- /dev/null +++ b/force_bdss/mco/parameters/tests/test_core_mco_parameters.py @@ -0,0 +1,14 @@ +import unittest + +from force_bdss.mco.parameters import core_mco_parameters +from force_bdss.mco.parameters.base_mco_parameter import \ + BaseMCOParameterFactory + + +class TestCoreMCOParameters(unittest.TestCase): + def test_all_classes(self): + factories = core_mco_parameters.all_core_factories() + self.assertNotEqual(len(factories), 0) + + for f in factories: + self.assertTrue(issubclass(f, BaseMCOParameterFactory)) diff --git a/force_bdss/workspecs/mco_parameters.py b/force_bdss/workspecs/mco_parameters.py deleted file mode 100644 index 5020117e19a8bc7d8d45d70fe2b60aa9ef77cdee..0000000000000000000000000000000000000000 --- a/force_bdss/workspecs/mco_parameters.py +++ /dev/null @@ -1,13 +0,0 @@ -from traits.api import HasStrictTraits, String, Float - - -class MCOParameter(HasStrictTraits): - pass - - -class RangedMCOParameter(MCOParameter): - name = String() - value_type = String() - initial_value = Float() - upper_bound = Float() - lower_bound = Float()