diff --git a/force_bdss/data_sources/base_data_source_factory.py b/force_bdss/data_sources/base_data_source_factory.py index 698cdb5755730ad719895a3b2767c2790aa31f3b..1ae4571e38d8b729b25942ab9f9f033462ea9e0d 100644 --- a/force_bdss/data_sources/base_data_source_factory.py +++ b/force_bdss/data_sources/base_data_source_factory.py @@ -40,11 +40,11 @@ class BaseDataSourceFactory(ABCHasStrictTraits): name = Str() #: The data source to be instantiated. Define this to your DataSource - data_source_class = Type(BaseDataSource) + data_source_class = Type(BaseDataSource, allow_none=False) #: The model associated to the data source. #: Define this to your DataSourceModel - model_class = Type(BaseDataSourceModel) + model_class = Type(BaseDataSourceModel, allow_none=False) #: Reference to the plugin that carries this factory #: This is automatically set by the system. you should not define it @@ -59,7 +59,17 @@ class BaseDataSourceFactory(ABCHasStrictTraits): self.model_class = self.get_model_class() self.name = self.get_name() identifier = self.get_identifier() - self.id = factory_id(self.plugin.id, identifier) + try: + id = factory_id(self.plugin.id, identifier) + except ValueError: + raise ValueError( + "Invalid identifier {} returned by " + "{}.get_identifier()".format( + identifier, + self.__class__.__name__ + ) + ) + self.id = id def get_data_source_class(self): """Must be reimplemented to return the DataSource class. @@ -89,7 +99,7 @@ class BaseDataSourceFactory(ABCHasStrictTraits): to be unique across the plugin data sources. """ raise NotImplementedError( - "get_name was not implemented in factory {}".format( + "get_identifier was not implemented in factory {}".format( self.__class__)) def create_data_source(self): @@ -101,13 +111,6 @@ class BaseDataSourceFactory(ABCHasStrictTraits): BaseDataSource The specific instance of the generated DataSource """ - if self.data_source_class is None: - msg = ("data_source_class cannot be None in {}. Either define " - "data_source_class or reimplement create_data_source on " - "your factory class.".format(self.__class__.__name__)) - log.error(msg) - raise RuntimeError(msg) - return self.data_source_class(self) def create_model(self, model_data=None): @@ -130,11 +133,4 @@ class BaseDataSourceFactory(ABCHasStrictTraits): if model_data is None: model_data = {} - if self.model_class is None: - msg = ("model_class cannot be None in {}. Either define " - "model_class or reimplement create_model on your " - "factory class.".format(self.__class__.__name__)) - log.error(msg) - raise RuntimeError(msg) - return self.model_class(self, **model_data) diff --git a/force_bdss/data_sources/tests/test_base_data_source_factory.py b/force_bdss/data_sources/tests/test_base_data_source_factory.py index 3215cf8b908f6285a9caf2a643a0433765697e02..f67bcb35933d8a04df6ded5b5d67cdf204e3f01c 100644 --- a/force_bdss/data_sources/tests/test_base_data_source_factory.py +++ b/force_bdss/data_sources/tests/test_base_data_source_factory.py @@ -1,5 +1,7 @@ import unittest +from traits.trait_errors import TraitError + from force_bdss.data_sources.tests.test_base_data_source import DummyDataSource from force_bdss.data_sources.tests.test_base_data_source_model import \ DummyDataSourceModel @@ -17,46 +19,64 @@ from force_bdss.data_sources.base_data_source_factory import \ class DummyDataSourceFactory(BaseDataSourceFactory): - id = "foo" - - name = "bar" - - def create_data_source(self): - pass - - def create_model(self, model_data=None): - pass + def get_identifier(self): + return "foo" + def get_name(self): + return "bar" -class DummyDataSourceFactoryFast(BaseDataSourceFactory): - id = "foo" + def get_model_class(self): + return DummyDataSourceModel - name = "bar" - - model_class = DummyDataSourceModel - - data_source_class = DummyDataSource + def get_data_source_class(self): + return DummyDataSource class TestBaseDataSourceFactory(unittest.TestCase): + def setUp(self): + self.plugin = mock.Mock(spec=Plugin, id="pid") + def test_initialization(self): - factory = DummyDataSourceFactory(mock.Mock(spec=Plugin)) - self.assertEqual(factory.id, 'foo') + factory = DummyDataSourceFactory(self.plugin) + self.assertEqual(factory.id, 'pid.factory.foo') self.assertEqual(factory.name, 'bar') - - def test_fast_specification(self): - factory = DummyDataSourceFactoryFast(mock.Mock(spec=Plugin)) + self.assertEqual(factory.model_class, DummyDataSourceModel) + self.assertEqual(factory.data_source_class, DummyDataSource) self.assertIsInstance(factory.create_data_source(), DummyDataSource) self.assertIsInstance(factory.create_model(), DummyDataSourceModel) - def test_fast_specification_errors(self): - factory = DummyDataSourceFactoryFast(mock.Mock(spec=Plugin)) - factory.model_class = None - factory.data_source_class = None + def test_initialization_errors_invalid_identifier(self): + class Broken(DummyDataSourceFactory): + def get_identifier(self): + return None with testfixtures.LogCapture(): - with self.assertRaises(RuntimeError): - factory.create_data_source() + with self.assertRaises(ValueError): + Broken(self.plugin) - with self.assertRaises(RuntimeError): - factory.create_model() + def test_initialization_errors_invalid_name(self): + class Broken(DummyDataSourceFactory): + def get_name(self): + return None + + with testfixtures.LogCapture(): + with self.assertRaises(TraitError): + Broken(self.plugin) + + def test_initialization_errors_invalid_model_class(self): + class Broken(DummyDataSourceFactory): + def get_model_class(self): + return None + + with testfixtures.LogCapture(): + with self.assertRaises(TraitError): + Broken(self.plugin) + + def test_initialization_errors_invalid_data_source_class(self): + class Broken(DummyDataSourceFactory): + def get_data_source_class(self): + return None + + with testfixtures.LogCapture(): + with self.assertRaises(TraitError): + Broken(self.plugin)