Skip to content

Commit da8cb1a

Browse files
committed
fixed GUI hang and plotting sub-epoch monitor updates
1 parent 48ad0d8 commit da8cb1a

File tree

1 file changed

+77
-56
lines changed

1 file changed

+77
-56
lines changed

pylearn2/train_extensions/live_monitoring.py

+77-56
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,29 @@
22
Training extension for allowing querying of monitoring values while an
33
experiment executes.
44
"""
5-
__authors__ = "Dustin Webb"
5+
__authors__ = "Dustin Webb, Adam Stone"
66
__copyright__ = "Copyright 2010-2012, Universite de Montreal"
7-
__credits__ = ["Dustin Webb"]
7+
__credits__ = ["Dustin Webb, Adam Stone"]
88
__license__ = "3-clause BSD"
99
__maintainer__ = "LISA Lab"
1010
__email__ = "pylearn-dev@googlegroups"
1111

1212
import copy
13+
import logging
14+
log = logging.getLogger('LiveMonitor')
1315

1416
try:
1517
import zmq
1618
zmq_available = True
17-
except:
19+
except Exception:
1820
zmq_available = False
1921

2022
try:
2123
from PySide import QtCore, QtGui
2224

2325
import sys
2426
import matplotlib
27+
import numpy as np
2528
matplotlib.use('Qt4Agg')
2629
matplotlib.rcParams['backend.qt4'] = 'PySide'
2730

@@ -30,13 +33,13 @@
3033
from matplotlib.figure import Figure
3134

3235
qt_available = True
33-
except:
36+
except Exception:
3437
qt_available = False
3538

3639
try:
3740
import matplotlib.pyplot as plt
3841
pyplot_available = True
39-
except:
42+
except Exception:
4043
pyplot_available = False
4144

4245
from functools import wraps
@@ -361,53 +364,48 @@ def update_channels(self, channel_list, start=-1, end=-1, step=1):
361364
chan.time_record += rsp_chan.time_record
362365
chan.val_record += rsp_chan.val_record
363366

364-
def follow_channels(self, channel_list):
367+
def follow_channels(self, channel_list, use_qt=False):
365368
"""
366369
Tracks and plots a specified set of channels in real time.
367370
368371
Parameters
369372
----------
370373
channel_list : list
371374
A list of the channels for which data has been requested.
375+
use_qt : bool
376+
Use a PySide GUI for plotting, if available.
372377
"""
373-
if not pyplot_available:
378+
if use_qt:
379+
if not qt_available:
380+
log.warning(
381+
'follow_channels called with use_qt=True, but PySide '
382+
'is not available. Falling back on matplotlib ion().')
383+
else:
384+
# only create new qt app if running the first time in session
385+
if not hasattr(self, 'gui'):
386+
self.gui = LiveMonitorGUI(self, channel_list)
387+
388+
self.gui.channel_list = channel_list
389+
self.gui.start()
390+
391+
elif not pyplot_available:
374392
raise ImportError('pyplot needs to be installed for '
375393
'this functionality.')
376-
plt.clf()
377-
plt.ion()
378-
while True:
379-
self.update_channels(channel_list)
394+
else:
380395
plt.clf()
381-
for channel_name in self.channels:
382-
plt.plot(
383-
self.channels[channel_name].epoch_record,
384-
self.channels[channel_name].val_record,
385-
label=channel_name
386-
)
387-
plt.legend()
388396
plt.ion()
389-
plt.draw()
390-
391-
def follow_channels_qt(self, channel_list):
392-
"""
393-
Tracks and plots a specified set of channels in real time using
394-
a PySide Qt GUI.
395-
396-
Parameters
397-
----------
398-
channel_list : list
399-
A list of the channels for which data has been requested.
400-
"""
401-
if not qt_available:
402-
raise ImportError('PySide needs to be installed for ' +
403-
'this functionality')
404-
405-
# only create qt app if running the first time
406-
if not hasattr(self, 'gui'):
407-
self.gui = LiveMonitorGUI(self, channel_list)
408-
409-
self.gui.channel_list = channel_list
410-
self.gui.start()
397+
while True:
398+
self.update_channels(channel_list)
399+
plt.clf()
400+
for channel_name in self.channels:
401+
plt.plot(
402+
self.channels[channel_name].epoch_record,
403+
self.channels[channel_name].val_record,
404+
label=channel_name
405+
)
406+
plt.legend()
407+
plt.ion()
408+
plt.draw()
411409

412410
if qt_available:
413411
class LiveMonitorGUI(QtGui.QMainWindow):
@@ -428,35 +426,58 @@ def __init__(self, lm, channel_list):
428426
super(LiveMonitorGUI, self).__init__()
429427
self.lm = lm
430428
self.channel_list = channel_list
429+
self.updaterThread = UpdaterThread(lm, channel_list)
430+
self.updaterThread.updated.connect(self.refresh)
431431
self.initUI()
432432

433433
def initUI(self):
434-
self.fig = Figure(figsize=(600, 600), dpi=72,
434+
self.resize(300, 200)
435+
self.fig = Figure(figsize=(300, 200), dpi=72,
435436
facecolor=(1, 1, 1), edgecolor=(0, 0, 0))
436437
self.ax = self.fig.add_subplot(111)
437438
self.canvas = FigureCanvas(self.fig)
438439
self.setCentralWidget(self.canvas)
439440

440-
def update(self):
441-
self.lm.update_channels(self.channel_list)
441+
def refresh(self):
442442
self.ax.cla() # clear previous plot
443443
for channel_name in self.channel_list:
444-
self.ax.plot(
445-
self.lm.channels[channel_name].epoch_record,
446-
self.lm.channels[channel_name].val_record,
447-
label=channel_name
448-
)
444+
445+
X = epoch_record = self.lm.channels[channel_name].epoch_record
446+
Y = val_record = self.lm.channels[channel_name].val_record
447+
448+
indices = np.nonzero(np.diff(epoch_record))[0] + 1
449+
epoch_record_split = np.split(epoch_record, indices)
450+
val_record_split = np.split(val_record, indices)
451+
452+
X = np.zeros(len(epoch_record))
453+
Y = np.zeros(len(epoch_record))
454+
455+
for i, epoch in enumerate(epoch_record_split):
456+
457+
j = i*len(epoch_record_split[0])
458+
X[j: j + len(epoch)] = (
459+
1.*np.arange(len(epoch)) / len(epoch) + epoch[0])
460+
Y[j: j + len(epoch)] = val_record_split[i]
461+
462+
self.ax.plot(X, Y, label=channel_name)
463+
449464
self.ax.legend()
450465
self.canvas.draw()
451-
452-
def closeEvent(self, event):
453-
self.updateTimer.stop()
454-
event.accept()
466+
self.updaterThread.start()
455467

456468
def start(self):
457-
self.updateTimer = QtCore.QTimer(self)
458-
self.updateTimer.timeout.connect(self.update)
459-
self.updateTimer.start(10000)
460469
self.show()
461-
self.update()
470+
self.updaterThread.start()
462471
self.app.exec_()
472+
473+
class UpdaterThread(QtCore.QThread):
474+
updated = QtCore.Signal()
475+
476+
def __init__(self, lm, channel_list):
477+
super(UpdaterThread, self).__init__()
478+
self.lm = lm
479+
self.channel_list = channel_list
480+
481+
def run(self):
482+
self.lm.update_channels(self.channel_list) # blocking
483+
self.updated.emit()

0 commit comments

Comments
 (0)