diff --git a/force_bdss/data_sources/base_data_source_model.py b/force_bdss/data_sources/base_data_source_model.py index d210f9492fb32518916802200028e273b286fd72..a5df537c9faf7a7c8768dfc7d9e4c5b7a788ad97 100644 --- a/force_bdss/data_sources/base_data_source_model.py +++ b/force_bdss/data_sources/base_data_source_model.py @@ -1,4 +1,6 @@ -from traits.api import ABCHasStrictTraits, Instance, List, Event +from traits.api import ( + ABCHasStrictTraits, Instance, List, Event, on_trait_change +) from force_bdss.core.input_slot_info import InputSlotInfo from force_bdss.core.output_slot_info import OutputSlotInfo @@ -49,3 +51,10 @@ class BaseDataSourceModel(ABCHasStrictTraits): x.__getstate__() for x in self.output_slot_info ] return state + + @on_trait_change("+changes_slots") + def _trigger_changes_slots(self, obj, name, new): + changes_slots = self.traits()[name].changes_slots + + if changes_slots: + self.changes_slots = True diff --git a/force_bdss/data_sources/tests/test_base_data_source_model.py b/force_bdss/data_sources/tests/test_base_data_source_model.py index a9be54b713f94d073c2659aa001a5000758b41f9..dde26ea0cd99e43c66747f793842acdb468445b8 100644 --- a/force_bdss/data_sources/tests/test_base_data_source_model.py +++ b/force_bdss/data_sources/tests/test_base_data_source_model.py @@ -1,7 +1,10 @@ import unittest +from traits.api import Int +from traits.testing.api import UnittestTools from force_bdss.core.input_slot_info import InputSlotInfo from force_bdss.core.output_slot_info import OutputSlotInfo +from force_bdss.data_sources.base_data_source_model import BaseDataSourceModel from force_bdss.tests.dummy_classes.data_source import DummyDataSourceModel try: @@ -13,9 +16,18 @@ from force_bdss.data_sources.base_data_source_factory import \ BaseDataSourceFactory -class TestBaseDataSourceModel(unittest.TestCase): +class ChangesSlotsModel(BaseDataSourceModel): + a = Int() + b = Int(changes_slots=True) + c = Int(changes_slots=False) + + +class TestBaseDataSourceModel(unittest.TestCase, UnittestTools): + def setUp(self): + self.mock_factory = mock.Mock(spec=BaseDataSourceFactory) + def test_getstate(self): - model = DummyDataSourceModel(mock.Mock(spec=BaseDataSourceFactory)) + model = DummyDataSourceModel(self.mock_factory) self.assertEqual( model.__getstate__(), { @@ -64,3 +76,15 @@ class TestBaseDataSourceModel(unittest.TestCase): } ] }) + + def test_changes_slots(self): + model = ChangesSlotsModel(self.mock_factory) + + with self.assertTraitDoesNotChange(model, "changes_slots"): + model.a = 5 + + with self.assertTraitChanges(model, "changes_slots"): + model.b = 5 + + with self.assertTraitDoesNotChange(model, "changes_slots"): + model.c = 5