Skip to content
Snippets Groups Projects
base_extension_plugin.py 5.96 KiB
Newer Older
import logging
import traceback

Stefano Borini's avatar
Stefano Borini committed
from envisage.plugin import Plugin
from traits.api import List, Unicode, Bool, Type, Either, Instance
Stefano Borini's avatar
Stefano Borini committed

from force_bdss.data_sources.base_data_source_factory import \
    BaseDataSourceFactory
from force_bdss.mco.base_mco_factory import BaseMCOFactory
from force_bdss.notification_listeners.base_notification_listener_factory \
    import \
    BaseNotificationListenerFactory
from force_bdss.ui_hooks.base_ui_hooks_factory import BaseUIHooksFactory
Stefano Borini's avatar
Stefano Borini committed
from .notification_listeners.i_notification_listener_factory import \
    INotificationListenerFactory
from .ids import ExtensionPointID, plugin_id
from .data_sources.i_data_source_factory import IDataSourceFactory
from .mco.i_mco_factory import IMCOFactory
from .ui_hooks.i_ui_hooks_factory import IUIHooksFactory
Stefano Borini's avatar
Stefano Borini committed


class BaseExtensionPlugin(Plugin):
    """Base class for extension plugins, that is, plugins that are
    provided by external contributors.

    It provides a set of slots to be populated that end up contributing
    to the application extension points. To use the class, simply inherit it
    in your plugin, and reimplement the methods as from example::

        class MyPlugin(BaseExtensionPlugin):
            def get_producer(self):
                return "enthought"
            def get_identifier(self):
                return "myplugin"

            def get_factory_classes(self):
                return [
                    MyDataSourceFactory1,
                    MyDataSourceFactory2,
                    MyMCOFactory
    #: Reports if the plugin loaded its factories successfully or not.
    broken = Bool(False)

    #: The error that have been generated by the instantiations.
    error_msg = Unicode()

    #: The error that have been generated by the instantiations.
    error_tb = Unicode()

    #: A list of all the factory classes to export.
    factory_classes = List(
        Either(Type(BaseDataSourceFactory),
               Type(BaseMCOFactory),
               Type(BaseNotificationListenerFactory),
Stefano Borini's avatar
Stefano Borini committed
               Type(BaseUIHooksFactory))

    #: A list of available Multi Criteria Optimizers this plugin exports.
    mco_factories = List(
        contributes_to=ExtensionPointID.MCO_FACTORIES
Stefano Borini's avatar
Stefano Borini committed
    )

    #: A list of the available Data Sources this plugin exports.
    data_source_factories = List(
        IDataSourceFactory,
        contributes_to=ExtensionPointID.DATA_SOURCE_FACTORIES
Stefano Borini's avatar
Stefano Borini committed

    #: A list of the available notification listeners this plugin exports
Stefano Borini's avatar
Stefano Borini committed
    notification_listener_factories = List(
        INotificationListenerFactory,
        contributes_to=ExtensionPointID.NOTIFICATION_LISTENER_FACTORIES
Stefano Borini's avatar
Stefano Borini committed
    )
    #: A list of the available ui hooks this plugin exports
    ui_hooks_factories = List(
        IUIHooksFactory,
        contributes_to=ExtensionPointID.UI_HOOKS_FACTORIES
    )
    #: The logger.
    _logger = Instance(logging.Logger)

        broken = False
        error = ""

        if "id" not in kwargs:
            try:
                id_ = plugin_id(self.get_producer(), self.get_identifier())
            except Exception as e:
                self._logger.exception(e)
                error = traceback.format_exc()
                broken = True
            else:
                kwargs["id"] = id_

        super(BaseExtensionPlugin, self).__init__(*args, **kwargs)
        if broken:
            self.broken = True
            self.error = error
            return

            self.factory_classes = self.get_factory_classes()
            self.mco_factories[:] = [
                cls(self)
                for cls in self._factory_by_type(BaseMCOFactory)]
            self.data_source_factories[:] = [
                cls(self)
                for cls in self._factory_by_type(BaseDataSourceFactory)]
            self.notification_listener_factories[:] = [
                cls(self)
                for cls in self._factory_by_type(
                    BaseNotificationListenerFactory)
            ]
            self.ui_hooks_factories[:] = [
                cls(self)
                for cls in self._factory_by_type(BaseUIHooksFactory)
            self._logger.exception(e)
            self.error_msg = str(e)
            self.error_tb = traceback.format_exc()
            self.broken = True
            self.mco_factories[:] = []
            self.data_source_factories[:] = []
            self.notification_listener_factories[:] = []
            self.ui_hooks_factories[:] = []

    def get_producer(self):
        """Must be reimplemented to return a string with the name of the
        company producing this plugin. Examples are "enthought", "itwm" etc.
        """

        raise NotImplementedError(
            "get_producer was not implemented in plugin {}".format(
                self.__class__))

    def get_identifier(self):
        """Must return a string with the name of the plugin the producer
        is releasing. The name must be unique and is responsibility of
        the producer to guarantee this name is not conflicting with
        another already existing plugin
        """
        raise NotImplementedError(
            "get_identifier was not implemented in plugin {}".format(
                self.__class__))

    def get_factory_classes(self):
        """Must return a list of factory classes that this plugin exports.
        """
        raise NotImplementedError(
            "get_factory_classes was not implemented in plugin {}".format(
                self.__class__))

    def _factory_by_type(self, type_):
        """Returns all the factories of the given type"""
        return [cls for cls in self.factory_classes if issubclass(cls, type_)]

    def _id_default(self):
        """Override for base method that raises a warning we don't want to
        show"""
        return '%s.%s' % (type(self).__module__, type(self).__name__)

    def __logger_default(self):
        return logging.getLogger(self.__class__.__name__)