Skip to content

Commit bfd1d94

Browse files
committed
added PySide GUI implementation for live updating plot and plotting sub-epoch monitor updates
1 parent 2ea428f commit bfd1d94

File tree

1 file changed

+128
-18
lines changed

1 file changed

+128
-18
lines changed

pylearn2/train_extensions/live_monitoring.py

+128-18
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,44 @@
22
Training extension for allowing querying of monitoring values while an
33
experiment executes.
44
"""
5-
__authors__ = "Dustin Webb"
5+
__authors__ = "Dustin Webb, Adam Stone"
66
__copyright__ = "Copyright 2010-2012, Universite de Montreal"
7-
__credits__ = ["Dustin Webb"]
7+
__credits__ = ["Dustin Webb, Adam Stone"]
88
__license__ = "3-clause BSD"
99
__maintainer__ = "LISA Lab"
1010
__email__ = "pylearn-dev@googlegroups"
1111

1212
import copy
13+
import logging
14+
log = logging.getLogger('LiveMonitor')
1315

1416
try:
1517
import zmq
1618
zmq_available = True
17-
except:
19+
except Exception:
1820
zmq_available = False
1921

22+
try:
23+
from PySide import QtCore, QtGui
24+
25+
import sys
26+
import matplotlib
27+
import numpy as np
28+
matplotlib.use('Qt4Agg')
29+
matplotlib.rcParams['backend.qt4'] = 'PySide'
30+
31+
from matplotlib.backends.backend_qt4agg import (
32+
FigureCanvasQTAgg as FigureCanvas)
33+
from matplotlib.figure import Figure
34+
35+
qt_available = True
36+
except Exception:
37+
qt_available = False
38+
2039
try:
2140
import matplotlib.pyplot as plt
2241
pyplot_available = True
23-
except:
42+
except Exception:
2443
pyplot_available = False
2544

2645
from functools import wraps
@@ -345,29 +364,120 @@ def update_channels(self, channel_list, start=-1, end=-1, step=1):
345364
chan.time_record += rsp_chan.time_record
346365
chan.val_record += rsp_chan.val_record
347366

348-
def follow_channels(self, channel_list):
367+
def follow_channels(self, channel_list, use_qt=False):
349368
"""
350369
Tracks and plots a specified set of channels in real time.
351370
352371
Parameters
353372
----------
354373
channel_list : list
355374
A list of the channels for which data has been requested.
375+
use_qt : bool
376+
Use a PySide GUI for plotting, if available.
356377
"""
357-
if not pyplot_available:
378+
if use_qt:
379+
if not qt_available:
380+
log.warning(
381+
'follow_channels called with use_qt=True, but PySide '
382+
'is not available. Falling back on matplotlib ion().')
383+
else:
384+
# only create new qt app if running the first time in session
385+
if not hasattr(self, 'gui'):
386+
self.gui = LiveMonitorGUI(self, channel_list)
387+
388+
self.gui.channel_list = channel_list
389+
self.gui.start()
390+
391+
elif not pyplot_available:
358392
raise ImportError('pyplot needs to be installed for '
359393
'this functionality.')
360-
plt.clf()
361-
plt.ion()
362-
while True:
363-
self.update_channel(channel_list)
394+
else:
364395
plt.clf()
365-
for channel_name in self.channels:
366-
plt.plot(
367-
self.channels[channel_name].epoch_record,
368-
self.channels[channel_name].val_record,
369-
label=channel_name
370-
)
371-
plt.legend()
372396
plt.ion()
373-
plt.draw()
397+
while True:
398+
self.update_channels(channel_list)
399+
plt.clf()
400+
for channel_name in self.channels:
401+
plt.plot(
402+
self.channels[channel_name].epoch_record,
403+
self.channels[channel_name].val_record,
404+
label=channel_name
405+
)
406+
plt.legend()
407+
plt.ion()
408+
plt.draw()
409+
410+
if qt_available:
411+
class LiveMonitorGUI(QtGui.QMainWindow):
412+
def __init__(self, lm, channel_list):
413+
"""
414+
PySide GUI implementation for live monitoring channels.
415+
416+
Parameters
417+
----------
418+
lm : LiveMonitor instance
419+
The LiveMonitor instance to which the GUI belongs.
420+
421+
channel_list : list
422+
A list of the channels to display.
423+
"""
424+
self.app = QtGui.QApplication(sys.argv)
425+
426+
super(LiveMonitorGUI, self).__init__()
427+
self.lm = lm
428+
self.channel_list = channel_list
429+
self.updaterThread = UpdaterThread(lm, channel_list)
430+
self.updaterThread.updated.connect(self.refresh)
431+
self.initUI()
432+
433+
def initUI(self):
434+
self.resize(300, 200)
435+
self.fig = Figure(figsize=(300, 200), dpi=72,
436+
facecolor=(1, 1, 1), edgecolor=(0, 0, 0))
437+
self.ax = self.fig.add_subplot(111)
438+
self.canvas = FigureCanvas(self.fig)
439+
self.setCentralWidget(self.canvas)
440+
441+
def refresh(self):
442+
self.ax.cla() # clear previous plot
443+
for channel_name in self.channel_list:
444+
445+
X = epoch_record = self.lm.channels[channel_name].epoch_record
446+
Y = val_record = self.lm.channels[channel_name].val_record
447+
448+
indices = np.nonzero(np.diff(epoch_record))[0] + 1
449+
epoch_record_split = np.split(epoch_record, indices)
450+
val_record_split = np.split(val_record, indices)
451+
452+
X = np.zeros(len(epoch_record))
453+
Y = np.zeros(len(epoch_record))
454+
455+
for i, epoch in enumerate(epoch_record_split):
456+
457+
j = i*len(epoch_record_split[0])
458+
X[j: j + len(epoch)] = (
459+
1.*np.arange(len(epoch)) / len(epoch) + epoch[0])
460+
Y[j: j + len(epoch)] = val_record_split[i]
461+
462+
self.ax.plot(X, Y, label=channel_name)
463+
464+
self.ax.legend()
465+
self.canvas.draw()
466+
self.updaterThread.start()
467+
468+
def start(self):
469+
self.show()
470+
self.updaterThread.start()
471+
self.app.exec_()
472+
473+
class UpdaterThread(QtCore.QThread):
474+
updated = QtCore.Signal()
475+
476+
def __init__(self, lm, channel_list):
477+
super(UpdaterThread, self).__init__()
478+
self.lm = lm
479+
self.channel_list = channel_list
480+
481+
def run(self):
482+
self.lm.update_channels(self.channel_list) # blocking
483+
self.updated.emit()

0 commit comments

Comments
 (0)