From 0f25066253062a8e4911c2b657b16c89ec5f6854 Mon Sep 17 00:00:00 2001
From: Stefano Borini <sborini@enthought.com>
Date: Wed, 9 Aug 2017 15:14:52 +0100
Subject: [PATCH] Introduced events and their rendering

---
 force_bdss/core_mco_driver.py                 | 14 +++-------
 .../dummy/dummy_dakota/dakota_optimizer.py    | 15 +++++++----
 .../dummy_notification_listener.py            |  9 +++++--
 .../dummy/ui_notification/ui_notification.py  | 26 +++++++++++++++----
 force_bdss/mco/base_mco.py                    | 13 +++++-----
 force_bdss/mco/events.py                      | 19 ++++++++++++++
 utils/zmq_client.py                           |  8 +++---
 7 files changed, 71 insertions(+), 33 deletions(-)
 create mode 100644 force_bdss/mco/events.py

diff --git a/force_bdss/core_mco_driver.py b/force_bdss/core_mco_driver.py
index 6b33dea..36e0450 100644
--- a/force_bdss/core_mco_driver.py
+++ b/force_bdss/core_mco_driver.py
@@ -42,18 +42,10 @@ class CoreMCODriver(BaseCoreDriver):
         mco_factory = mco_model.factory
         return mco_factory.create_optimizer()
 
-    @on_trait_change("mco:started,mco:finished,mco:progress")
-    def _handle_mco_event(self, object, name, old, new):
-        if name == "started":
-            self._deliver_to_listeners("MCO_STARTED")
-        elif name == "finished":
-            self._deliver_to_listeners("MCO_FINISHED")
-        elif name == "progress":
-            self._deliver_to_listeners("MCO_PROGRESS")
-
-    def _deliver_to_listeners(self, message):
+    @on_trait_change("mco:event")
+    def _handle_mco_event(self, event):
         for listener in self.listeners:
-            listener.deliver(None, message)
+            listener.deliver(None, event)
 
     def _listeners_default(self):
         listeners = []
diff --git a/force_bdss/core_plugins/dummy/dummy_dakota/dakota_optimizer.py b/force_bdss/core_plugins/dummy/dummy_dakota/dakota_optimizer.py
index d97c414..9f9de5d 100644
--- a/force_bdss/core_plugins/dummy/dummy_dakota/dakota_optimizer.py
+++ b/force_bdss/core_plugins/dummy/dummy_dakota/dakota_optimizer.py
@@ -4,6 +4,8 @@ import itertools
 import collections
 
 from force_bdss.api import BaseMCO
+from force_bdss.mco.events import MCOStartEvent, MCOFinishEvent, \
+    MCOProgressEvent
 
 
 def rotated_range(start, stop, starting_value):
@@ -16,7 +18,8 @@ def rotated_range(start, stop, starting_value):
 
 class DummyDakotaOptimizer(BaseMCO):
     def run(self, model):
-        self.started = True
+        self.notify_event(MCOStartEvent())
+
         parameters = model.parameters
 
         values = []
@@ -41,8 +44,10 @@ class DummyDakotaOptimizer(BaseMCO):
 
             out = ps.communicate(
                 " ".join([str(v) for v in value]).encode("utf-8"))
-            print("{}: {}".format(" ".join([str(v) for v in value]),
-                                  out[0].decode("utf-8")))
-            self.progress = True
+            out_data = out[0].decode("utf-8").split()
+            self.notify_event(MCOProgressEvent(
+                input=tuple(value),
+                output=tuple(out_data),
+            ))
 
-        self.finished = True
+        self.notify_event(MCOFinishEvent())
diff --git a/force_bdss/core_plugins/dummy/dummy_notification_listener/dummy_notification_listener.py b/force_bdss/core_plugins/dummy/dummy_notification_listener/dummy_notification_listener.py
index 27fe660..6efc45a 100644
--- a/force_bdss/core_plugins/dummy/dummy_notification_listener/dummy_notification_listener.py
+++ b/force_bdss/core_plugins/dummy/dummy_notification_listener/dummy_notification_listener.py
@@ -1,9 +1,14 @@
 from force_bdss.api import BaseNotificationListener
+from force_bdss.mco.events import (
+    MCOStartEvent, MCOFinishEvent, MCOProgressEvent)
 
 
 class DummyNotificationListener(BaseNotificationListener):
-    def deliver(self, model, message):
-        print(message)
+    def deliver(self, model, event):
+        if isinstance(event, (MCOStartEvent, MCOFinishEvent)):
+            print(event.__class__.__name__)
+        elif isinstance(event, MCOProgressEvent):
+            print(event.__class__.__name__, event.input, event.output)
 
     def init_persistent_state(self, model):
         print("Initializing persistent state")
diff --git a/force_bdss/core_plugins/dummy/ui_notification/ui_notification.py b/force_bdss/core_plugins/dummy/ui_notification/ui_notification.py
index bbf7218..8e21584 100644
--- a/force_bdss/core_plugins/dummy/ui_notification/ui_notification.py
+++ b/force_bdss/core_plugins/dummy/ui_notification/ui_notification.py
@@ -1,10 +1,13 @@
 import errno
 import logging
-from traits.api import Any, List
+from traits.api import Any, List, Instance
 
 from force_bdss.api import BaseNotificationListener
 import zmq
 
+from force_bdss.mco.events import BaseMCOEvent, MCOStartEvent, MCOFinishEvent, \
+    MCOProgressEvent
+
 
 class UINotification(BaseNotificationListener):
     #: The ZMQ context.
@@ -20,7 +23,7 @@ class UINotification(BaseNotificationListener):
     #: The cache of messages as they are sent out.
     _msg_cache = List()
 
-    def deliver(self, model, message):
+    def deliver(self, model, event):
         try:
             data = self._rep_socket.recv(flags=zmq.NOBLOCK)
         except zmq.ZMQError as e:
@@ -33,9 +36,10 @@ class UINotification(BaseNotificationListener):
         if data and data[0:4] == "SYNC".encode("utf-8"):
             self._rep_socket.send_multipart(self._msg_cache)
 
-        msg = "ACTION {}".format(message).encode("utf-8")
-        self._msg_cache.append(msg)
-        self._pub_socket.send(msg)
+        msg = self._format_event(event)
+        if msg is not None:
+            self._msg_cache.append(msg)
+            self._pub_socket.send(msg)
 
     def init_persistent_state(self, model):
         self._context = zmq.Context()
@@ -44,3 +48,15 @@ class UINotification(BaseNotificationListener):
 
         self._rep_socket = self._context.socket(zmq.REP)
         self._rep_socket.bind("tcp://*:12346")
+
+    def _format_event(self, event):
+        if isinstance(event, MCOStartEvent):
+            data = "MCO_START"
+        elif isinstance(event, MCOFinishEvent):
+            data = "MCO_FINISH"
+        elif isinstance(event, MCOProgressEvent):
+            data = "MCO_PROGRESS {} {}".format(event.input, event.output)
+        else:
+            return None
+
+        return ("EVENT {}".format(data)).encode("utf-8")
diff --git a/force_bdss/mco/base_mco.py b/force_bdss/mco/base_mco.py
index 72b2b45..21d91bf 100644
--- a/force_bdss/mco/base_mco.py
+++ b/force_bdss/mco/base_mco.py
@@ -1,7 +1,8 @@
 import abc
 
-from traits.api import ABCHasStrictTraits, Instance, Event
+from traits.api import ABCHasStrictTraits, Instance, Event, Dict
 
+from force_bdss.mco.events import BaseMCOEvent
 from .i_mco_factory import IMCOFactory
 
 
@@ -13,11 +14,8 @@ class BaseMCO(ABCHasStrictTraits):
     #: A reference to the factory
     factory = Instance(IMCOFactory)
 
-    started = Event()
-
-    finished = Event()
-
-    progress = Event()
+    #: Must be triggered when an event occurs.
+    event = Event(BaseMCOEvent)
 
     def __init__(self, factory, *args, **kwargs):
         """Initializes the MCO.
@@ -41,3 +39,6 @@ class BaseMCO(ABCHasStrictTraits):
             An instance of the model information, as created from
             create_model()
         """
+
+    def notify_event(self, event):
+        self.event = event
diff --git a/force_bdss/mco/events.py b/force_bdss/mco/events.py
new file mode 100644
index 0000000..19a881b
--- /dev/null
+++ b/force_bdss/mco/events.py
@@ -0,0 +1,19 @@
+from traits.api import HasStrictTraits, Tuple
+
+
+class BaseMCOEvent(HasStrictTraits):
+    pass
+
+
+class MCOStartEvent(BaseMCOEvent):
+    pass
+
+
+class MCOFinishEvent(BaseMCOEvent):
+    pass
+
+
+class MCOProgressEvent(BaseMCOEvent):
+    input = Tuple()
+    output = Tuple()
+
diff --git a/utils/zmq_client.py b/utils/zmq_client.py
index 21e3687..ec54404 100644
--- a/utils/zmq_client.py
+++ b/utils/zmq_client.py
@@ -11,10 +11,10 @@ send_socket.connect("tcp://localhost:12346")
 send_socket.send("SYNC".encode("utf-8"))
 data = send_socket.recv_multipart()
 for d in data:
-    topic, messagedata = d.split()
-    print("SYNCED ", topic, messagedata)
+    split_data = d.split()
+    print("SYNCED ", split_data)
 
 while True:
     string = socket.recv()
-    topic, messagedata = string.split()
-    print(topic, messagedata)
+    split_data = string.split()
+    print(split_data)
-- 
GitLab