From 1ee0cd0e357d68b966b262394e4ec814443380bd Mon Sep 17 00:00:00 2001
From: Stefano Borini <sborini@enthought.com>
Date: Fri, 28 Jul 2017 17:07:45 +0100
Subject: [PATCH] Moved parameters into mco

---
 force_bdss/base_core_driver.py                    |  9 +++++----
 force_bdss/bundle_registry_plugin.py              | 10 ++++++++++
 .../dummy/dummy_dakota/dakota_bundle.py           |  7 +++++++
 .../dummy/dummy_dakota/parameters.py}             | 15 ++-------------
 .../tests/test_dakota_communicator.py             |  2 +-
 .../dummy_dakota/tests/test_dakota_optimizer.py   |  4 ++--
 force_bdss/io/workflow_reader.py                  |  5 +++--
 force_bdss/mco/base_mco_bundle.py                 |  9 +++++++++
 .../mco/parameters/base_mco_parameter_factory.py  | 11 +++++++++--
 .../parameters/mco_parameter_factory_registry.py  |  6 +++---
 .../parameters/tests/test_core_mco_parameters.py  |  4 ++--
 11 files changed, 53 insertions(+), 29 deletions(-)
 rename force_bdss/{mco/parameters/core_mco_parameters.py => core_plugins/dummy/dummy_dakota/parameters.py} (58%)

diff --git a/force_bdss/base_core_driver.py b/force_bdss/base_core_driver.py
index 6b43b4c..e17d586 100644
--- a/force_bdss/base_core_driver.py
+++ b/force_bdss/base_core_driver.py
@@ -6,10 +6,9 @@ from .bundle_registry_plugin import (
     BUNDLE_REGISTRY_PLUGIN_ID
 )
 from .io.workflow_reader import WorkflowReader
-from .workspecs.workflow import Workflow
 from .mco.parameters.mco_parameter_factory_registry import (
     MCOParameterFactoryRegistry)
-from .mco.parameters.core_mco_parameters import all_core_factories
+from .workspecs.workflow import Workflow
 
 
 class BaseCoreDriver(Plugin):
@@ -31,8 +30,10 @@ class BaseCoreDriver(Plugin):
 
     def _parameter_factory_registry_default(self):
         registry = MCOParameterFactoryRegistry()
-        for f in all_core_factories():
-            registry.register(f)
+
+        for mco_bundle in self.bundle_registry.mco_bundles:
+            for factory in mco_bundle.parameter_factories():
+                registry.register(factory)
 
         return registry
 
diff --git a/force_bdss/bundle_registry_plugin.py b/force_bdss/bundle_registry_plugin.py
index 24bc5dd..aa22110 100644
--- a/force_bdss/bundle_registry_plugin.py
+++ b/force_bdss/bundle_registry_plugin.py
@@ -107,3 +107,13 @@ class BundleRegistryPlugin(Plugin):
 
         raise ValueError("Requested MCO {} but don't know how "
                          "to find it.".format(id))
+
+    def mco_parameters_by_id(self, mco_id, parameter_id):
+        mco_bundle = self.mco_bundle_by_id(mco_id)
+
+        for factory in mco_bundle.parameter_factories():
+            if factory.id == parameter_id:
+                return factory
+
+        raise ValueError("Requested MCO parameter {}:{} but don't know"
+                         " how to find it.".format(mco_id, parameter_id))
diff --git a/force_bdss/core_plugins/dummy/dummy_dakota/dakota_bundle.py b/force_bdss/core_plugins/dummy/dummy_dakota/dakota_bundle.py
index 925a127..189c27c 100644
--- a/force_bdss/core_plugins/dummy/dummy_dakota/dakota_bundle.py
+++ b/force_bdss/core_plugins/dummy/dummy_dakota/dakota_bundle.py
@@ -1,5 +1,7 @@
 from traits.api import String
 from force_bdss.api import bundle_id, BaseMCOBundle
+from force_bdss.core_plugins.dummy.dummy_dakota.parameters import \
+    RangedMCOParameterFactory
 
 from .dakota_communicator import DummyDakotaCommunicator
 from .dakota_model import DummyDakotaModel
@@ -21,3 +23,8 @@ class DummyDakotaBundle(BaseMCOBundle):
 
     def create_communicator(self):
         return DummyDakotaCommunicator(self)
+
+    def parameter_factories(self):
+        return [
+            RangedMCOParameterFactory(self)
+        ]
diff --git a/force_bdss/mco/parameters/core_mco_parameters.py b/force_bdss/core_plugins/dummy/dummy_dakota/parameters.py
similarity index 58%
rename from force_bdss/mco/parameters/core_mco_parameters.py
rename to force_bdss/core_plugins/dummy/dummy_dakota/parameters.py
index b9efa4e..6585312 100644
--- a/force_bdss/mco/parameters/core_mco_parameters.py
+++ b/force_bdss/core_plugins/dummy/dummy_dakota/parameters.py
@@ -1,7 +1,7 @@
 from traits.api import Float
 
-from ...ids import mco_parameter_id
-from .base_mco_parameter import BaseMCOParameter
+from force_bdss.ids import mco_parameter_id
+from force_bdss.mco.parameters.base_mco_parameter import BaseMCOParameter
 from force_bdss.mco.parameters.base_mco_parameter_factory import \
     BaseMCOParameterFactory
 
@@ -20,14 +20,3 @@ class RangedMCOParameterFactory(BaseMCOParameterFactory):
     model_class = RangedMCOParameter
     name = "Range"
     description = "A ranged parameter in floating point values."
-
-
-def all_core_factories():
-    """Produces a list of all factories contained in this module."""
-    import inspect
-
-    return [c()
-            for c in inspect.getmodule(all_core_factories).__dict__.values()
-            if inspect.isclass(c) and
-            c is not BaseMCOParameterFactory and
-            issubclass(c, BaseMCOParameterFactory)]
diff --git a/force_bdss/core_plugins/dummy/dummy_dakota/tests/test_dakota_communicator.py b/force_bdss/core_plugins/dummy/dummy_dakota/tests/test_dakota_communicator.py
index e7bc0eb..97cf108 100644
--- a/force_bdss/core_plugins/dummy/dummy_dakota/tests/test_dakota_communicator.py
+++ b/force_bdss/core_plugins/dummy/dummy_dakota/tests/test_dakota_communicator.py
@@ -13,7 +13,7 @@ from force_bdss.data_sources.data_source_parameters import DataSourceParameters
 
 from force_bdss.mco.parameters.base_mco_parameter_factory import \
     BaseMCOParameterFactory
-from force_bdss.mco.parameters.core_mco_parameters import RangedMCOParameter
+from force_bdss.core_plugins.dummy.dummy_dakota.parameters import RangedMCOParameter
 
 
 class TestDakotaCommunicator(unittest.TestCase):
diff --git a/force_bdss/core_plugins/dummy/dummy_dakota/tests/test_dakota_optimizer.py b/force_bdss/core_plugins/dummy/dummy_dakota/tests/test_dakota_optimizer.py
index b0b9f85..40ce8cc 100644
--- a/force_bdss/core_plugins/dummy/dummy_dakota/tests/test_dakota_optimizer.py
+++ b/force_bdss/core_plugins/dummy/dummy_dakota/tests/test_dakota_optimizer.py
@@ -1,10 +1,10 @@
 import unittest
 
+from force_bdss.core_plugins.dummy.dummy_dakota.parameters import RangedMCOParameter, \
+    RangedMCOParameterFactory
 from force_bdss.core_plugins.dummy.dummy_dakota.dakota_model import \
     DummyDakotaModel
 from force_bdss.mco.base_mco_bundle import BaseMCOBundle
-from force_bdss.mco.parameters.core_mco_parameters import RangedMCOParameter, \
-    RangedMCOParameterFactory
 
 try:
     import mock
diff --git a/force_bdss/io/workflow_reader.py b/force_bdss/io/workflow_reader.py
index e12e13f..e13f6ac 100644
--- a/force_bdss/io/workflow_reader.py
+++ b/force_bdss/io/workflow_reader.py
@@ -128,6 +128,7 @@ class WorkflowReader(HasStrictTraits):
         mco_bundle = registry.mco_bundle_by_id(mco_id)
         model_data = wf_data["mco"]["model_data"]
         model_data["parameters"] = self._extract_mco_parameters(
+            mco_id,
             model_data["parameters"])
         model = mco_bundle.create_model(
             wf_data["mco"]["model_data"])
@@ -182,7 +183,7 @@ class WorkflowReader(HasStrictTraits):
 
         return kpi_calculators
 
-    def _extract_mco_parameters(self, parameters_data):
+    def _extract_mco_parameters(self, mco_id, parameters_data):
         """Extracts the MCO parameters from the data as dictionary.
 
         Parameters
@@ -200,7 +201,7 @@ class WorkflowReader(HasStrictTraits):
 
         for p in parameters_data:
             id = p["id"]
-            factory = registry.get_factory_by_id(id)
+            factory = registry.get_factory(mco_id, id)
             model = factory.create_model(p["model_data"])
             parameters.append(model)
 
diff --git a/force_bdss/mco/base_mco_bundle.py b/force_bdss/mco/base_mco_bundle.py
index 5fe9fdb..c8e0d6d 100644
--- a/force_bdss/mco/base_mco_bundle.py
+++ b/force_bdss/mco/base_mco_bundle.py
@@ -68,3 +68,12 @@ class BaseMCOBundle(ABCHasStrictTraits):
         BaseMCOCommunicator
             An instance of the communicator
         """
+
+    @abc.abstractmethod
+    def parameter_factories(self):
+        """Returns the parameter factories supported by this MCO
+
+        Returns
+        -------
+        List of BaseMCOParameterFactory
+        """
diff --git a/force_bdss/mco/parameters/base_mco_parameter_factory.py b/force_bdss/mco/parameters/base_mco_parameter_factory.py
index 6124173..f74e325 100644
--- a/force_bdss/mco/parameters/base_mco_parameter_factory.py
+++ b/force_bdss/mco/parameters/base_mco_parameter_factory.py
@@ -1,5 +1,6 @@
-from traits.has_traits import HasStrictTraits
-from traits.trait_types import String, Type
+from traits.api import HasStrictTraits, String, Type, Instance
+
+from ..base_mco_bundle import BaseMCOBundle
 
 
 class BaseMCOParameterFactory(HasStrictTraits):
@@ -8,6 +9,8 @@ class BaseMCOParameterFactory(HasStrictTraits):
 
     Must be reimplemented for the specific parameter."""
 
+    bundle = Instance(BaseMCOBundle)
+
     #: A unique string identifying the parameter
     id = String()
 
@@ -20,6 +23,10 @@ class BaseMCOParameterFactory(HasStrictTraits):
     # The model class to instantiate when create_model is called.
     model_class = Type('BaseMCOParameter')
 
+    def __init__(self, bundle):
+        self.bundle = bundle
+        super(BaseMCOParameterFactory, self).__init__()
+
     def create_model(self, data_values=None):
         """Creates the instance of the model class and returns it.
         """
diff --git a/force_bdss/mco/parameters/mco_parameter_factory_registry.py b/force_bdss/mco/parameters/mco_parameter_factory_registry.py
index 709f7f9..1920ac0 100644
--- a/force_bdss/mco/parameters/mco_parameter_factory_registry.py
+++ b/force_bdss/mco/parameters/mco_parameter_factory_registry.py
@@ -10,12 +10,12 @@ class MCOParameterFactoryRegistry(HasStrictTraits):
     # Temp: this will become an extension point.
     factories = Dict(String, BaseMCOParameterFactory)
 
-    def get_factory_by_id(self, id):
+    def get_factory_by_id(self, bundle_id, parameter_factory_id):
         """Finds the factory by its id, so that we can obtain it as from
         the id in the model file.
         """
-        return self.factories[id]
+        return self.factories[(bundle_id, parameter_factory_id)]
 
     def register(self, factory):
         """Registers a new factory"""
-        self.factories[factory.id] = factory
+        self.factories[(factory.bundle.id, factory.id)] = factory
diff --git a/force_bdss/mco/parameters/tests/test_core_mco_parameters.py b/force_bdss/mco/parameters/tests/test_core_mco_parameters.py
index af6f4c9..8bff8fe 100644
--- a/force_bdss/mco/parameters/tests/test_core_mco_parameters.py
+++ b/force_bdss/mco/parameters/tests/test_core_mco_parameters.py
@@ -1,13 +1,13 @@
 import unittest
 
-from force_bdss.mco.parameters import core_mco_parameters
+from force_bdss.core_plugins.dummy.dummy_dakota import parameters
 from force_bdss.mco.parameters.base_mco_parameter_factory import \
     BaseMCOParameterFactory
 
 
 class TestCoreMCOParameters(unittest.TestCase):
     def test_all_factories(self):
-        factories = core_mco_parameters.all_core_factories()
+        factories = parameters.all_core_factories()
         self.assertEqual(len(factories), 1)
 
         for f in factories:
-- 
GitLab