Skip to content
Snippets Groups Projects
Commit 0ced8db3 authored by Stefano Borini's avatar Stefano Borini
Browse files

Handle id differently to prevent issuing of warn due to bug in traits

parent a6d1aa67
No related branches found
No related tags found
1 merge request!130Safer plugin import - 2
...@@ -2,7 +2,7 @@ import logging ...@@ -2,7 +2,7 @@ import logging
import traceback import traceback
from envisage.plugin import Plugin from envisage.plugin import Plugin
from traits.api import List, Unicode, Bool, Type, Either from traits.api import List, Unicode, Bool, Type, Either, Instance
from force_bdss.data_sources.base_data_source_factory import \ from force_bdss.data_sources.base_data_source_factory import \
BaseDataSourceFactory BaseDataSourceFactory
...@@ -19,8 +19,6 @@ from .mco.i_mco_factory import IMCOFactory ...@@ -19,8 +19,6 @@ from .mco.i_mco_factory import IMCOFactory
from .ui_hooks.i_ui_hooks_factory import IUIHooksFactory from .ui_hooks.i_ui_hooks_factory import IUIHooksFactory
logger = logging.getLogger(__name__)
class BaseExtensionPlugin(Plugin): class BaseExtensionPlugin(Plugin):
"""Base class for extension plugins, that is, plugins that are """Base class for extension plugins, that is, plugins that are
...@@ -48,7 +46,10 @@ class BaseExtensionPlugin(Plugin): ...@@ -48,7 +46,10 @@ class BaseExtensionPlugin(Plugin):
broken = Bool(False) broken = Bool(False)
#: The error that have been generated by the instantiations. #: The error that have been generated by the instantiations.
error = Unicode() error_msg = Unicode()
#: The error that have been generated by the instantiations.
error_tb = Unicode()
#: A list of all the factory classes to export. #: A list of all the factory classes to export.
factory_classes = List( factory_classes = List(
...@@ -82,11 +83,31 @@ class BaseExtensionPlugin(Plugin): ...@@ -82,11 +83,31 @@ class BaseExtensionPlugin(Plugin):
contributes_to=ExtensionPointID.UI_HOOKS_FACTORIES contributes_to=ExtensionPointID.UI_HOOKS_FACTORIES
) )
#: The logger.
_logger = Instance(logging.Logger)
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
broken = False
error = ""
if "id" not in kwargs:
try:
id_ = plugin_id(self.get_producer(), self.get_identifier())
except Exception as e:
self._logger.exception(e)
error = traceback.format_exc()
broken = True
else:
kwargs["id"] = id_
super(BaseExtensionPlugin, self).__init__(*args, **kwargs) super(BaseExtensionPlugin, self).__init__(*args, **kwargs)
if broken:
self.broken = True
self.error = error
return
try: try:
self.id = plugin_id(self.get_producer(), self.get_identifier())
self.factory_classes = self.get_factory_classes() self.factory_classes = self.get_factory_classes()
self.mco_factories[:] = [ self.mco_factories[:] = [
cls(self) cls(self)
...@@ -101,12 +122,12 @@ class BaseExtensionPlugin(Plugin): ...@@ -101,12 +122,12 @@ class BaseExtensionPlugin(Plugin):
] ]
self.ui_hooks_factories[:] = [ self.ui_hooks_factories[:] = [
cls(self) cls(self)
for cls in self._factory_by_type( for cls in self._factory_by_type(BaseUIHooksFactory)
BaseUIHooksFactory)
] ]
except Exception as e: except Exception as e:
self.error = traceback.format_exc() self._logger.exception(e)
logger.exception(e) self.error_msg = str(e)
self.error_tb = traceback.format_exc()
self.broken = True self.broken = True
self.mco_factories[:] = [] self.mco_factories[:] = []
self.data_source_factories[:] = [] self.data_source_factories[:] = []
...@@ -140,4 +161,13 @@ class BaseExtensionPlugin(Plugin): ...@@ -140,4 +161,13 @@ class BaseExtensionPlugin(Plugin):
self.__class__)) self.__class__))
def _factory_by_type(self, type_): def _factory_by_type(self, type_):
"""Returns all the factories of the given type"""
return [cls for cls in self.factory_classes if issubclass(cls, type_)] return [cls for cls in self.factory_classes if issubclass(cls, type_)]
def _id_default(self):
"""Override for base method that raises a warning we don't want to
show"""
return '%s.%s' % (type(self).__module__, type(self).__name__)
def __logger_default(self):
return logging.getLogger(self.__class__.__name__)
...@@ -12,4 +12,5 @@ class TestBaseExtensionPlugin(unittest.TestCase): ...@@ -12,4 +12,5 @@ class TestBaseExtensionPlugin(unittest.TestCase):
self.assertEqual(len(plugin.mco_factories), 1) self.assertEqual(len(plugin.mco_factories), 1)
self.assertEqual(len(plugin.ui_hooks_factories), 1) self.assertEqual(len(plugin.ui_hooks_factories), 1)
self.assertFalse(plugin.broken) self.assertFalse(plugin.broken)
self.assertEqual(plugin.error, "") self.assertEqual(plugin.error_msg, "")
self.assertEqual(plugin.error_tb, "")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment