Skip to content
Snippets Groups Projects
Commit 7ebcd1f9 authored by Stefano Borini's avatar Stefano Borini
Browse files

upgraded base mco factory

parent d7e80315
No related branches found
No related tags found
1 merge request!130Safer plugin import - 2
...@@ -25,16 +25,16 @@ class BaseMCOFactory(ABCHasStrictTraits): ...@@ -25,16 +25,16 @@ class BaseMCOFactory(ABCHasStrictTraits):
name = Str() name = Str()
#: The optimizer class to instantiate. Define this to your MCO class. #: The optimizer class to instantiate. Define this to your MCO class.
optimizer_class = Type(BaseMCO) optimizer_class = Type(BaseMCO, allow_none=False)
#: The model associated to the MCO. Define this to your MCO model class. #: The model associated to the MCO. Define this to your MCO model class.
model_class = Type(BaseMCOModel) model_class = Type(BaseMCOModel, allow_none=False)
#: The communicator associated to the MCO. Define this to your MCO comm. #: The communicator associated to the MCO. Define this to your MCO comm.
communicator_class = Type(BaseMCOCommunicator) communicator_class = Type(BaseMCOCommunicator, allow_none=False)
#: A reference to the Plugin that holds this factory. #: A reference to the Plugin that holds this factory.
plugin = Instance(Plugin) plugin = Instance(Plugin, allow_none=False)
def __init__(self, plugin, *args, **kwargs): def __init__(self, plugin, *args, **kwargs):
self.plugin = plugin self.plugin = plugin
...@@ -45,7 +45,17 @@ class BaseMCOFactory(ABCHasStrictTraits): ...@@ -45,7 +45,17 @@ class BaseMCOFactory(ABCHasStrictTraits):
self.model_class = self.get_model_class() self.model_class = self.get_model_class()
self.communicator_class = self.get_communicator_class() self.communicator_class = self.get_communicator_class()
identifier = self.get_identifier() identifier = self.get_identifier()
self.id = factory_id(self.plugin.id, identifier) try:
id = factory_id(self.plugin.id, identifier)
except ValueError:
raise ValueError(
"Invalid identifier {} returned by "
"{}.get_identifier()".format(
identifier,
self.__class__.__name__
)
)
self.id = id
def get_optimizer_class(self): def get_optimizer_class(self):
raise NotImplementedError( raise NotImplementedError(
......
import unittest import unittest
import testfixtures from traits.trait_errors import TraitError
from force_bdss.mco.base_mco_model import BaseMCOModel from force_bdss.mco.base_mco_model import BaseMCOModel
from force_bdss.mco.tests.test_base_mco import DummyMCO from force_bdss.mco.tests.test_base_mco import DummyMCO
...@@ -18,18 +18,20 @@ from force_bdss.mco.base_mco_factory import BaseMCOFactory ...@@ -18,18 +18,20 @@ from force_bdss.mco.base_mco_factory import BaseMCOFactory
class DummyMCOFactory(BaseMCOFactory): class DummyMCOFactory(BaseMCOFactory):
id = "foo" def get_identifier(self):
return "foo"
name = "bar" def get_name(self):
return "bar"
def create_optimizer(self): def get_model_class(self):
pass return DummyMCOModel
def create_model(self, model_data=None): def get_communicator_class(self):
pass return DummyMCOCommunicator
def create_communicator(self): def get_optimizer_class(self):
pass return DummyMCO
def parameter_factories(self): def parameter_factories(self):
return [] return []
...@@ -39,26 +41,14 @@ class DummyMCOModel(BaseMCOModel): ...@@ -39,26 +41,14 @@ class DummyMCOModel(BaseMCOModel):
pass pass
class DummyMCOFactoryFast(BaseMCOFactory):
id = "foo"
name = "bar"
optimizer_class = DummyMCO
model_class = DummyMCOModel
communicator_class = DummyMCOCommunicator
class TestBaseMCOFactory(unittest.TestCase): class TestBaseMCOFactory(unittest.TestCase):
def setUp(self):
self.plugin = mock.Mock(spec=Plugin, id="pid")
def test_initialization(self): def test_initialization(self):
factory = DummyMCOFactory(mock.Mock(spec=Plugin)) factory = DummyMCOFactory(self.plugin)
self.assertEqual(factory.id, 'foo') self.assertEqual(factory.id, 'pid.factory.foo')
self.assertEqual(factory.name, 'bar') self.assertEqual(factory.name, 'bar')
def test_fast_definition(self):
factory = DummyMCOFactoryFast(mock.Mock(spec=Plugin))
self.assertIsInstance(factory.create_optimizer(), self.assertIsInstance(factory.create_optimizer(),
DummyMCO) DummyMCO)
self.assertIsInstance(factory.create_communicator(), self.assertIsInstance(factory.create_communicator(),
...@@ -66,18 +56,42 @@ class TestBaseMCOFactory(unittest.TestCase): ...@@ -66,18 +56,42 @@ class TestBaseMCOFactory(unittest.TestCase):
self.assertIsInstance(factory.create_model(), self.assertIsInstance(factory.create_model(),
DummyMCOModel) DummyMCOModel)
def test_fast_definition_errors(self): def test_broken_get_identifier(self):
factory = DummyMCOFactoryFast(mock.Mock(spec=Plugin)) class Broken(DummyMCOFactory):
factory.optimizer_class = None def get_identifier(self):
factory.model_class = None return None
factory.communicator_class = None
with self.assertRaises(ValueError):
Broken(self.plugin)
def test_broken_get_name(self):
class Broken(DummyMCOFactory):
def get_name(self):
return None
with self.assertRaises(TraitError):
Broken(self.plugin)
def test_broken_get_model_class(self):
class Broken(DummyMCOFactory):
def get_model_class(self):
return None
with self.assertRaises(TraitError):
Broken(self.plugin)
def test_broken_get_optimiser_class(self):
class Broken(DummyMCOFactory):
def get_optimizer_class(self):
return None
with testfixtures.LogCapture(): with self.assertRaises(TraitError):
with self.assertRaises(RuntimeError): Broken(self.plugin)
factory.create_optimizer()
with self.assertRaises(RuntimeError): def test_broken_get_communicator_class(self):
factory.create_communicator() class Broken(DummyMCOFactory):
def get_communicator_class(self):
return None
with self.assertRaises(RuntimeError): with self.assertRaises(TraitError):
factory.create_model() Broken(self.plugin)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment