diff --git a/force_bdss/data_sources/base_data_source_model.py b/force_bdss/data_sources/base_data_source_model.py index d210f9492fb32518916802200028e273b286fd72..97a105028497ca9d878710a7fdb4ea97df4008b8 100644 --- a/force_bdss/data_sources/base_data_source_model.py +++ b/force_bdss/data_sources/base_data_source_model.py @@ -1,4 +1,6 @@ -from traits.api import ABCHasStrictTraits, Instance, List, Event +from traits.api import ( + ABCHasStrictTraits, Instance, List, Event, on_trait_change +) from force_bdss.core.input_slot_info import InputSlotInfo from force_bdss.core.output_slot_info import OutputSlotInfo @@ -49,3 +51,13 @@ class BaseDataSourceModel(ABCHasStrictTraits): x.__getstate__() for x in self.output_slot_info ] return state + + @on_trait_change("+changes_slots") + def _trigger_changes_slots(self, obj, name, new): + try: + changes_slots = self.traits()[name].changes_slots + except AttributeError: + return + + if changes_slots: + self.changes_slots = True diff --git a/force_bdss/data_sources/tests/test_base_data_source_model.py b/force_bdss/data_sources/tests/test_base_data_source_model.py index a9be54b713f94d073c2659aa001a5000758b41f9..893f730cc251a33661fe1474eabdafd9bb8f2a55 100644 --- a/force_bdss/data_sources/tests/test_base_data_source_model.py +++ b/force_bdss/data_sources/tests/test_base_data_source_model.py @@ -1,7 +1,10 @@ import unittest +from traits.api import Int +from traits.testing.api import UnittestTools from force_bdss.core.input_slot_info import InputSlotInfo from force_bdss.core.output_slot_info import OutputSlotInfo +from force_bdss.data_sources.base_data_source_model import BaseDataSourceModel from force_bdss.tests.dummy_classes.data_source import DummyDataSourceModel try: @@ -13,9 +16,18 @@ from force_bdss.data_sources.base_data_source_factory import \ BaseDataSourceFactory -class TestBaseDataSourceModel(unittest.TestCase): +class ChangesSlotsModel(BaseDataSourceModel): + a = Int() + b = Int(changes_slots=True) + c = Int(changes_slots=False) + + +class TestBaseDataSourceModel(unittest.TestCase, UnittestTools): + def setUp(self): + self.mock_factory = mock.Mock(spec=BaseDataSourceFactory) + def test_getstate(self): - model = DummyDataSourceModel(mock.Mock(spec=BaseDataSourceFactory)) + model = DummyDataSourceModel(self.mock_factory) self.assertEqual( model.__getstate__(), { @@ -64,3 +76,16 @@ class TestBaseDataSourceModel(unittest.TestCase): } ] }) + + def test_changes_slots(self): + model = ChangesSlotsModel(self.mock_factory) + + with self.assertTraitDoesNotChange(model, "changes_slots"): + model.a = 5 + + with self.assertTraitChanges(model, "changes_slots"): + model.b = 5 + + with self.assertTraitDoesNotChange(model, "changes_slots"): + model.c = 5 +