From 63678bdb4ab0d28aa619176899f80cc1c7fddd09 Mon Sep 17 00:00:00 2001 From: Stefano Borini <sborini@enthought.com> Date: Mon, 24 Jul 2017 12:55:13 +0100 Subject: [PATCH] Introduced factory based creation of parameters --- force_bdss/base_core_driver.py | 11 +++++++++ force_bdss/io/workflow_reader.py | 24 +++++++++++++++---- force_bdss/io/workflow_writer.py | 8 +++++++ force_bdss/mco/base_mco_model.py | 4 ++-- force_bdss/mco/parameters/__init__.py | 0 .../mco/parameters/base_mco_parameter.py | 20 ++++++++++++++++ .../mco/parameters/core_mco_parameters.py | 24 +++++++++++++++++++ .../parameters/parameter_factory_registry.py | 11 +++++++++ force_bdss/mco/parameters/tests/__init__.py | 0 .../tests/test_core_mco_parameters.py | 14 +++++++++++ force_bdss/workspecs/mco_parameters.py | 13 ---------- 11 files changed, 110 insertions(+), 19 deletions(-) create mode 100644 force_bdss/mco/parameters/__init__.py create mode 100644 force_bdss/mco/parameters/base_mco_parameter.py create mode 100644 force_bdss/mco/parameters/core_mco_parameters.py create mode 100644 force_bdss/mco/parameters/parameter_factory_registry.py create mode 100644 force_bdss/mco/parameters/tests/__init__.py create mode 100644 force_bdss/mco/parameters/tests/test_core_mco_parameters.py delete mode 100644 force_bdss/workspecs/mco_parameters.py diff --git a/force_bdss/base_core_driver.py b/force_bdss/base_core_driver.py index ca884cc..af082f8 100644 --- a/force_bdss/base_core_driver.py +++ b/force_bdss/base_core_driver.py @@ -7,6 +7,8 @@ from .bundle_registry_plugin import ( ) from .io.workflow_reader import WorkflowReader from .workspecs.workflow import Workflow +from .mco.parameters.parameter_factory_registry import ParameterFactoryRegistry +from .mco.parameters.core_mco_parameters import all_core_factories class BaseCoreDriver(Plugin): @@ -16,12 +18,21 @@ class BaseCoreDriver(Plugin): bundle_registry = Instance(BundleRegistryPlugin) + parameter_factory_registry = Instance(ParameterFactoryRegistry) + #: Deserialized content of the workflow file. workflow = Instance(Workflow) def _bundle_registry_default(self): return self.application.get_plugin(BUNDLE_REGISTRY_PLUGIN_ID) + def _parameter_factory_registry_default(self): + registry = ParameterFactoryRegistry() + for f in all_core_factories(): + self.register(f) + + return registry + def _workflow_default(self): reader = WorkflowReader(self.bundle_registry) with open(self.application.workflow_filepath) as f: diff --git a/force_bdss/io/workflow_reader.py b/force_bdss/io/workflow_reader.py index 89c8066..c210c27 100644 --- a/force_bdss/io/workflow_reader.py +++ b/force_bdss/io/workflow_reader.py @@ -3,9 +3,10 @@ import logging from traits.api import HasStrictTraits, Instance -from ..workspecs.mco_parameters import RangedMCOParameter -from ..workspecs.workflow import Workflow +from ..mco.parameters.parameter_factory_registry import ( + ParameterFactoryRegistry) from ..bundle_registry_plugin import BundleRegistryPlugin +from ..workspecs.workflow import Workflow SUPPORTED_FILE_VERSIONS = ["1"] @@ -26,8 +27,13 @@ class WorkflowReader(HasStrictTraits): #: The bundle registry. The reader needs it to create the #: bundle-specific model objects. bundle_registry = Instance(BundleRegistryPlugin) + mco_parameter_registry = Instance(ParameterFactoryRegistry) - def __init__(self, bundle_registry, *args, **kwargs): + def __init__(self, + bundle_registry, + mco_parameter_registry, + *args, + **kwargs): """Initializes the reader. Parameters @@ -37,6 +43,7 @@ class WorkflowReader(HasStrictTraits): for a bundle identified by a given id. """ self.bundle_registry = bundle_registry + self.mco_parameter_registry = mco_parameter_registry super(WorkflowReader, self).__init__(*args, **kwargs) @@ -172,5 +179,14 @@ class WorkflowReader(HasStrictTraits): return kpi_calculators def _extract_mco_parameters(self, parameters_data): - return [RangedMCOParameter(**d) for d in parameters_data] + registry = self.mco_parameter_registry + + parameters = [] + + for p in parameters_data: + id = p["id"] + factory = registry.get_factory_by_id(id) + model = factory.create_model(p["model_data"]) + parameters.append(model) + return parameters diff --git a/force_bdss/io/workflow_writer.py b/force_bdss/io/workflow_writer.py index a83d54a..29855c8 100644 --- a/force_bdss/io/workflow_writer.py +++ b/force_bdss/io/workflow_writer.py @@ -27,6 +27,14 @@ class WorkflowWriter(HasStrictTraits): "id": workflow.multi_criteria_optimizer.bundle.id, "model_data": workflow.multi_criteria_optimizer.__getstate__() } + wf_data["multi_criteria_optimizer"]["parameters"] = [] + for param in workflow.multi_criteria_optimizer.parameters: + wf_data["multi_criteria_optimizer"]["parameters"].append( + { + "id": param.factory.id, + "model_data": param.__getstate__() + } + ) kpic_data = [] for kpic in workflow.kpi_calculators: kpic_data.append({ diff --git a/force_bdss/mco/base_mco_model.py b/force_bdss/mco/base_mco_model.py index acdebfb..8f3a56f 100644 --- a/force_bdss/mco/base_mco_model.py +++ b/force_bdss/mco/base_mco_model.py @@ -1,6 +1,6 @@ from traits.api import ABCHasStrictTraits, Instance, List -from ..workspecs.mco_parameters import MCOParameter +from .parameters.base_mco_parameter import BaseMCOParameter from .i_multi_criteria_optimizer_bundle import IMultiCriteriaOptimizerBundle @@ -18,7 +18,7 @@ class BaseMCOModel(ABCHasStrictTraits): visible=False, transient=True) - parameters = List(MCOParameter) + parameters = List(BaseMCOParameter) def __init__(self, bundle, *args, **kwargs): self.bundle = bundle diff --git a/force_bdss/mco/parameters/__init__.py b/force_bdss/mco/parameters/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/force_bdss/mco/parameters/base_mco_parameter.py b/force_bdss/mco/parameters/base_mco_parameter.py new file mode 100644 index 0000000..de14213 --- /dev/null +++ b/force_bdss/mco/parameters/base_mco_parameter.py @@ -0,0 +1,20 @@ +from traits.api import HasStrictTraits, String, Type, Instance + + +class BaseMCOParameterFactory(HasStrictTraits): + id = String() + name = String("Undefined parameter") + description = String("Undefined parameter") + model_class = Type('BaseMCOParameter') + + def create_model(self, data_values=None): + if data_values is None: + data_values = {} + + return self.model_class(factory=self, **data_values) + + +class BaseMCOParameter(HasStrictTraits): + factory = Instance(BaseMCOParameterFactory) + value_name = String() + value_type = String() diff --git a/force_bdss/mco/parameters/core_mco_parameters.py b/force_bdss/mco/parameters/core_mco_parameters.py new file mode 100644 index 0000000..a1b6d50 --- /dev/null +++ b/force_bdss/mco/parameters/core_mco_parameters.py @@ -0,0 +1,24 @@ +from traits.api import Float + +from ...ids import mco_parameter_id +from .base_mco_parameter import BaseMCOParameter, BaseMCOParameterFactory + + +class RangedMCOParameter(BaseMCOParameter): + initial_value = Float() + upper_bound = Float() + lower_bound = Float() + + +class RangedMCOParameterFactory(BaseMCOParameterFactory): + id = mco_parameter_id("enthought", "ranged") + model_class = RangedMCOParameter + name = "Range" + description = "A ranged parameter in floating point values." + + +def all_core_factories(): + import inspect + + return [c for c in inspect.getmodule(all_core_factories).__dict__.values() + if inspect.isclass(c) and issubclass(c, BaseMCOParameterFactory)] diff --git a/force_bdss/mco/parameters/parameter_factory_registry.py b/force_bdss/mco/parameters/parameter_factory_registry.py new file mode 100644 index 0000000..915cd0c --- /dev/null +++ b/force_bdss/mco/parameters/parameter_factory_registry.py @@ -0,0 +1,11 @@ +from traits.api import HasStrictTraits, Dict + + +class ParameterFactoryRegistry(HasStrictTraits): + factories = Dict() + + def get_factory_by_id(self, id): + return self.factories[id] + + def register(self, factory): + self.factories[factory.id] = factory diff --git a/force_bdss/mco/parameters/tests/__init__.py b/force_bdss/mco/parameters/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/force_bdss/mco/parameters/tests/test_core_mco_parameters.py b/force_bdss/mco/parameters/tests/test_core_mco_parameters.py new file mode 100644 index 0000000..80fd17b --- /dev/null +++ b/force_bdss/mco/parameters/tests/test_core_mco_parameters.py @@ -0,0 +1,14 @@ +import unittest + +from force_bdss.mco.parameters import core_mco_parameters +from force_bdss.mco.parameters.base_mco_parameter import \ + BaseMCOParameterFactory + + +class TestCoreMCOParameters(unittest.TestCase): + def test_all_classes(self): + factories = core_mco_parameters.all_core_factories() + self.assertNotEqual(len(factories), 0) + + for f in factories: + self.assertTrue(issubclass(f, BaseMCOParameterFactory)) diff --git a/force_bdss/workspecs/mco_parameters.py b/force_bdss/workspecs/mco_parameters.py deleted file mode 100644 index 5020117..0000000 --- a/force_bdss/workspecs/mco_parameters.py +++ /dev/null @@ -1,13 +0,0 @@ -from traits.api import HasStrictTraits, String, Float - - -class MCOParameter(HasStrictTraits): - pass - - -class RangedMCOParameter(MCOParameter): - name = String() - value_type = String() - initial_value = Float() - upper_bound = Float() - lower_bound = Float() -- GitLab