Skip to content
Snippets Groups Projects
Unverified Commit cd3e16c4 authored by Stefano Borini's avatar Stefano Borini Committed by GitHub
Browse files

Merge pull request #114 from force-h2020/automatic-creation-of-objects

Introduces fast declaration style for factory
parents 766163d2 bf1c4a42
No related branches found
No related tags found
No related merge requests found
Showing
with 454 additions and 41 deletions
from traits.api import ABCHasStrictTraits, Instance
import abc import abc
from traits.api import ABCHasStrictTraits, Instance
from ..data_sources.i_data_source_factory import IDataSourceFactory from ..data_sources.i_data_source_factory import IDataSourceFactory
......
import abc import logging
from traits.api import ABCHasStrictTraits, provides, String, Instance from traits.api import ABCHasStrictTraits, provides, String, Instance, Type
from envisage.plugin import Plugin from envisage.plugin import Plugin
from .i_data_source_factory import IDataSourceFactory from force_bdss.data_sources.base_data_source import BaseDataSource
from force_bdss.data_sources.base_data_source_model import BaseDataSourceModel
from force_bdss.data_sources.i_data_source_factory import IDataSourceFactory
log = logging.getLogger(__name__)
@provides(IDataSourceFactory) @provides(IDataSourceFactory)
...@@ -19,14 +23,22 @@ class BaseDataSourceFactory(ABCHasStrictTraits): ...@@ -19,14 +23,22 @@ class BaseDataSourceFactory(ABCHasStrictTraits):
#: A human readable name of the factory. Spaces allowed #: A human readable name of the factory. Spaces allowed
name = String() name = String()
#: The data source to be instantiated. Define this to your DataSource
data_source_class = Type(BaseDataSource)
#: The model associated to the data source.
#: Define this to your DataSourceModel
model_class = Type(BaseDataSourceModel)
#: Reference to the plugin that carries this factory #: Reference to the plugin that carries this factory
#: This is automatically set by the system. you should not define it
#: in your subclass.
plugin = Instance(Plugin) plugin = Instance(Plugin)
def __init__(self, plugin, *args, **kwargs): def __init__(self, plugin, *args, **kwargs):
self.plugin = plugin self.plugin = plugin
super(BaseDataSourceFactory, self).__init__(*args, **kwargs) super(BaseDataSourceFactory, self).__init__(*args, **kwargs)
@abc.abstractmethod
def create_data_source(self): def create_data_source(self):
"""Factory method. """Factory method.
Must return the factory-specific BaseDataSource instance. Must return the factory-specific BaseDataSource instance.
...@@ -36,8 +48,15 @@ class BaseDataSourceFactory(ABCHasStrictTraits): ...@@ -36,8 +48,15 @@ class BaseDataSourceFactory(ABCHasStrictTraits):
BaseDataSource BaseDataSource
The specific instance of the generated DataSource The specific instance of the generated DataSource
""" """
if self.data_source_class is None:
msg = ("data_source_class cannot be None in {}. Either define "
"data_source_class or reimplement create_data_source on "
"your factory class.".format(self.__class__.__name__))
log.error(msg)
raise RuntimeError(msg)
return self.data_source_class(self)
@abc.abstractmethod
def create_model(self, model_data=None): def create_model(self, model_data=None):
"""Factory method. """Factory method.
Creates the model object (or network of model objects) of the KPI Creates the model object (or network of model objects) of the KPI
...@@ -55,3 +74,14 @@ class BaseDataSourceFactory(ABCHasStrictTraits): ...@@ -55,3 +74,14 @@ class BaseDataSourceFactory(ABCHasStrictTraits):
BaseDataSourceModel BaseDataSourceModel
The model The model
""" """
if model_data is None:
model_data = {}
if self.model_class is None:
msg = ("model_class cannot be None in {}. Either define "
"model_class or reimplement create_model on your "
"factory class.".format(self.__class__.__name__))
log.error(msg)
raise RuntimeError(msg)
return self.model_class(self, **model_data)
from envisage.api import Plugin from envisage.api import Plugin
from traits.api import Interface, String, Instance from traits.api import Interface, String, Instance, Type
class IDataSourceFactory(Interface): class IDataSourceFactory(Interface):
...@@ -12,6 +12,14 @@ class IDataSourceFactory(Interface): ...@@ -12,6 +12,14 @@ class IDataSourceFactory(Interface):
name = String() name = String()
data_source_class = Type(
"force_bdss.data_sources.base_data_source.BaseDataSource"
)
model_class = Type(
"force_bdss.data_sources.base_data_source_model.BaseDataSourceModel"
)
plugin = Instance(Plugin) plugin = Instance(Plugin)
def create_data_source(self): def create_data_source(self):
......
import unittest import unittest
from force_bdss.data_sources.tests.test_base_data_source import DummyDataSource
from force_bdss.data_sources.tests.test_base_data_source_model import \
DummyDataSourceModel
try: try:
import mock import mock
except ImportError: except ImportError:
from unittest import mock from unittest import mock
import testfixtures
from envisage.plugin import Plugin from envisage.plugin import Plugin
from force_bdss.data_sources.base_data_source_factory import \ from force_bdss.data_sources.base_data_source_factory import \
BaseDataSourceFactory BaseDataSourceFactory
...@@ -21,8 +28,35 @@ class DummyDataSourceFactory(BaseDataSourceFactory): ...@@ -21,8 +28,35 @@ class DummyDataSourceFactory(BaseDataSourceFactory):
pass pass
class DummyDataSourceFactoryFast(BaseDataSourceFactory):
id = "foo"
name = "bar"
model_class = DummyDataSourceModel
data_source_class = DummyDataSource
class TestBaseDataSourceFactory(unittest.TestCase): class TestBaseDataSourceFactory(unittest.TestCase):
def test_initialization(self): def test_initialization(self):
factory = DummyDataSourceFactory(mock.Mock(spec=Plugin)) factory = DummyDataSourceFactory(mock.Mock(spec=Plugin))
self.assertEqual(factory.id, 'foo') self.assertEqual(factory.id, 'foo')
self.assertEqual(factory.name, 'bar') self.assertEqual(factory.name, 'bar')
def test_fast_specification(self):
factory = DummyDataSourceFactoryFast(mock.Mock(spec=Plugin))
self.assertIsInstance(factory.create_data_source(), DummyDataSource)
self.assertIsInstance(factory.create_model(), DummyDataSourceModel)
def test_fast_specification_errors(self):
factory = DummyDataSourceFactoryFast(mock.Mock(spec=Plugin))
factory.model_class = None
factory.data_source_class = None
with testfixtures.LogCapture():
with self.assertRaises(RuntimeError):
factory.create_data_source()
with self.assertRaises(RuntimeError):
factory.create_model()
import abc import logging
from envisage.plugin import Plugin from envisage.plugin import Plugin
from traits.api import ABCHasStrictTraits, provides, String, Instance from traits.api import ABCHasStrictTraits, provides, String, Instance, Type
from force_bdss.kpi.base_kpi_calculator import BaseKPICalculator
from force_bdss.kpi.base_kpi_calculator_model import BaseKPICalculatorModel
from .i_kpi_calculator_factory import IKPICalculatorFactory from .i_kpi_calculator_factory import IKPICalculatorFactory
log = logging.getLogger(__name__)
@provides(IKPICalculatorFactory) @provides(IKPICalculatorFactory)
class BaseKPICalculatorFactory(ABCHasStrictTraits): class BaseKPICalculatorFactory(ABCHasStrictTraits):
"""Base class for the Key Performance Indicator calculator factories. """Base class for the Key Performance Indicator calculator factories.
...@@ -20,6 +25,13 @@ class BaseKPICalculatorFactory(ABCHasStrictTraits): ...@@ -20,6 +25,13 @@ class BaseKPICalculatorFactory(ABCHasStrictTraits):
#: A UI friendly name for the factory. Can contain spaces. #: A UI friendly name for the factory. Can contain spaces.
name = String() name = String()
#: The KPI calculator to be instantiated. Define this to your KPICalculator
kpi_calculator_class = Type(BaseKPICalculator)
#: The model associated to the KPI calculator.
#: Define this to your KPICalculatorModel
model_class = Type(BaseKPICalculatorModel)
#: A reference to the plugin that holds this factory. #: A reference to the plugin that holds this factory.
plugin = Instance(Plugin) plugin = Instance(Plugin)
...@@ -34,7 +46,6 @@ class BaseKPICalculatorFactory(ABCHasStrictTraits): ...@@ -34,7 +46,6 @@ class BaseKPICalculatorFactory(ABCHasStrictTraits):
self.plugin = plugin self.plugin = plugin
super(BaseKPICalculatorFactory, self).__init__(*args, **kwargs) super(BaseKPICalculatorFactory, self).__init__(*args, **kwargs)
@abc.abstractmethod
def create_kpi_calculator(self): def create_kpi_calculator(self):
"""Factory method. """Factory method.
Creates and returns an instance of a KPI Calculator, associated Creates and returns an instance of a KPI Calculator, associated
...@@ -45,8 +56,15 @@ class BaseKPICalculatorFactory(ABCHasStrictTraits): ...@@ -45,8 +56,15 @@ class BaseKPICalculatorFactory(ABCHasStrictTraits):
BaseKPICalculator BaseKPICalculator
The specific instance of the generated KPICalculator The specific instance of the generated KPICalculator
""" """
if self.kpi_calculator_class is None:
msg = ("kpi_calculator_class cannot be None in {}. Either define "
"kpi_calculator_class or reimplement create_kpi_calculator "
"on your factory class.".format(self.__class__.__name__))
log.error(msg)
raise RuntimeError(msg)
return self.kpi_calculator_class(self)
@abc.abstractmethod
def create_model(self, model_data=None): def create_model(self, model_data=None):
"""Factory method. """Factory method.
Creates the model object (or network of model objects) of the KPI Creates the model object (or network of model objects) of the KPI
...@@ -64,3 +82,14 @@ class BaseKPICalculatorFactory(ABCHasStrictTraits): ...@@ -64,3 +82,14 @@ class BaseKPICalculatorFactory(ABCHasStrictTraits):
BaseKPICalculatorModel BaseKPICalculatorModel
The model The model
""" """
if model_data is None:
model_data = {}
if self.model_class is None:
msg = ("model_class cannot be None in {}. Either define "
"model_class or reimplement create_model on your "
"factory class.".format(self.__class__.__name__))
log.error(msg)
raise RuntimeError(msg)
return self.model_class(self, **model_data)
from traits.api import Interface, String, Instance from traits.api import Interface, String, Instance, Type
from envisage.plugin import Plugin from envisage.plugin import Plugin
...@@ -12,6 +12,14 @@ class IKPICalculatorFactory(Interface): ...@@ -12,6 +12,14 @@ class IKPICalculatorFactory(Interface):
name = String() name = String()
kpi_calculator_class = Type(
"force_bdss.kpi.base_kpi_calculator.BaseKPICalculator"
)
model_class = Type(
"force_bdss.kpi.base_kpi_calculator_model.BaseKPICalculatorModel"
)
plugin = Instance(Plugin) plugin = Instance(Plugin)
def create_kpi_calculator(self): def create_kpi_calculator(self):
......
import unittest import unittest
import testfixtures
from envisage.plugin import Plugin from envisage.plugin import Plugin
from force_bdss.kpi.tests.test_base_kpi_calculator import DummyKPICalculator
from force_bdss.kpi.tests.test_base_kpi_calculator_model import \
DummyKPICalculatorModel
try: try:
import mock import mock
except ImportError: except ImportError:
...@@ -22,8 +27,38 @@ class DummyKPICalculatorFactory(BaseKPICalculatorFactory): ...@@ -22,8 +27,38 @@ class DummyKPICalculatorFactory(BaseKPICalculatorFactory):
pass pass
class DummyKPICalculatorFactoryFast(BaseKPICalculatorFactory):
id = "foo"
name = "bar"
kpi_calculator_class = DummyKPICalculator
model_class = DummyKPICalculatorModel
class TestBaseKPICalculatorFactory(unittest.TestCase): class TestBaseKPICalculatorFactory(unittest.TestCase):
def test_initialization(self): def test_initialization(self):
factory = DummyKPICalculatorFactory(mock.Mock(spec=Plugin)) factory = DummyKPICalculatorFactory(mock.Mock(spec=Plugin))
self.assertEqual(factory.id, 'foo') self.assertEqual(factory.id, 'foo')
self.assertEqual(factory.name, 'bar') self.assertEqual(factory.name, 'bar')
def test_fast_definition(self):
factory = DummyKPICalculatorFactoryFast(mock.Mock(spec=Plugin))
self.assertIsInstance(factory.create_kpi_calculator(),
DummyKPICalculator)
self.assertIsInstance(factory.create_model(),
DummyKPICalculatorModel)
def test_fast_definition_errors(self):
factory = DummyKPICalculatorFactoryFast(mock.Mock(spec=Plugin))
factory.kpi_calculator_class = None
factory.model_class = None
with testfixtures.LogCapture():
with self.assertRaises(RuntimeError):
factory.create_kpi_calculator()
with self.assertRaises(RuntimeError):
factory.create_model()
import abc import logging
from traits.api import ABCHasStrictTraits, String, provides, Instance, Type
from traits.api import ABCHasStrictTraits, String, provides, Instance
from envisage.plugin import Plugin from envisage.plugin import Plugin
from force_bdss.mco.base_mco import BaseMCO
from force_bdss.mco.base_mco_communicator import BaseMCOCommunicator
from force_bdss.mco.base_mco_model import BaseMCOModel
from .i_mco_factory import IMCOFactory from .i_mco_factory import IMCOFactory
log = logging.getLogger(__name__)
@provides(IMCOFactory) @provides(IMCOFactory)
class BaseMCOFactory(ABCHasStrictTraits): class BaseMCOFactory(ABCHasStrictTraits):
...@@ -19,6 +23,15 @@ class BaseMCOFactory(ABCHasStrictTraits): ...@@ -19,6 +23,15 @@ class BaseMCOFactory(ABCHasStrictTraits):
#: A user friendly name of the factory. Spaces allowed. #: A user friendly name of the factory. Spaces allowed.
name = String() name = String()
#: The optimizer class to instantiate. Define this to your MCO class.
optimizer_class = Type(BaseMCO)
#: The model associated to the MCO. Define this to your MCO model class.
model_class = Type(BaseMCOModel)
#: The communicator associated to the MCO. Define this to your MCO comm.
communicator_class = Type(BaseMCOCommunicator)
#: A reference to the Plugin that holds this factory. #: A reference to the Plugin that holds this factory.
plugin = Instance(Plugin) plugin = Instance(Plugin)
...@@ -26,7 +39,6 @@ class BaseMCOFactory(ABCHasStrictTraits): ...@@ -26,7 +39,6 @@ class BaseMCOFactory(ABCHasStrictTraits):
self.plugin = plugin self.plugin = plugin
super(BaseMCOFactory, self).__init__(*args, **kwargs) super(BaseMCOFactory, self).__init__(*args, **kwargs)
@abc.abstractmethod
def create_optimizer(self): def create_optimizer(self):
"""Factory method. """Factory method.
Creates the optimizer with the given application Creates the optimizer with the given application
...@@ -34,11 +46,18 @@ class BaseMCOFactory(ABCHasStrictTraits): ...@@ -34,11 +46,18 @@ class BaseMCOFactory(ABCHasStrictTraits):
Returns Returns
------- -------
BaseMCOOptimizer BaseMCO
The optimizer The optimizer
""" """
if self.optimizer_class is None:
msg = ("optimizer_class cannot be None in {}. Either define "
"optimizer_class or reimplement create_optimizer on "
"your factory class.".format(self.__class__.__name__))
log.error(msg)
raise RuntimeError(msg)
return self.optimizer_class(self)
@abc.abstractmethod
def create_model(self, model_data=None): def create_model(self, model_data=None):
"""Factory method. """Factory method.
Creates the model object (or network of model objects) of the MCO. Creates the model object (or network of model objects) of the MCO.
...@@ -57,8 +76,18 @@ class BaseMCOFactory(ABCHasStrictTraits): ...@@ -57,8 +76,18 @@ class BaseMCOFactory(ABCHasStrictTraits):
BaseMCOModel BaseMCOModel
The MCOModel The MCOModel
""" """
if model_data is None:
model_data = {}
if self.model_class is None:
msg = ("model_class cannot be None in {}. Either define "
"model_class or reimplement create_model on your "
"factory class.".format(self.__class__.__name__))
log.error(msg)
raise RuntimeError(msg)
return self.model_class(self, **model_data)
@abc.abstractmethod
def create_communicator(self): def create_communicator(self):
"""Factory method. Returns the communicator class that allows """Factory method. Returns the communicator class that allows
exchange between the MCO and the evaluator code. exchange between the MCO and the evaluator code.
...@@ -68,8 +97,15 @@ class BaseMCOFactory(ABCHasStrictTraits): ...@@ -68,8 +97,15 @@ class BaseMCOFactory(ABCHasStrictTraits):
BaseMCOCommunicator BaseMCOCommunicator
An instance of the communicator An instance of the communicator
""" """
if self.communicator_class is None:
msg = ("communicator_class cannot be None in {}. Either define "
"communicator_class or reimplement create_communicator on "
"your factory class.".format(self.__class__.__name__))
log.error(msg)
raise RuntimeError(msg)
return self.communicator_class(self)
@abc.abstractmethod
def parameter_factories(self): def parameter_factories(self):
"""Returns the parameter factories supported by this MCO """Returns the parameter factories supported by this MCO
......
from traits.api import Interface, String, Instance from traits.api import Interface, String, Instance, Type
from envisage.plugin import Plugin from envisage.plugin import Plugin
...@@ -12,6 +12,18 @@ class IMCOFactory(Interface): ...@@ -12,6 +12,18 @@ class IMCOFactory(Interface):
name = String() name = String()
optimizer_class = Type(
"force_bdss.mco.base_mco.BaseMCO"
)
model_class = Type(
"force_bdss.mco.base_mco_communicator.BaseMCOCommunicator"
)
communicator_class = Type(
"force_bdss.mco.base_mco_model.BaseMCOModel"
)
plugin = Instance(Plugin) plugin = Instance(Plugin)
def create_optimizer(self): def create_optimizer(self):
......
from traits.api import HasStrictTraits, String, Type, Instance from traits.api import HasStrictTraits, String, Type, Instance
from ..base_mco_factory import BaseMCOFactory
class BaseMCOParameterFactory(HasStrictTraits): class BaseMCOParameterFactory(HasStrictTraits):
"""Factory that produces the model instance of a given BASEMCOParameter """Factory that produces the model instance of a given BASEMCOParameter
...@@ -13,7 +11,7 @@ class BaseMCOParameterFactory(HasStrictTraits): ...@@ -13,7 +11,7 @@ class BaseMCOParameterFactory(HasStrictTraits):
""" """
#: A reference to the MCO factory this parameter factory lives in. #: A reference to the MCO factory this parameter factory lives in.
mco_factory = Instance(BaseMCOFactory) mco_factory = Instance('force_bdss.mco.base_mco_factory.BaseMCOFactory')
#: A unique string identifying the parameter #: A unique string identifying the parameter
id = String() id = String()
...@@ -25,7 +23,9 @@ class BaseMCOParameterFactory(HasStrictTraits): ...@@ -25,7 +23,9 @@ class BaseMCOParameterFactory(HasStrictTraits):
description = String("Undefined parameter") description = String("Undefined parameter")
# The model class to instantiate when create_model is called. # The model class to instantiate when create_model is called.
model_class = Type('BaseMCOParameter') model_class = Type(
"force_bdss.mco.parameters.base_mco_parameter.BaseMCOParameter"
)
def __init__(self, mco_factory, *args, **kwargs): def __init__(self, mco_factory, *args, **kwargs):
self.mco_factory = mco_factory self.mco_factory = mco_factory
......
import unittest import unittest
from envisage.plugin import Plugin
from force_bdss.mco.base_mco_factory import BaseMCOFactory from force_bdss.mco.base_mco_factory import BaseMCOFactory
try: try:
...@@ -25,9 +27,14 @@ class DummyMCOParameterFactory(BaseMCOParameterFactory): ...@@ -25,9 +27,14 @@ class DummyMCOParameterFactory(BaseMCOParameterFactory):
model_class = DummyMCOParameter model_class = DummyMCOParameter
class DummyMCOFactory(BaseMCOFactory):
pass
class TestBaseMCOParameterFactory(unittest.TestCase): class TestBaseMCOParameterFactory(unittest.TestCase):
def test_initialization(self): def test_initialization(self):
factory = DummyMCOParameterFactory(mock.Mock(spec=BaseMCOFactory)) factory = DummyMCOParameterFactory(
mco_factory=BaseMCOFactory(plugin=mock.Mock(spec=Plugin)))
model = factory.create_model({"x": 42}) model = factory.create_model({"x": 42})
self.assertIsInstance(model, DummyMCOParameter) self.assertIsInstance(model, DummyMCOParameter)
self.assertEqual(model.x, 42) self.assertEqual(model.x, 42)
......
import unittest import unittest
import testfixtures
from force_bdss.mco.base_mco_model import BaseMCOModel
from force_bdss.mco.tests.test_base_mco import DummyMCO
from force_bdss.mco.tests.test_base_mco_communicator import \
DummyMCOCommunicator
try: try:
import mock import mock
except ImportError: except ImportError:
...@@ -28,8 +35,49 @@ class DummyMCOFactory(BaseMCOFactory): ...@@ -28,8 +35,49 @@ class DummyMCOFactory(BaseMCOFactory):
return [] return []
class DummyMCOModel(BaseMCOModel):
pass
class DummyMCOFactoryFast(BaseMCOFactory):
id = "foo"
name = "bar"
optimizer_class = DummyMCO
model_class = DummyMCOModel
communicator_class = DummyMCOCommunicator
class TestBaseMCOFactory(unittest.TestCase): class TestBaseMCOFactory(unittest.TestCase):
def test_initialization(self): def test_initialization(self):
factory = DummyMCOFactory(mock.Mock(spec=Plugin)) factory = DummyMCOFactory(mock.Mock(spec=Plugin))
self.assertEqual(factory.id, 'foo') self.assertEqual(factory.id, 'foo')
self.assertEqual(factory.name, 'bar') self.assertEqual(factory.name, 'bar')
def test_fast_definition(self):
factory = DummyMCOFactoryFast(mock.Mock(spec=Plugin))
self.assertIsInstance(factory.create_optimizer(),
DummyMCO)
self.assertIsInstance(factory.create_communicator(),
DummyMCOCommunicator)
self.assertIsInstance(factory.create_model(),
DummyMCOModel)
def test_fast_definition_errors(self):
factory = DummyMCOFactoryFast(mock.Mock(spec=Plugin))
factory.optimizer_class = None
factory.model_class = None
factory.communicator_class = None
with testfixtures.LogCapture():
with self.assertRaises(RuntimeError):
factory.create_optimizer()
with self.assertRaises(RuntimeError):
factory.create_communicator()
with self.assertRaises(RuntimeError):
factory.create_model()
import abc import logging
from traits.api import ABCHasStrictTraits, Instance, String, provides, Type
from traits.api import ABCHasStrictTraits, Instance, String, provides
from envisage.plugin import Plugin from envisage.plugin import Plugin
from force_bdss.notification_listeners.base_notification_listener import \
BaseNotificationListener
from force_bdss.notification_listeners.base_notification_listener_model \
import \
BaseNotificationListenerModel
from .i_notification_listener_factory import INotificationListenerFactory from .i_notification_listener_factory import INotificationListenerFactory
log = logging.getLogger(__name__)
@provides(INotificationListenerFactory) @provides(INotificationListenerFactory)
class BaseNotificationListenerFactory(ABCHasStrictTraits): class BaseNotificationListenerFactory(ABCHasStrictTraits):
...@@ -18,6 +24,14 @@ class BaseNotificationListenerFactory(ABCHasStrictTraits): ...@@ -18,6 +24,14 @@ class BaseNotificationListenerFactory(ABCHasStrictTraits):
#: Name of the factory. User friendly for UI #: Name of the factory. User friendly for UI
name = String() name = String()
#: The listener class that must be instantiated. Define this to your
#: listener class.
listener_class = Type(BaseNotificationListener)
#: The associated model to the listener. Define this to your
#: listener model class.
model_class = Type(BaseNotificationListenerModel)
#: A reference to the containing plugin #: A reference to the containing plugin
plugin = Instance(Plugin) plugin = Instance(Plugin)
...@@ -32,13 +46,19 @@ class BaseNotificationListenerFactory(ABCHasStrictTraits): ...@@ -32,13 +46,19 @@ class BaseNotificationListenerFactory(ABCHasStrictTraits):
self.plugin = plugin self.plugin = plugin
super(BaseNotificationListenerFactory, self).__init__(*args, **kwargs) super(BaseNotificationListenerFactory, self).__init__(*args, **kwargs)
@abc.abstractmethod
def create_listener(self): def create_listener(self):
""" """
Creates an instance of the listener. Creates an instance of the listener.
""" """
if self.listener_class is None:
msg = ("listener_class cannot be None in {}. Either define "
"listener_class or reimplement create_listener on "
"your factory class.".format(self.__class__.__name__))
log.error(msg)
raise RuntimeError(msg)
return self.listener_class(self)
@abc.abstractmethod
def create_model(self, model_data=None): def create_model(self, model_data=None):
""" """
Creates an instance of the model. Creates an instance of the model.
...@@ -48,3 +68,14 @@ class BaseNotificationListenerFactory(ABCHasStrictTraits): ...@@ -48,3 +68,14 @@ class BaseNotificationListenerFactory(ABCHasStrictTraits):
model_data: dict model_data: dict
Data to use to fill the model. Data to use to fill the model.
""" """
if model_data is None:
model_data = {}
if self.model_class is None:
msg = ("model_class cannot be None in {}. Either define "
"model_class or reimplement create_model on your "
"factory class.".format(self.__class__.__name__))
log.error(msg)
raise RuntimeError(msg)
return self.model_class(self, **model_data)
from traits.api import Interface, String, Instance from traits.api import Interface, String, Instance, Type
from envisage.plugin import Plugin from envisage.plugin import Plugin
...@@ -12,6 +12,16 @@ class INotificationListenerFactory(Interface): ...@@ -12,6 +12,16 @@ class INotificationListenerFactory(Interface):
name = String() name = String()
listener_class = Type(
"force_bdss.notification_listeners"
".base_notification_listener.BaseNotificationListener"
)
model_class = Type(
"force_bdss.notification_listeners"
".base_notification_listener_model.BaseNotificationListenerModel"
)
plugin = Instance(Plugin) plugin = Instance(Plugin)
def create_listener(self): def create_listener(self):
......
import unittest
import testfixtures
from envisage.plugin import Plugin
try:
import mock
except ImportError:
from unittest import mock
from force_bdss.notification_listeners.base_notification_listener import \
BaseNotificationListener
from force_bdss.notification_listeners.base_notification_listener_factory \
import \
BaseNotificationListenerFactory
from force_bdss.notification_listeners.base_notification_listener_model \
import \
BaseNotificationListenerModel
class DummyNotificationListener(BaseNotificationListener):
def deliver(self, event):
pass
class DummyNotificationListenerModel(BaseNotificationListenerModel):
pass
class DummyNotificationListenerFactory(BaseNotificationListenerFactory):
id = "foo"
name = "bar"
def create_listener(self):
return DummyNotificationListener(self)
def create_model(self, model_data=None):
return DummyNotificationListenerModel(self)
class DummyNotificationListenerFactoryFast(BaseNotificationListenerFactory):
id = "foo"
name = "bar"
listener_class = DummyNotificationListener
model_class = DummyNotificationListenerModel
class TestBaseNotificationListenerFactory(unittest.TestCase):
def test_initialization(self):
factory = DummyNotificationListenerFactory(mock.Mock(spec=Plugin))
self.assertEqual(factory.id, 'foo')
self.assertEqual(factory.name, 'bar')
def test_fast_definition(self):
factory = DummyNotificationListenerFactoryFast(mock.Mock(spec=Plugin))
self.assertIsInstance(factory.create_listener(),
DummyNotificationListener)
self.assertIsInstance(factory.create_model(),
DummyNotificationListenerModel)
def test_fast_definition_errors(self):
factory = DummyNotificationListenerFactoryFast(mock.Mock(spec=Plugin))
factory.listener_class = None
factory.model_class = None
with testfixtures.LogCapture():
with self.assertRaises(RuntimeError):
factory.create_model()
with self.assertRaises(RuntimeError):
factory.create_listener()
import abc import logging
from traits.api import ABCHasStrictTraits, Instance, String, provides, Type
from traits.api import ABCHasStrictTraits, Instance, String, provides
from envisage.plugin import Plugin from envisage.plugin import Plugin
from force_bdss.ui_hooks.base_ui_hooks_manager import BaseUIHooksManager
from .i_ui_hooks_factory import IUIHooksFactory from .i_ui_hooks_factory import IUIHooksFactory
log = logging.getLogger(__name__)
@provides(IUIHooksFactory) @provides(IUIHooksFactory)
class BaseUIHooksFactory(ABCHasStrictTraits): class BaseUIHooksFactory(ABCHasStrictTraits):
...@@ -18,6 +20,10 @@ class BaseUIHooksFactory(ABCHasStrictTraits): ...@@ -18,6 +20,10 @@ class BaseUIHooksFactory(ABCHasStrictTraits):
#: Name of the factory. User friendly for UI #: Name of the factory. User friendly for UI
name = String() name = String()
#: The UI Hooks manager class to instantiate. Define this to your
#: base hook managers.
ui_hooks_manager_class = Type(BaseUIHooksManager)
#: A reference to the containing plugin #: A reference to the containing plugin
plugin = Instance(Plugin) plugin = Instance(Plugin)
...@@ -32,7 +38,6 @@ class BaseUIHooksFactory(ABCHasStrictTraits): ...@@ -32,7 +38,6 @@ class BaseUIHooksFactory(ABCHasStrictTraits):
self.plugin = plugin self.plugin = plugin
super(BaseUIHooksFactory, self).__init__(*args, **kwargs) super(BaseUIHooksFactory, self).__init__(*args, **kwargs)
@abc.abstractmethod
def create_ui_hooks_manager(self): def create_ui_hooks_manager(self):
"""Creates an instance of the hook manager. """Creates an instance of the hook manager.
The hooks manager contains a set of methods that are applicable in The hooks manager contains a set of methods that are applicable in
...@@ -42,3 +47,12 @@ class BaseUIHooksFactory(ABCHasStrictTraits): ...@@ -42,3 +47,12 @@ class BaseUIHooksFactory(ABCHasStrictTraits):
------- -------
BaseUIHooksManager BaseUIHooksManager
""" """
if self.ui_hooks_manager_class is None:
msg = ("ui_hooks_manager_class cannot be None in {}. Either "
"define ui_hooks_manager_class or reimplement "
"create_ui_hooks_manager on "
"your factory class.".format(self.__class__.__name__))
log.error(msg)
raise RuntimeError(msg)
return self.ui_hooks_manager_class(self)
from traits.api import Interface, String, Instance from traits.api import Interface, String, Instance, Type
from envisage.plugin import Plugin from envisage.plugin import Plugin
...@@ -12,6 +12,10 @@ class IUIHooksFactory(Interface): ...@@ -12,6 +12,10 @@ class IUIHooksFactory(Interface):
name = String() name = String()
ui_hooks_manager_class = Type(
"force_bdss.ui_hooks.base_ui_hooks_manager.BaseUIHooksManager"
)
plugin = Instance(Plugin) plugin = Instance(Plugin)
def create_hook_manager(self): def create_hook_manager(self):
......
import unittest import unittest
import testfixtures
from force_bdss.ui_hooks.tests.test_base_ui_hooks_manager import \
DummyUIHooksManager
try: try:
import mock import mock
except ImportError: except ImportError:
...@@ -9,13 +14,34 @@ from envisage.api import Plugin ...@@ -9,13 +14,34 @@ from envisage.api import Plugin
from ..base_ui_hooks_factory import BaseUIHooksFactory from ..base_ui_hooks_factory import BaseUIHooksFactory
class NullUIHooksFactory(BaseUIHooksFactory): class DummyUIHooksFactory(BaseUIHooksFactory):
def create_ui_hooks_manager(self): def create_ui_hooks_manager(self):
return None return DummyUIHooksManager(self)
class DummyUIHooksFactoryFast(BaseUIHooksFactory):
ui_hooks_manager_class = DummyUIHooksManager
class TestBaseUIHooksFactory(unittest.TestCase): class TestBaseUIHooksFactory(unittest.TestCase):
def test_initialize(self): def test_initialize(self):
mock_plugin = mock.Mock(spec=Plugin) mock_plugin = mock.Mock(spec=Plugin)
factory = NullUIHooksFactory(plugin=mock_plugin) factory = DummyUIHooksFactory(plugin=mock_plugin)
self.assertEqual(factory.plugin, mock_plugin) self.assertEqual(factory.plugin, mock_plugin)
def test_fast_definition(self):
mock_plugin = mock.Mock(spec=Plugin)
factory = DummyUIHooksFactoryFast(plugin=mock_plugin)
self.assertIsInstance(
factory.create_ui_hooks_manager(),
DummyUIHooksManager)
def test_fast_definition_errors(self):
mock_plugin = mock.Mock(spec=Plugin)
factory = DummyUIHooksFactoryFast(plugin=mock_plugin)
factory.ui_hooks_manager_class = None
with testfixtures.LogCapture():
with self.assertRaises(RuntimeError):
factory.create_ui_hooks_manager()
...@@ -8,9 +8,13 @@ except ImportError: ...@@ -8,9 +8,13 @@ except ImportError:
from unittest import mock from unittest import mock
class DummyUIHooksManager(BaseUIHooksManager):
pass
class TestBaseUIHooksManager(unittest.TestCase): class TestBaseUIHooksManager(unittest.TestCase):
def test_initialization(self): def test_initialization(self):
mock_factory = mock.Mock(spec=BaseUIHooksFactory) mock_factory = mock.Mock(spec=BaseUIHooksFactory)
mgr = BaseUIHooksManager(mock_factory) mgr = DummyUIHooksManager(mock_factory)
self.assertEqual(mgr.factory, mock_factory) self.assertEqual(mgr.factory, mock_factory)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment