From 34436a8ef8568a7d1b09daf5ae1926f043ec1569 Mon Sep 17 00:00:00 2001
From: Stefano Borini <sborini@enthought.com>
Date: Thu, 17 May 2018 14:31:41 +0100
Subject: [PATCH] Changed data source to new design

---
 .../data_sources/base_data_source_factory.py  | 32 ++++----
 .../tests/test_base_data_source_factory.py    | 78 ++++++++++++-------
 2 files changed, 63 insertions(+), 47 deletions(-)

diff --git a/force_bdss/data_sources/base_data_source_factory.py b/force_bdss/data_sources/base_data_source_factory.py
index 698cdb5..1ae4571 100644
--- a/force_bdss/data_sources/base_data_source_factory.py
+++ b/force_bdss/data_sources/base_data_source_factory.py
@@ -40,11 +40,11 @@ class BaseDataSourceFactory(ABCHasStrictTraits):
     name = Str()
 
     #: The data source to be instantiated. Define this to your DataSource
-    data_source_class = Type(BaseDataSource)
+    data_source_class = Type(BaseDataSource, allow_none=False)
 
     #: The model associated to the data source.
     #: Define this to your DataSourceModel
-    model_class = Type(BaseDataSourceModel)
+    model_class = Type(BaseDataSourceModel, allow_none=False)
 
     #: Reference to the plugin that carries this factory
     #: This is automatically set by the system. you should not define it
@@ -59,7 +59,17 @@ class BaseDataSourceFactory(ABCHasStrictTraits):
         self.model_class = self.get_model_class()
         self.name = self.get_name()
         identifier = self.get_identifier()
-        self.id = factory_id(self.plugin.id, identifier)
+        try:
+            id = factory_id(self.plugin.id, identifier)
+        except ValueError:
+            raise ValueError(
+                "Invalid identifier {} returned by "
+                "{}.get_identifier()".format(
+                    identifier,
+                    self.__class__.__name__
+                )
+            )
+        self.id = id
 
     def get_data_source_class(self):
         """Must be reimplemented to return the DataSource class.
@@ -89,7 +99,7 @@ class BaseDataSourceFactory(ABCHasStrictTraits):
         to be unique across the plugin data sources.
         """
         raise NotImplementedError(
-            "get_name was not implemented in factory {}".format(
+            "get_identifier was not implemented in factory {}".format(
                 self.__class__))
 
     def create_data_source(self):
@@ -101,13 +111,6 @@ class BaseDataSourceFactory(ABCHasStrictTraits):
         BaseDataSource
             The specific instance of the generated DataSource
         """
-        if self.data_source_class is None:
-            msg = ("data_source_class cannot be None in {}. Either define "
-                   "data_source_class or reimplement create_data_source on "
-                   "your factory class.".format(self.__class__.__name__))
-            log.error(msg)
-            raise RuntimeError(msg)
-
         return self.data_source_class(self)
 
     def create_model(self, model_data=None):
@@ -130,11 +133,4 @@ class BaseDataSourceFactory(ABCHasStrictTraits):
         if model_data is None:
             model_data = {}
 
-        if self.model_class is None:
-            msg = ("model_class cannot be None in {}. Either define "
-                   "model_class or reimplement create_model on your "
-                   "factory class.".format(self.__class__.__name__))
-            log.error(msg)
-            raise RuntimeError(msg)
-
         return self.model_class(self, **model_data)
diff --git a/force_bdss/data_sources/tests/test_base_data_source_factory.py b/force_bdss/data_sources/tests/test_base_data_source_factory.py
index 3215cf8..f67bcb3 100644
--- a/force_bdss/data_sources/tests/test_base_data_source_factory.py
+++ b/force_bdss/data_sources/tests/test_base_data_source_factory.py
@@ -1,5 +1,7 @@
 import unittest
 
+from traits.trait_errors import TraitError
+
 from force_bdss.data_sources.tests.test_base_data_source import DummyDataSource
 from force_bdss.data_sources.tests.test_base_data_source_model import \
     DummyDataSourceModel
@@ -17,46 +19,64 @@ from force_bdss.data_sources.base_data_source_factory import \
 
 
 class DummyDataSourceFactory(BaseDataSourceFactory):
-    id = "foo"
-
-    name = "bar"
-
-    def create_data_source(self):
-        pass
-
-    def create_model(self, model_data=None):
-        pass
+    def get_identifier(self):
+        return "foo"
 
+    def get_name(self):
+        return "bar"
 
-class DummyDataSourceFactoryFast(BaseDataSourceFactory):
-    id = "foo"
+    def get_model_class(self):
+        return DummyDataSourceModel
 
-    name = "bar"
-
-    model_class = DummyDataSourceModel
-
-    data_source_class = DummyDataSource
+    def get_data_source_class(self):
+        return DummyDataSource
 
 
 class TestBaseDataSourceFactory(unittest.TestCase):
+    def setUp(self):
+        self.plugin = mock.Mock(spec=Plugin, id="pid")
+
     def test_initialization(self):
-        factory = DummyDataSourceFactory(mock.Mock(spec=Plugin))
-        self.assertEqual(factory.id, 'foo')
+        factory = DummyDataSourceFactory(self.plugin)
+        self.assertEqual(factory.id, 'pid.factory.foo')
         self.assertEqual(factory.name, 'bar')
-
-    def test_fast_specification(self):
-        factory = DummyDataSourceFactoryFast(mock.Mock(spec=Plugin))
+        self.assertEqual(factory.model_class, DummyDataSourceModel)
+        self.assertEqual(factory.data_source_class, DummyDataSource)
         self.assertIsInstance(factory.create_data_source(), DummyDataSource)
         self.assertIsInstance(factory.create_model(), DummyDataSourceModel)
 
-    def test_fast_specification_errors(self):
-        factory = DummyDataSourceFactoryFast(mock.Mock(spec=Plugin))
-        factory.model_class = None
-        factory.data_source_class = None
+    def test_initialization_errors_invalid_identifier(self):
+        class Broken(DummyDataSourceFactory):
+            def get_identifier(self):
+                return None
 
         with testfixtures.LogCapture():
-            with self.assertRaises(RuntimeError):
-                factory.create_data_source()
+            with self.assertRaises(ValueError):
+                Broken(self.plugin)
 
-            with self.assertRaises(RuntimeError):
-                factory.create_model()
+    def test_initialization_errors_invalid_name(self):
+        class Broken(DummyDataSourceFactory):
+            def get_name(self):
+                return None
+
+        with testfixtures.LogCapture():
+            with self.assertRaises(TraitError):
+                Broken(self.plugin)
+
+    def test_initialization_errors_invalid_model_class(self):
+        class Broken(DummyDataSourceFactory):
+            def get_model_class(self):
+                return None
+
+        with testfixtures.LogCapture():
+            with self.assertRaises(TraitError):
+                Broken(self.plugin)
+
+    def test_initialization_errors_invalid_data_source_class(self):
+        class Broken(DummyDataSourceFactory):
+            def get_data_source_class(self):
+                return None
+
+        with testfixtures.LogCapture():
+            with self.assertRaises(TraitError):
+                Broken(self.plugin)
-- 
GitLab