Skip to content
Snippets Groups Projects
Commit 7ebcd1f9 authored by Stefano Borini's avatar Stefano Borini
Browse files

upgraded base mco factory

parent d7e80315
No related branches found
No related tags found
1 merge request!130Safer plugin import - 2
......@@ -25,16 +25,16 @@ class BaseMCOFactory(ABCHasStrictTraits):
name = Str()
#: The optimizer class to instantiate. Define this to your MCO class.
optimizer_class = Type(BaseMCO)
optimizer_class = Type(BaseMCO, allow_none=False)
#: The model associated to the MCO. Define this to your MCO model class.
model_class = Type(BaseMCOModel)
model_class = Type(BaseMCOModel, allow_none=False)
#: The communicator associated to the MCO. Define this to your MCO comm.
communicator_class = Type(BaseMCOCommunicator)
communicator_class = Type(BaseMCOCommunicator, allow_none=False)
#: A reference to the Plugin that holds this factory.
plugin = Instance(Plugin)
plugin = Instance(Plugin, allow_none=False)
def __init__(self, plugin, *args, **kwargs):
self.plugin = plugin
......@@ -45,7 +45,17 @@ class BaseMCOFactory(ABCHasStrictTraits):
self.model_class = self.get_model_class()
self.communicator_class = self.get_communicator_class()
identifier = self.get_identifier()
self.id = factory_id(self.plugin.id, identifier)
try:
id = factory_id(self.plugin.id, identifier)
except ValueError:
raise ValueError(
"Invalid identifier {} returned by "
"{}.get_identifier()".format(
identifier,
self.__class__.__name__
)
)
self.id = id
def get_optimizer_class(self):
raise NotImplementedError(
......
import unittest
import testfixtures
from traits.trait_errors import TraitError
from force_bdss.mco.base_mco_model import BaseMCOModel
from force_bdss.mco.tests.test_base_mco import DummyMCO
......@@ -18,18 +18,20 @@ from force_bdss.mco.base_mco_factory import BaseMCOFactory
class DummyMCOFactory(BaseMCOFactory):
id = "foo"
def get_identifier(self):
return "foo"
name = "bar"
def get_name(self):
return "bar"
def create_optimizer(self):
pass
def get_model_class(self):
return DummyMCOModel
def create_model(self, model_data=None):
pass
def get_communicator_class(self):
return DummyMCOCommunicator
def create_communicator(self):
pass
def get_optimizer_class(self):
return DummyMCO
def parameter_factories(self):
return []
......@@ -39,26 +41,14 @@ class DummyMCOModel(BaseMCOModel):
pass
class DummyMCOFactoryFast(BaseMCOFactory):
id = "foo"
name = "bar"
optimizer_class = DummyMCO
model_class = DummyMCOModel
communicator_class = DummyMCOCommunicator
class TestBaseMCOFactory(unittest.TestCase):
def setUp(self):
self.plugin = mock.Mock(spec=Plugin, id="pid")
def test_initialization(self):
factory = DummyMCOFactory(mock.Mock(spec=Plugin))
self.assertEqual(factory.id, 'foo')
factory = DummyMCOFactory(self.plugin)
self.assertEqual(factory.id, 'pid.factory.foo')
self.assertEqual(factory.name, 'bar')
def test_fast_definition(self):
factory = DummyMCOFactoryFast(mock.Mock(spec=Plugin))
self.assertIsInstance(factory.create_optimizer(),
DummyMCO)
self.assertIsInstance(factory.create_communicator(),
......@@ -66,18 +56,42 @@ class TestBaseMCOFactory(unittest.TestCase):
self.assertIsInstance(factory.create_model(),
DummyMCOModel)
def test_fast_definition_errors(self):
factory = DummyMCOFactoryFast(mock.Mock(spec=Plugin))
factory.optimizer_class = None
factory.model_class = None
factory.communicator_class = None
def test_broken_get_identifier(self):
class Broken(DummyMCOFactory):
def get_identifier(self):
return None
with self.assertRaises(ValueError):
Broken(self.plugin)
def test_broken_get_name(self):
class Broken(DummyMCOFactory):
def get_name(self):
return None
with self.assertRaises(TraitError):
Broken(self.plugin)
def test_broken_get_model_class(self):
class Broken(DummyMCOFactory):
def get_model_class(self):
return None
with self.assertRaises(TraitError):
Broken(self.plugin)
def test_broken_get_optimiser_class(self):
class Broken(DummyMCOFactory):
def get_optimizer_class(self):
return None
with testfixtures.LogCapture():
with self.assertRaises(RuntimeError):
factory.create_optimizer()
with self.assertRaises(TraitError):
Broken(self.plugin)
with self.assertRaises(RuntimeError):
factory.create_communicator()
def test_broken_get_communicator_class(self):
class Broken(DummyMCOFactory):
def get_communicator_class(self):
return None
with self.assertRaises(RuntimeError):
factory.create_model()
with self.assertRaises(TraitError):
Broken(self.plugin)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment