diff --git a/.github/workflows/CI.yaml b/.github/workflows/CI.yaml
index 0c80283d..5a6a5503 100644
--- a/.github/workflows/CI.yaml
+++ b/.github/workflows/CI.yaml
@@ -20,7 +20,7 @@ jobs:
max-parallel: 1
fail-fast: false
matrix:
- python-version: ["3.8", "3.9", "3.10"]
+ python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
os: [ubuntu-latest, macos-latest] # [ubuntu-latest, macos-latest, windows-latest]
steps:
@@ -65,7 +65,7 @@ jobs:
- name: Test with pytest
run: |
- python -m pytest --cov=improv
+ python -m pytest -x -s -l --cov=improv
- name: Coveralls
uses: coverallsapp/github-action@v2
@@ -102,6 +102,6 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Close parallel build
- uses: coverallsapp/github-action@v1
+ uses: coverallsapp/github-action@v2
with:
parallel-finished: true
diff --git a/.gitignore b/.gitignore
index b76c5666..bedf9d58 100644
--- a/.gitignore
+++ b/.gitignore
@@ -128,6 +128,9 @@ dmypy.json
arrow
.vscode/
+.idea/
*.code-workspace
improv/_version.py
+
+*venv
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 00000000..13566b81
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 00000000..105ce2da
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 00000000..69d3d2cc
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 00000000..35eb1ddf
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/demos/1p_caiman/1p_demo.yaml b/demos/1p_caiman/1p_demo.yaml
index 6f754860..e60e469a 100644
--- a/demos/1p_caiman/1p_demo.yaml
+++ b/demos/1p_caiman/1p_demo.yaml
@@ -34,6 +34,3 @@ connections:
Processor.q_out: [Analysis.q_in]
Analysis.q_out: [Visual.q_in]
InputStim.q_out: [Analysis.input_stim_queue]
-
-# settings:
-# use_watcher: [Acquirer, Processor, Visual, Analysis]
\ No newline at end of file
diff --git a/demos/1p_caiman/actors/1p_processor.py b/demos/1p_caiman/actors/1p_processor.py
index 38305944..912290cd 100644
--- a/demos/1p_caiman/actors/1p_processor.py
+++ b/demos/1p_caiman/actors/1p_processor.py
@@ -58,13 +58,13 @@ def setup(self):
def stop(self):
super().stop()
- def runStep(self):
+ def run_step(self):
"""Run process. Runs once per frame.
Output is a location in the DS to continually
place the Estimates results, with ref number that
corresponds to the frame number
"""
- super().runStep()
+ super().run_step()
def putEstimates(self):
"""Put whatever estimates we currently have
diff --git a/demos/basic/Behavior_demo.py b/demos/basic/Behavior_demo.py
deleted file mode 100644
index 8117189c..00000000
--- a/demos/basic/Behavior_demo.py
+++ /dev/null
@@ -1,23 +0,0 @@
-import logging
-
-# Matplotlib is overly verbose by default
-logging.getLogger("matplotlib").setLevel(logging.WARNING)
-from improv.nexus import Nexus
-
-loadFile = "./Behavior_demo.yaml"
-
-nexus = Nexus("Nexus")
-nexus.createNexus(file=loadFile)
-
-# All modules needed have been imported
-# so we can change the level of logging here
-# import logging
-# import logging.config
-# logging.config.dictConfig({
-# 'version': 1,
-# 'disable_existing_loggers': True,
-# })
-# logger = logging.getLogger("improv")
-# logger.setLevel(logging.INFO)
-
-nexus.startNexus()
diff --git a/demos/basic/Behavior_demo.yaml b/demos/basic/Behavior_demo.yaml
index 8f105507..992f1528 100644
--- a/demos/basic/Behavior_demo.yaml
+++ b/demos/basic/Behavior_demo.yaml
@@ -1,6 +1,3 @@
-settings:
- use_watcher: [Acquirer, Processor, Visual, Analysis, Behavior, Motion]
-
actors:
GUI:
package: actors.visual
diff --git a/demos/basic/actors/basic_processor.py b/demos/basic/actors/basic_processor.py
index 55013b37..e5568d66 100644
--- a/demos/basic/actors/basic_processor.py
+++ b/demos/basic/actors/basic_processor.py
@@ -61,7 +61,7 @@ def stop(self):
np.savetxt("output/timing/shape_time.txt", self.shape_time)
np.savetxt("output/timing/detect_time.txt", self.detect_time)
- def runStep(self):
+ def run_step(self):
"""Run process. Runs once per frame.
Output is a location in the DS to continually
place the Estimates results, with ref number that
diff --git a/demos/basic/actors/behavior.py b/demos/basic/actors/behavior.py
index 2486b2be..0e4b14e3 100644
--- a/demos/basic/actors/behavior.py
+++ b/demos/basic/actors/behavior.py
@@ -41,7 +41,7 @@ def setup(self):
else:
raise FileNotFoundError
- def runStep(self):
+ def run_step(self):
if self.done:
pass
elif self.frame_num < self.data.shape[2]:
@@ -95,7 +95,7 @@ def setup(self):
else:
raise FileNotFoundError
- def runStep(self):
+ def run_step(self):
if self.done:
pass
elif self.frame_num < len(self.data):
diff --git a/demos/basic/basic_demo.py b/demos/basic/basic_demo.py
deleted file mode 100644
index 953948ef..00000000
--- a/demos/basic/basic_demo.py
+++ /dev/null
@@ -1,23 +0,0 @@
-import logging
-
-# Matplotlib is overly verbose by default
-logging.getLogger("matplotlib").setLevel(logging.WARNING)
-from improv.nexus import Nexus
-
-loadFile = "./basic_demo.yaml"
-
-nexus = Nexus("Nexus")
-nexus.createNexus(file=loadFile)
-
-# All modules needed have been imported
-# so we can change the level of logging here
-# import logging
-# import logging.config
-# logging.config.dictConfig({
-# 'version': 1,
-# 'disable_existing_loggers': True,
-# })
-# logger = logging.getLogger("improv")
-# logger.setLevel(logging.INFO)
-
-nexus.startNexus()
diff --git a/demos/basic/basic_demo.yaml b/demos/basic/basic_demo.yaml
index b007147c..1736bd64 100644
--- a/demos/basic/basic_demo.yaml
+++ b/demos/basic/basic_demo.yaml
@@ -33,6 +33,3 @@ connections:
Processor.q_out: [Analysis.q_in]
Analysis.q_out: [Visual.q_in]
InputStim.q_out: [Analysis.input_stim_queue]
-
-# settings:
-# use_watcher: [Acquirer, Processor, Visual, Analysis]
\ No newline at end of file
diff --git a/demos/bubblewrap/README.md b/demos/bubblewrap/README.md
index 5b570d24..6889d6b5 100644
--- a/demos/bubblewrap/README.md
+++ b/demos/bubblewrap/README.md
@@ -4,13 +4,13 @@ After installing _improv_, install additional dependencies (JAX on CPU by defaul
- `pip install -r requirements.txt`
-To download the sample data from the paper, do:
+To download the sample data from the paper, run the following from the `improv` directory:
- `python demos/bubblewrap/actors/utils.py`
-This may take a few minutes. After data is downloaded, run the GUI with:
+This may take a few minutes. After data is downloaded, run the demo with:
-- `python demos/bubblewrap/bubble_demo.py`
+- `improv run demos/bubblewrap/bubble_demo.yaml`
A GUI will pop up with two buttons named "setup" and "run". First hit "setup" and wait ~5 seconds, then hit "run". Bubblewrap will perform dimensionality reduction of ~180 neurons to 2 dimensions, represented by grey dots popping up on the plot, and coarsely tile the space with red bubbles to represent transitions in the low-dimension space. All in real-time!
diff --git a/demos/bubblewrap/actors/acquire.py b/demos/bubblewrap/actors/acquire.py
index ca7ed53e..53435f67 100644
--- a/demos/bubblewrap/actors/acquire.py
+++ b/demos/bubblewrap/actors/acquire.py
@@ -4,17 +4,42 @@
import time
import logging
import traceback
-import time
+import sys, os
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
+def load_data(filename):
+ """
+ Load filename, which may be a relative path.
+ Return data.
+ """
+ if os.path.isfile(filename):
+ full_fname = filename
+ else:
+ for p in reversed(sys.path):
+ # traverse path in reverse order, on the theory that
+ # the yaml file comes near the end
+ full_fname = os.path.join(p, filename)
+ if os.path.isfile(full_fname):
+ break
+
+ try:
+ data = mat73.loadmat(full_fname)
+ except FileNotFoundError:
+ logger.error("Bubblewrap data file not found!")
+
+ return data
+
+
+
class Acquirer(Actor):
def __init__(self, *args, filename=None, **kwargs):
super().__init__(*args, **kwargs)
- if not filename: logger.error('Error: Filename not specified')
+ if not filename:
+ logger.error('Error: Filename not specified')
self.file = filename
self.frame_num = 0
self.done = False
@@ -28,7 +53,7 @@ def setup(self):
Note: A utility function that downloads the required data file can be found in utils.py
"""
# get unsorted vs sorted units
- data_dict = mat73.loadmat(self.file)
+ data_dict = load_data(self.file)
units_unsorted = []
units_sorted = []
for ch_curr in data_dict['spikes']:
@@ -58,7 +83,7 @@ def setup(self):
self.num_iters = np.floor((self.data.shape[0] - l1 - self.l)/self.l).astype('int')
#send to dim reduction
- init_id = self.client.put([self.data.shape[0], self.data[:l1, :]], "init_data")
+ init_id = self.client.put([self.data.shape[0], self.data[:l1, :]])
logger.info("Putted init data")
self.q_out.put(init_id)
@@ -66,7 +91,7 @@ def stop(self):
logger.info(f"Stopped running Acquire, avg time per frame: {np.mean(self.total_times)}")
logger.info(f"Acquire got through {self.frame_num} frames")
- def runStep(self):
+ def run_step(self):
"""Send data to dim reduction one frame at a time"""
if self.done:
pass
@@ -74,7 +99,7 @@ def runStep(self):
start, end = self.t, self.t + 1
frame = self.data[start:end, :]
t = time.time()
- id = self.client.put([self.t, frame], "acq_bubble" + str(self.frame_num))
+ id = self.client.put([self.t, frame])
self.timestamp.append([time.time(), self.frame_num])
try:
self.q_out.put([str(self.frame_num), id])
diff --git a/demos/bubblewrap/actors/bubble.py b/demos/bubblewrap/actors/bubble.py
index c63c1e3c..c2ecd7e5 100644
--- a/demos/bubblewrap/actors/bubble.py
+++ b/demos/bubblewrap/actors/bubble.py
@@ -4,7 +4,7 @@
from bubblewrap import Bubblewrap
from improv.actor import Actor
import logging
-
+logging.getLogger("jax").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
@@ -21,7 +21,10 @@ def setup(self):
while shape_id is None:
try:
shape_id = self.q_in.get(timeout=0.0005)
- except Empty: pass
+ except Empty:
+ pass
+ except TimeoutError:
+ pass
dat_shape_0 = self.client.get(shape_id)
# init bubblewrap
M = 20
@@ -47,23 +50,24 @@ def setup(self):
while id is None:
try:
id = self.q_in.get(timeout = 0.0005)
- except Empty: pass
- init_data = self.client.getID(id)
+ except Empty:
+ pass
+ except TimeoutError:
+ pass
+ init_data = self.client.get(id)
for i in np.arange(0, M):
self.bw.observe(init_data[i])
self.bw.init_nodes()
logger.info("Nodes initialized")
- self._getStoreInterface()
-
- def runStep(self):
+ def run_step(self):
"""Observe new data from dim reduction and update bubblewrap"""
try:
ids = self.q_in.get(timeout=0.0005)
# expect ids of size 2 containing data location and frame number
- new_data = self.client.getID(ids[1])
+ new_data = self.client.get(ids[1])
self.frame_number = ids[0]
self.bw.observe(new_data)
@@ -72,19 +76,17 @@ def runStep(self):
self.putOutput()
except Empty:
pass
+ except TimeoutError:
+ pass
def putOutput(self):
"""Function for putting updated results into the store"""
ids = []
- ids.append(self.client.put(np.array(self.bw.A), "A" + str(self.frame_number)))
- ids.append(self.client.put(np.array(self.bw.L), "L" + str(self.frame_number)))
- ids.append(self.client.put(np.array(self.bw.mu), "mu" + str(self.frame_number)))
- ids.append(self.client.put(
- np.array(self.bw.n_obs), "n_obs" + str(self.frame_number)))
- ids.append(self.client.put(
- np.array(self.bw.pred), "pred" + str(self.frame_number)))
- ids.append(self.client.put(
- np.array(self.bw.entropy_list), "entropy" + str(self.frame_number)))
- ids.append(self.client.put(
- np.array(self.bw.dead_nodes), "dead_nodes" + str(self.frame_number)))
+ ids.append(self.client.put(np.array(self.bw.A)))
+ ids.append(self.client.put(np.array(self.bw.L)))
+ ids.append(self.client.put(np.array(self.bw.mu)))
+ ids.append(self.client.put(np.array(self.bw.n_obs)))
+ ids.append(self.client.put(np.array(self.bw.pred)))
+ ids.append(self.client.put(np.array(self.bw.entropy_list)))
+ ids.append(self.client.put(np.array(self.bw.dead_nodes)))
self.q_out.put([self.frame_number, ids])
diff --git a/demos/bubblewrap/actors/dimension_reduction.py b/demos/bubblewrap/actors/dimension_reduction.py
index 2af87fca..7dd09089 100644
--- a/demos/bubblewrap/actors/dimension_reduction.py
+++ b/demos/bubblewrap/actors/dimension_reduction.py
@@ -1,6 +1,6 @@
from improv.actor import Actor
import numpy as np
-import scipy.signal as signal
+import scipy.signal.windows as signal
from proSVD import proSVD
from queue import Empty
import logging
@@ -20,13 +20,16 @@ def setup(self):
while init_id is None:
try:
init_id = self.q_in.get(timeout = 0.0005)
- except Empty: pass
+ except Empty:
+ pass
+ except TimeoutError:
+ pass
logger.info("Got init data")
- my_list = self.client.getID(init_id)
+ my_list = self.client.get(init_id)
dat_shape_0 = my_list[0]
dat_init = np.array(my_list[1])
- bw_id = self.client.put(dat_shape_0, "dat_shape_bw")
+ bw_id = self.client.put(dat_shape_0)
self.q_out.put(bw_id)
# proSVD params
@@ -51,18 +54,18 @@ def setup(self):
# storing dimension-reduced data
self.data_red = np.zeros((dat_shape_0, k))
self.data_red[:l1, :] = data_init_smooth @ self.pro.Q
- bw_id = self.client.put(self.data_red[:M], "bw_data")
+ bw_id = self.client.put(self.data_red[:M])
#send to bubblewrap
self.q_out.put(bw_id)
self.pro_diffs = []
self.smooth_window = dat_init[l1-len(self.smooth_filt):l1, :]
- def runStep(self):
+ def run_step(self):
"""update proSVD at each step using data from Acquirer and send to bubblewrap"""
try:
res = self.q_in.get(timeout=0.0005)
- data_curr = self.client.getID(res[1])[1]
- self.t = self.client.getID(res[1])[0]
+ data_curr = self.client.get(res[1])[1]
+ self.t = self.client.get(res[1])[0]
start, end = self.t, self.t+self.pro.w_len
self.smooth_window[:-1, :] = self.smooth_window[1:, :]
self.smooth_window[-1, :] = data_curr
@@ -76,7 +79,7 @@ def runStep(self):
self.data_red[start:end, :] = dat_smooth @ self.pro.Q
# send to bubblewrap
try:
- id = self.client.put(self.data_red[self.t], "dim_bubble" + str(self.t))
+ id = self.client.put(self.data_red[self.t])
self.q_out.put([int(self.t), id])
self.links['v_out'].put([int(self.t), id])
except Exception as e:
@@ -84,4 +87,6 @@ def runStep(self):
logger.error(traceback.format_exc())
except Empty:
return None
+ except TimeoutError:
+ return None
diff --git a/demos/bubblewrap/actors/front_end.py b/demos/bubblewrap/actors/front_end.py
index 9034c69b..0d8b430d 100644
--- a/demos/bubblewrap/actors/front_end.py
+++ b/demos/bubblewrap/actors/front_end.py
@@ -4,26 +4,20 @@
from PyQt5.QtGui import QColor
from . import improv_bubble
from improv.actor import Signal
-from math import atan2, floor
+from math import atan2
from PyQt5.QtWidgets import QMessageBox
-import logging
import traceback
-logger = logging.getLogger(__name__)
-logger.setLevel(logging.INFO)
-
class FrontEnd(QtWidgets.QMainWindow, improv_bubble.Ui_MainWindow):
- def __init__(self, visual, comm, q_sig, parent=None):
+ def __init__(self, state, parent=None):
"""Setup GUI
Setup and start Nexus controls
"""
- logger.info("Setup and start Nexus controls")
- self.visual = visual
- self.comm = comm # Link back to Nexus for transmitting signals
- self.q_sig = q_sig
+ state.logger.info("Setup and start Nexus controls")
+ self.state = state
self.prev = 0
self.n = 300
@@ -51,35 +45,41 @@ def __init__(self, visual, comm, q_sig, parent=None):
self.pushButton_2.clicked.connect(_call(self._runProcess))
self.pushButton_2.clicked.connect(_call(self.update)) # Tell Nexus to start
+ # Check for Nexus input on a timer
+ self.timer = QtCore.QTimer(self)
+ self.timer.timeout.connect(state.signal_check)
+ self.timer.setInterval(10) # Call every 10 ms
+ self.timer.start()
+
def update(self):
"""Check if get data is successful, call plotting function and update GUI"""
try:
- if self.visual.getData():
+ if self.state.getData():
self.plotBw()
except Exception as e:
- logger.error('Front End Exception: {}'.format(e))
- logger.error(traceback.format_exc())
+ self.state.logger.error('Front End Exception: {}'.format(e))
+ self.state.logger.error(traceback.format_exc())
QtCore.QTimer.singleShot(10, self.update)
def plotBw(self):
"""Function for plotting dim reduced trajectories and bubbles"""
self.plt.clear()
# Dim reduced data plotting
- newDat = np.array([self.visual.data[0], self.visual.data[1]])
+ newDat = np.array([self.state.data[0], self.state.data[1]])
self.data_red = np.vstack([self.data_red, newDat])
self.scatter.setData(pos=self.data_red)
self.plt.addItem(self.scatter)
# bubble plotting
- for n in np.arange(self.visual.bw_L.shape[0]):
- if n not in self.visual.bw_dead_nodes: #ignore dead nodes
- el = np.linalg.inv(self.visual.bw_L[n])
+ for n in np.arange(self.state.bw_L.shape[0]):
+ if n not in self.state.bw_dead_nodes: #ignore dead nodes
+ el = np.linalg.inv(self.state.bw_L[n])
sig = el.T @ el
u,s,v = np.linalg.svd(sig)
width, height = np.sqrt(s[0])*3, np.sqrt(s[1])*3
angle = atan2(v[0,1],v[0,0])*360 / (2*np.pi)
alpha_mat = 0.4
- x = self.visual.bw_mu[n,0]
- y = self.visual.bw_mu[n,1]
+ x = self.state.bw_mu[n,0]
+ y = self.state.bw_mu[n,1]
el = QtWidgets.QGraphicsEllipseItem(x-(width/2), y-(height/2), width, height, self.plt)
el.setBrush(pyqtgraph.mkBrush(QColor(237, 103, 19, int(alpha_mat/1*255))))
el.setPen(pyqtgraph.mkPen(None))
@@ -87,23 +87,22 @@ def plotBw(self):
el.setRotation(angle)
self.plt.addItem(el)
- mask = np.ones(self.visual.bw_mu.shape[0], dtype=bool)
- mask[self.visual.bw_n_obs < .1] = False
- mask[self.visual.bw_dead_nodes] = False
- self.bw_center.setData(x = self.visual.bw_mu[mask, 0], y = self.visual.bw_mu[mask, 1])
+ mask = np.ones(self.state.bw_mu.shape[0], dtype=bool)
+ mask[self.state.bw_n_obs < .1] = False
+ mask[self.state.bw_dead_nodes] = False
+ self.bw_center.setData(x = self.state.bw_mu[mask, 0], y = self.state.bw_mu[mask, 1])
self.plt.addItem(self.bw_center)
def _runProcess(self):
- logger.info("------------------------- put run in comm")
- self.comm.put([Signal.run()])
+ self.state.logger.info("GUI sent run command")
+ self.state.send(Signal.run())
def _setup(self):
- logger.info("------------------------- put setup in comm")
- self.comm.put([Signal.setup()])
- self.visual.setup()
+ self.state.logger.info("GUI sent setup command")
+ self.state.send(Signal.setup())
def closeEvent(self, event):
"""Clicked x/close on window
@@ -117,9 +116,9 @@ def closeEvent(self, event):
QMessageBox.No,
)
if confirm == QMessageBox.Yes:
- self.comm.put([Signal.quit()])
- # print('Visual broke, avg time per frame: ', np.mean(self.visual.total_times, axis=0))
- print("Visual got through ", self.visual.frame_num, " frames")
+ self.state.send(Signal.quit())
+ # print('Visual broke, avg time per frame: ', np.mean(self.state.total_times, axis=0))
+ self.state.logger.info("Visual got through ", self.state.frame_num, " frames")
# print('GUI avg time ', np.mean(self.total_times))
event.accept()
else:
diff --git a/demos/bubblewrap/actors/visual.py b/demos/bubblewrap/actors/visual.py
index f85eed74..bbbe61f7 100644
--- a/demos/bubblewrap/actors/visual.py
+++ b/demos/bubblewrap/actors/visual.py
@@ -1,8 +1,8 @@
-from improv.actor import Actor, Signal
+from improv.actor import Actor, RunManager
from PyQt5 import QtWidgets
-import numpy as np
from queue import Empty
from .front_end import FrontEnd
+from improv.messaging import ActorSignalMsg
import logging
import traceback
@@ -12,50 +12,81 @@
class Visual(Actor):
"""Class used to run a GUI + Visual as a single Actor"""
- def setup(self, visual):
- # self.visual is CaimanVisual
- self.visual = visual
- self.visual.setup()
- logger.info("Running setup for " + self.name)
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ if "name" in kwargs:
+ self.name = kwargs["name"]
+
+ self.state = GUIState(self)
def run(self):
- logger.info("Loading FrontEnd")
- self.app = QtWidgets.QApplication([])
- self.viewer = FrontEnd(self.visual, self.q_comm, self.q_sig)
- self.viewer.show()
- logger.info("GUI ready")
- self.q_comm.put([Signal.ready()])
- self.visual.q_comm.put([Signal.ready()])
- self.app.exec_()
- logger.info("Done running GUI")
+ self.setup_logging()
+ self.register_with_nexus()
+ self.state.add_logger(self.improv_logger)
+ self.state.add_signal_check()
+ self.register_with_broker()
+ self.setup_links()
+ self.q_comm = self.links["q_comm"]
+ self.q_sig = self.links["q_sig"]
-class CaimanVisual(Actor):
- """Class for displaying data from caiman processor"""
+ self.get_store_interface()
- def __init__(self, *args, showConnectivity=False):
- super().__init__(*args)
+ self.improv_logger.info("Loading FrontEnd")
+ self.app = QtWidgets.QApplication([])
+ self.viewer = FrontEnd(self.state)
+ self.viewer.show()
+ self.app.exec_()
+ self.improv_logger.info("Done running GUI")
+
+ def run_step(self):
+ self.viewer.update()
- def setup(self):
+class GUIState:
+ def __init__(self, gui):
+ self.gui = gui
self.data = None
+ self.bw_mu = None
self.bw_L = None
+ self.bw_dead_nodes = None
+ self.bw_n_obs = None
+ self.frame_num = None
- def run(self):
- pass # NOTE: Special case here, tied to GUI
-
+
def getData(self):
"""Load data from dim reduction and bubblewrap, returns false on timeout"""
try:
- bw_res = self.links['bw_in'].get(timeout=0.0005)
- res = self.q_in.get(timeout=0.0005)
- self.data = self.client.getID(res[1])
- self.bw_L = self.client.getID(bw_res[1][1])
- self.bw_mu = self.client.getID(bw_res[1][2])
- self.bw_n_obs = self.client.getID(bw_res[1][3])
- self.bw_dead_nodes = self.client.getID(bw_res[1][6])
- except Empty as e:
+ bw_res = self.gui.links['bw_in'].get(timeout=0.0005)
+ res = self.gui.q_in.get(timeout=0.0005)
+ self.data = self.gui.client.get(res[1])
+ self.bw_L = self.gui.client.get(bw_res[1][1])
+ self.bw_mu = self.gui.client.get(bw_res[1][2])
+ self.bw_n_obs = self.gui.client.get(bw_res[1][3])
+ self.bw_dead_nodes = self.gui.client.get(bw_res[1][6])
+ except Empty:
+ return False
+ except TimeoutError:
return False
except Exception as e:
- logger.error('Visual: Exception in get data: {}'.format(e))
- logger.error(traceback.format_exc())
+ self.logger.error('Visual: Exception in get data: {}'.format(e))
+ self.logger.error(traceback.format_exc())
return True
+
+ def send(self, signal):
+ actor_signal = ActorSignalMsg(
+ self.gui.name,
+ signal,
+ f"Sending signal {signal} to nexus",
+ )
+ return self.gui.q_comm.send(actor_signal)
+
+ def add_logger(self, logger):
+ self.logger = logger
+
+ def add_signal_check(self):
+ # make function to perform intermittent signal checking from Nexus
+ gui = self.gui
+ rm = RunManager(
+ gui.name, gui.actions, gui.links, gui.nexus_sig_port, self.logger
+ )
+ self.signal_check = rm.loop_logic
diff --git a/demos/bubblewrap/bubble_demo.yaml b/demos/bubblewrap/bubble_demo.yaml
index b0e49114..dd4e73e8 100644
--- a/demos/bubblewrap/bubble_demo.yaml
+++ b/demos/bubblewrap/bubble_demo.yaml
@@ -2,7 +2,7 @@ actors:
GUI:
package: actors.visual
class: Visual
- visual: Visual
+ method: "spawn"
Acquirer:
package: actors.acquire
@@ -18,13 +18,27 @@ actors:
package: actors.bubble
class: Bubble
- Visual:
- package: actors.visual
- class: CaimanVisual
-
connections:
- Acquirer.q_out: [DimReduction.q_in]
- DimReduction.q_out: [Bubblewrap.q_in]
- DimReduction.v_out: [Visual.q_in]
- Bubblewrap.q_out: [Visual.bw_in]
-
+ Acquirer-Reducer:
+ sources:
+ - Acquirer.q_out
+ sinks:
+ - DimReduction.q_in
+
+ Reducer-Wrapper:
+ sources:
+ - DimReduction.q_out
+ sinks:
+ - Bubblewrap.q_in
+
+ Reducer-GUI:
+ sources:
+ - DimReduction.v_out
+ sinks:
+ - GUI.q_in
+
+ Bubblewrapper-GUI:
+ sources:
+ - Bubblewrap.q_out
+ sinks:
+ - GUI.bw_in
diff --git a/demos/bubblewrap/bubble_demo_spawn.yaml b/demos/bubblewrap/bubble_demo_spawn.yaml
deleted file mode 100644
index 3e9be24f..00000000
--- a/demos/bubblewrap/bubble_demo_spawn.yaml
+++ /dev/null
@@ -1,30 +0,0 @@
-actors:
- GUI:
- package: actors.visual
- class: Visual
- visual: Visual
-
- Acquirer:
- package: actors.acquire
- class: Acquirer
- filename: data/indy_20160407_02.mat
-
- DimReduction:
- package: actors.dimension_reduction
- class: DimReduction
-
- Bubblewrap:
- package: actors.bubble
- class: Bubble
- dimension: 2
- method: spawn
-
- Visual:
- package: actors.visual
- class: CaimanVisual
-
-connections:
- Acquirer.q_out: [DimReduction.q_in]
- DimReduction.q_out: [Visual.q_in, Bubblewrap.q_in]
- Bubblewrap.q_out: [Visual.bw_in]
-
diff --git a/demos/live/live_demo.py b/demos/live/live_demo.py
deleted file mode 100644
index fff6839f..00000000
--- a/demos/live/live_demo.py
+++ /dev/null
@@ -1,23 +0,0 @@
-import logging
-
-# Matplotlib is overly verbose by default
-logging.getLogger("matplotlib").setLevel(logging.WARNING)
-from improv.nexus import Nexus
-
-loadFile = "./live_demo.yaml"
-
-nexus = Nexus("Nexus")
-nexus.createNexus(file=loadFile)
-
-# All modules needed have been imported
-# so we can change the level of logging here
-# import logging
-# import logging.config
-# logging.config.dictConfig({
-# 'version': 1,
-# 'disable_existing_loggers': True,
-# })
-# logger = logging.getLogger("improv")
-# logger.setLevel(logging.INFO)
-
-nexus.startNexus()
diff --git a/demos/zmq/actors/zmq_ps_sample_generator.py b/demos/minimal/actors/sample_generator_zmq.py
similarity index 54%
rename from demos/zmq/actors/zmq_ps_sample_generator.py
rename to demos/minimal/actors/sample_generator_zmq.py
index 42b0afc4..e0a75655 100644
--- a/demos/zmq/actors/zmq_ps_sample_generator.py
+++ b/demos/minimal/actors/sample_generator_zmq.py
@@ -1,15 +1,14 @@
+from improv.actor import ZmqActor
+from datetime import date # used for saving
import numpy as np
import logging
-from demos.sample_actors.zmqActor import ZmqActor
-
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class Generator(ZmqActor):
- """Sample actor to generate data to pass into a sample processor
- using sync ZMQ to communicate.
+ """Sample actor to generate data to pass into a sample processor.
Intended for use along with sample_processor.py.
"""
@@ -25,39 +24,53 @@ def __str__(self):
def setup(self):
"""Generates an array that serves as an initial source of data.
- Sets up a ZmqPSActor to send data to the processor.
Initial array is a 100 row, 5 column numpy matrix that contains
integers from 1-99, inclusive.
"""
- logger.info("Beginning setup for Generator")
self.data = np.asmatrix(np.random.randint(100, size=(100, 5)))
- logger.info("Completed setup for Generator")
+ self.improv_logger.info("Completed setup for Generator")
+
+ # def run(self):
+ # """ Send array into the store.
+ # """
+ # self.fcns = {}
+ # self.fcns['setup'] = self.setup
+ # self.fcns['run'] = self.run_step
+ # self.fcns['stop'] = self.stop
+
+ # with RunManager(self.name, self.fcns, self.links) as rm:
+ # logger.info(rm)
def stop(self):
"""Save current randint vector to a file."""
- logger.info("Generator stopping")
- np.save("sample_generator_data.npy", self.data)
+ self.improv_logger.info("Generator stopping")
+ np.save(f"sample_generator_data", self.data)
+ # This is not the best example of a save function,
+ # will overwrite previous files with the same name.
return 0
- def runStep(self):
+ def run_step(self):
"""Generates additional data after initial setup data is exhausted.
- Sends data to the processor using a ZmqPSActor.
Data is of a different form as the setup data in that although it is
the same size (5x1 vector), it is uniformly distributed in [1, 10]
instead of in [1, 100]. Therefore, the average over time should
converge to 5.5.
"""
+
if self.frame_num < np.shape(self.data)[0]:
- data_id = self.client.put(self.data[self.frame_num], str(f"Gen_raw: {self.frame_num}"))
+ data_id = self.client.put(self.data[self.frame_num])
try:
- self.put(data_id) #[data_id, str(self.frame_num)])
+ self.q_out.put(data_id)
+ # self.improv_logger.info(f"Sent {self.data[self.frame_num]} with key {data_id}")
self.frame_num += 1
+
except Exception as e:
- logger.error(f"---------Generator Exception: {e}")
+ self.improv_logger.error(f"Generator Exception: {e}")
else:
- new_data = np.asmatrix(np.random.randint(10, size=(1, 5)))
- self.data = np.concatenate((self.data, new_data), axis=0)
+ self.data = np.concatenate(
+ (self.data, np.asmatrix(np.random.randint(10, size=(1, 5)))), axis=0
+ )
diff --git a/demos/minimal/actors/sample_persistence_generator.py b/demos/minimal/actors/sample_persistence_generator.py
new file mode 100644
index 00000000..5f6afbf1
--- /dev/null
+++ b/demos/minimal/actors/sample_persistence_generator.py
@@ -0,0 +1,78 @@
+import random
+import time
+
+from improv.actor import ZmqActor
+from datetime import date # used for saving
+import numpy as np
+import logging
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
+
+class Generator(ZmqActor):
+ """Sample actor to generate data to pass into a sample processor.
+
+ Intended for use along with sample_processor.py.
+ """
+
+ def __init__(self, output_filename, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.data = None
+ self.name = "Generator"
+ self.frame_num = 0
+ self.output_filename = output_filename
+
+ def __str__(self):
+ return f"Name: {self.name}, Data: {self.data}"
+
+ def setup(self):
+ """Generates an array that serves as an initial source of data.
+
+ Initial array is a 100 row, 5 column numpy matrix that contains
+ integers from 1-99, inclusive.
+ """
+
+ self.data = np.asmatrix(np.random.randint(100, size=(100, 5)))
+ self.improv_logger.info("Completed setup for Generator")
+
+ def stop(self):
+ """Save current randint vector to a file."""
+
+ self.improv_logger.info("Generator stopping")
+ return 0
+
+ def run_step(self):
+ """Generates additional data after initial setup data is exhausted.
+
+ Data is of a different form as the setup data in that although it is
+ the same size (5x1 vector), it is uniformly distributed in [1, 10]
+ instead of in [1, 100]. Therefore, the average over time should
+ converge to 5.5.
+ """
+
+
+ device_time = time.time_ns()
+ time.sleep(0.0003)
+ acquired_time = time.time_ns() # mock the time the generator "actually received" the data
+
+ if self.frame_num < np.shape(self.data)[0]:
+
+ with open(self.output_filename, "a+") as f:
+ device_data = self.data[self.frame_num]
+ packaged_data = (acquired_time, (device_time, device_data))
+ # save the data to the flat file before sending it downstream
+ f.write(f"{packaged_data[0]}, {packaged_data[1][0]}, {packaged_data[1][1]}\n")
+
+ data_id = self.client.put(packaged_data)
+ try:
+ self.q_out.put(data_id)
+ # self.improv_logger.info(f"Sent {self.data[self.frame_num]} with key {data_id}")
+ self.frame_num += 1
+
+ except Exception as e:
+ self.improv_logger.error(f"Generator Exception: {e}")
+ else:
+ self.data = np.concatenate(
+ (self.data, np.asmatrix(np.random.randint(10, size=(1, 5)))), axis=0
+ )
diff --git a/demos/minimal/actors/sample_processor.py b/demos/minimal/actors/sample_persistence_processor.py
similarity index 61%
rename from demos/minimal/actors/sample_processor.py
rename to demos/minimal/actors/sample_persistence_processor.py
index 40fbf4d5..3a38d3fc 100644
--- a/demos/minimal/actors/sample_processor.py
+++ b/demos/minimal/actors/sample_persistence_processor.py
@@ -1,4 +1,4 @@
-from improv.actor import Actor
+from improv.actor import ZmqActor
import numpy as np
import logging
@@ -6,7 +6,7 @@
logger.setLevel(logging.INFO)
-class Processor(Actor):
+class Processor(ZmqActor):
"""Sample processor used to calculate the average of an array of integers.
Intended for use with sample_generator.py.
@@ -14,6 +14,8 @@ class Processor(Actor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
+ if "name" in kwargs:
+ self.name = kwargs["name"]
def setup(self):
"""Initializes all class variables.
@@ -24,18 +26,20 @@ def setup(self):
self.frame_num (int): index of current frame.
"""
- self.name = "Processor"
+ if not hasattr(self, "name"):
+ self.name = "Processor"
self.frame = None
self.avg_list = []
self.frame_num = 1
- logger.info("Completed setup for Processor")
+ self.improv_logger.info("Completed setup for Processor")
def stop(self):
"""Trivial stop function for testing purposes."""
- logger.info("Processor stopping")
+ self.improv_logger.info("Processor stopping")
+ return 0
- def runStep(self):
+ def run_step(self):
"""Gets from the input queue and calculates the average.
Receives an ObjectID, references data in the store using that
@@ -46,22 +50,17 @@ def runStep(self):
frame = None
try:
frame = self.q_in.get(timeout=0.05)
-
- except Exception:
- logger.error("Could not get frame!")
+ except Exception as e:
+ # logger.error(f"{self.name} could not get frame! At {self.frame_num}: {e}")
pass
if frame is not None and self.frame_num is not None:
self.done = False
- if self.store_loc:
- self.frame = self.client.getID(frame[0][0])
- else:
- self.frame = self.client.get(frame)
- avg = np.mean(self.frame[0])
-
- logger.info(f"Average: {avg}")
+ self.frame = self.client.get(frame)
+ device_data = self.frame[1][1]
+ avg = np.mean(device_data)
+ # self.improv_logger.info(f"Average: {avg}")
self.avg_list.append(avg)
- logger.info(f"Overall Average: {np.mean(self.avg_list)}")
- logger.info(f"Frame number: {self.frame_num}")
-
+ # self.improv_logger.info(f"Overall Average: {np.mean(self.avg_list)}")
+ # self.improv_logger.info(f"Frame number: {self.frame_num}")
self.frame_num += 1
diff --git a/demos/minimal/actors/sample_spawn_processor.py b/demos/minimal/actors/sample_processor_zmq.py
similarity index 64%
rename from demos/minimal/actors/sample_spawn_processor.py
rename to demos/minimal/actors/sample_processor_zmq.py
index cdc8bbb7..3b84a399 100644
--- a/demos/minimal/actors/sample_spawn_processor.py
+++ b/demos/minimal/actors/sample_processor_zmq.py
@@ -1,13 +1,13 @@
-from improv.actor import Actor
+from improv.actor import ZmqActor
import numpy as np
-from queue import Empty
import logging
+import time
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
-class Processor(Actor):
+class Processor(ZmqActor):
"""Sample processor used to calculate the average of an array of integers.
Intended for use with sample_generator.py.
@@ -15,6 +15,8 @@ class Processor(Actor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
+ if "name" in kwargs:
+ self.name = kwargs["name"]
def setup(self):
"""Initializes all class variables.
@@ -25,21 +27,20 @@ def setup(self):
self.frame_num (int): index of current frame.
"""
- self.name = "Processor"
+ if not hasattr(self, "name"):
+ self.name = "Processor"
self.frame = None
self.avg_list = []
self.frame_num = 1
- logger.info("Completed setup for Processor")
-
- self._getStoreInterface()
+ self.improv_logger.info("Completed setup for Processor")
def stop(self):
"""Trivial stop function for testing purposes."""
- logger.info("Processor stopping")
+ self.improv_logger.info("Processor stopping")
return 0
- def runStep(self):
+ def run_step(self):
"""Gets from the input queue and calculates the average.
Receives an ObjectID, references data in the store using that
@@ -50,23 +51,17 @@ def runStep(self):
frame = None
try:
frame = self.q_in.get(timeout=0.05)
- except Empty:
- pass
- except Exception:
- logger.error("Could not get frame!")
+ except Exception as e:
+ logger.error(f"{self.name} could not get frame! At {self.frame_num}: {e}")
pass
if frame is not None and self.frame_num is not None:
self.done = False
- if self.store_loc:
- self.frame = self.client.getID(frame[0][0])
- else:
- self.frame = self.client.get(frame)
+ self.frame = self.client.get(frame)
avg = np.mean(self.frame[0])
-
- logger.info(f"Average: {avg}")
+ self.improv_logger.info(f"Average: {avg}")
self.avg_list.append(avg)
- logger.info(f"Overall Average: {np.mean(self.avg_list)}")
- logger.info(f"Frame number: {self.frame_num}")
-
+ self.improv_logger.info(f"Overall Average: {np.mean(self.avg_list)}")
+ self.improv_logger.info(f"Frame number: {self.frame_num}")
self.frame_num += 1
+ time.sleep(1)
diff --git a/demos/minimal/minimal.yaml b/demos/minimal/minimal.yaml
index c8e0b24e..d752b415 100644
--- a/demos/minimal/minimal.yaml
+++ b/demos/minimal/minimal.yaml
@@ -1,14 +1,15 @@
actors:
Generator:
- package: actors.sample_generator
+ package: actors.sample_generator_zmq
class: Generator
Processor:
- package: actors.sample_processor
+ package: actors.sample_processor_zmq
class: Processor
connections:
- Generator.q_out: [Processor.q_in]
-
-redis_config:
- port: 6379
\ No newline at end of file
+ Generator-Processor:
+ sources:
+ - Generator.q_out
+ sinks:
+ - Processor.q_in
diff --git a/demos/minimal/minimal_persistence.yaml b/demos/minimal/minimal_persistence.yaml
new file mode 100644
index 00000000..65226762
--- /dev/null
+++ b/demos/minimal/minimal_persistence.yaml
@@ -0,0 +1,16 @@
+actors:
+ Generator:
+ package: actors.sample_persistence_generator
+ class: Generator
+ output_filename: test_persistence.csv
+
+ Processor:
+ package: actors.sample_persistence_processor
+ class: Processor
+
+connections:
+ Generator-Processor:
+ sources:
+ - Generator.q_out
+ sinks:
+ - Processor.q_in
diff --git a/demos/minimal/minimal_plasma.yaml b/demos/minimal/minimal_plasma.yaml
deleted file mode 100644
index ab869678..00000000
--- a/demos/minimal/minimal_plasma.yaml
+++ /dev/null
@@ -1,13 +0,0 @@
-actors:
- Generator:
- package: actors.sample_generator
- class: Generator
-
- Processor:
- package: actors.sample_processor
- class: Processor
-
-connections:
- Generator.q_out: [Processor.q_in]
-
-plasma_config:
\ No newline at end of file
diff --git a/demos/minimal/minimal_spawn.yaml b/demos/minimal/minimal_spawn.yaml
index 6f22e653..3bc2f677 100644
--- a/demos/minimal/minimal_spawn.yaml
+++ b/demos/minimal/minimal_spawn.yaml
@@ -1,15 +1,16 @@
actors:
Generator:
- package: actors.sample_generator
+ package: actors.sample_generator_zmq
class: Generator
Processor:
- package: actors.sample_spawn_processor
+ package: actors.sample_processor_zmq
class: Processor
method: spawn
connections:
- Generator.q_out: [Processor.q_in]
-
-redis_config:
- port: 6378
\ No newline at end of file
+ Generator-Processor:
+ sources:
+ - Generator.q_out
+ sinks:
+ - Processor.q_in
\ No newline at end of file
diff --git a/demos/naumann/actors/analysis_model.py b/demos/naumann/actors/analysis_model.py
index db47c781..692324f7 100644
--- a/demos/naumann/actors/analysis_model.py
+++ b/demos/naumann/actors/analysis_model.py
@@ -75,7 +75,7 @@ def run(self):
self.fit_times = []
with RunManager(
- self.name, self.runStep, self.setup, self.q_sig, self.q_comm
+ self.name, self.run_step, self.setup, self.q_sig, self.q_comm
) as rm:
logger.info(rm)
@@ -97,7 +97,7 @@ def run(self):
np.savetxt("output/used_stims.txt", self.currStimID)
- def runStep(self):
+ def run_step(self):
"""Take numpy estimates and frame_number
Create X and Y for plotting
"""
diff --git a/demos/naumann/actors/processor.py b/demos/naumann/actors/processor.py
index 7ebaaf72..a6f91037 100644
--- a/demos/naumann/actors/processor.py
+++ b/demos/naumann/actors/processor.py
@@ -87,7 +87,7 @@ def stop(self):
- def runStep(self):
+ def run_step(self):
"""Run process. Runs once per frame.
Output is a location in the DS to continually
place the Estimates results, with ref number that
diff --git a/demos/neurofinder/neurofind_demo.py b/demos/neurofinder/neurofind_demo.py
deleted file mode 100644
index 5f49399c..00000000
--- a/demos/neurofinder/neurofind_demo.py
+++ /dev/null
@@ -1,23 +0,0 @@
-import logging
-
-# Matplotlib is overly verbose by default
-logging.getLogger("matplotlib").setLevel(logging.WARNING)
-from improv.nexus import Nexus
-
-loadFile = "./neurofind_demo.yaml"
-
-nexus = Nexus("Nexus")
-nexus.createNexus(file=loadFile)
-
-# All modules needed have been imported
-# so we can change the level of logging here
-# import logging
-# import logging.config
-# logging.config.dictConfig({
-# 'version': 1,
-# 'disable_existing_loggers': True,
-# })
-# logger = logging.getLogger("improv")
-# logger.setLevel(logging.INFO)
-
-nexus.startNexus()
diff --git a/demos/neurofinder/neurofind_demo.yaml b/demos/neurofinder/neurofind_demo.yaml
index eb7796bc..483f6dc8 100644
--- a/demos/neurofinder/neurofind_demo.yaml
+++ b/demos/neurofinder/neurofind_demo.yaml
@@ -1,6 +1,3 @@
-settings:
- use_watcher: [Acquirer, Processor, Visual, Analysis]
-
actors:
GUI:
package: demos.basic.actors.visual
diff --git a/demos/sample_actors/acquire.py b/demos/sample_actors/acquire.py
index 28dfb4f7..1fa9f513 100644
--- a/demos/sample_actors/acquire.py
+++ b/demos/sample_actors/acquire.py
@@ -5,7 +5,7 @@
import numpy as np
from skimage.io import imread
-from improv.actor import Actor
+from improv.actor import ZmqActor
import logging
@@ -15,7 +15,7 @@
# Classes: File Acquirer, Stim, Behavior, Tiff
-class FileAcquirer(Actor):
+class FileAcquirer(ZmqActor):
"""Class to import data from files and output
frames in a buffer, or discrete.
"""
@@ -64,7 +64,7 @@ def stop(self):
np.savetxt("output/timing/acquire_frame_time.txt", np.array(self.total_times))
np.savetxt("output/timing/acquire_timestamp.txt", np.array(self.timestamp))
- def runStep(self):
+ def run_step(self):
"""While frames exist in location specified during setup,
grab frame, save, put in store
"""
@@ -109,7 +109,7 @@ def saveFrame(self, frame):
self.f.flush()
-class StimAcquirer(Actor):
+class StimAcquirer(ZmqActor):
"""Class to load visual stimuli data from file
and stream into the pipeline
"""
@@ -139,7 +139,7 @@ def setup(self):
else:
raise FileNotFoundError
- def runStep(self):
+ def run_step(self):
"""Check for input from behavioral control"""
if self.n < len(self.stim):
# s = self.stim[self.sID]
@@ -149,7 +149,7 @@ def runStep(self):
self.n += 1
-class BehaviorAcquirer(Actor):
+class BehaviorAcquirer(ZmqActor):
"""Actor that acquires information of behavioral stimulus
during the experiment
@@ -176,7 +176,7 @@ def setup(self):
else:
self.behaviors = [0, 1, 2, 3, 4, 5, 6, 7] # 8 sets of input stimuli
- def runStep(self):
+ def run_step(self):
"""Check for input from behavioral control"""
# Faking it for now.
if self.n % 50 == 0:
@@ -188,7 +188,7 @@ def runStep(self):
self.n += 1
-class FileStim(Actor):
+class FileStim(ZmqActor):
"""Actor that acquires information of behavioral stimulus
during the experiment from a file
"""
@@ -204,7 +204,7 @@ def setup(self):
self.data = np.loadtxt(self.file)
- def runStep(self):
+ def run_step(self):
"""Check for input from behavioral control"""
# Faking it for now.
if self.n % 50 == 0 and self.n < self.data.shape[1] * 50:
@@ -216,7 +216,7 @@ def runStep(self):
self.n += 1
-class TiffAcquirer(Actor):
+class TiffAcquirer(ZmqActor):
"""Loops through a TIF file."""
def __init__(self, *args, filename=None, framerate=30, **kwargs):
@@ -236,7 +236,7 @@ def setup(self):
self.imgs = imread(self.filename)
print(self.imgs.shape)
- def runStep(self):
+ def run_step(self):
t0 = time.time()
id_store = self.client.put(
self.imgs[self.n_frame], "acq_raw" + str(self.n_frame)
diff --git a/demos/sample_actors/analysis.py b/demos/sample_actors/analysis.py
index 664ddaf3..081aa32e 100644
--- a/demos/sample_actors/analysis.py
+++ b/demos/sample_actors/analysis.py
@@ -4,7 +4,7 @@
from queue import Empty
import os
-from improv.actor import Actor, RunManager
+from improv.actor import ZmqActor
from improv.store import ObjectNotFoundError
import logging
@@ -13,7 +13,7 @@
logger.setLevel(logging.INFO)
-class MeanAnalysis(Actor):
+class MeanAnalysis(ZmqActor):
# TODO: Add additional error handling
# TODO: this is too complex for a sample actor?
def __init__(self, *args, **kwargs):
@@ -74,7 +74,7 @@ def stop(self):
)
np.savetxt("output/timing/analysis_timestamp.txt", np.array(self.timestamp))
- def runStep(self):
+ def run_step(self):
"""Take numpy estimates and frame_number
Create X and Y for plotting
"""
diff --git a/demos/sample_actors/analysis_julia.py b/demos/sample_actors/analysis_julia.py
index 83d5b1d0..d37b59e2 100644
--- a/demos/sample_actors/analysis_julia.py
+++ b/demos/sample_actors/analysis_julia.py
@@ -55,7 +55,7 @@ def stop(self):
print("Julia Analysis broke, avg time per frame: ", np.mean(self.t_per_frame))
print("JuliaAnalysis got through ", self.frame_number, " frames.")
- def runStep(self):
+ def run_step(self):
t = time.time()
try:
obj_id = self.q_in.get(timeout=0.0001) # List
diff --git a/demos/sample_actors/process.py b/demos/sample_actors/process.py
index 603064c8..bdeb10aa 100644
--- a/demos/sample_actors/process.py
+++ b/demos/sample_actors/process.py
@@ -119,7 +119,7 @@ def stop(self):
print("type ", type(self.coords1[0]))
np.savetxt("output/contours.txt", np.array(self.coords1))
- def runStep(self):
+ def run_step(self):
"""Run process. Runs once per frame.
Output is a location in the DS to continually
place the Estimates results, with ref number that
diff --git a/demos/sample_actors/simple_analysis.py b/demos/sample_actors/simple_analysis.py
index ff350b58..0b43eb06 100644
--- a/demos/sample_actors/simple_analysis.py
+++ b/demos/sample_actors/simple_analysis.py
@@ -3,7 +3,7 @@
from queue import Empty
import os
-from improv.actor import Actor, RunManager
+from improv.actor import ZmqActor
from improv.store import ObjectNotFoundError
import logging
@@ -12,7 +12,7 @@
logger.setLevel(logging.INFO)
-class SimpleAnalysis(Actor):
+class SimpleAnalysis(ZmqActor):
# TODO: Add additional error handling
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -72,7 +72,7 @@ def stop(self):
)
np.savetxt("output/timing/analysis_timestamp.txt", np.array(self.timestamp))
- def runStep(self):
+ def run_step(self):
"""Take numpy estimates and frame_number
Create X and Y for plotting
"""
diff --git a/demos/sample_actors/zmqActor.py b/demos/sample_actors/zmqActor.py
deleted file mode 100644
index 6455b0b4..00000000
--- a/demos/sample_actors/zmqActor.py
+++ /dev/null
@@ -1,206 +0,0 @@
-import asyncio
-import time
-
-import zmq
-from zmq import (
- PUB,
- SUB,
- SUBSCRIBE,
- REQ,
- REP,
- LINGER,
- Again,
- NOBLOCK,
- ZMQError,
- EAGAIN,
- ETERM,
-)
-from zmq.log.handlers import PUBHandler
-import zmq.asyncio
-
-from improv.actor import Actor
-
-import logging
-
-logger = logging.getLogger(__name__)
-logger.setLevel(logging.INFO)
-
-
-class ZmqActor(Actor):
- """
- Zmq actor with pub/sub or rep/req pattern.
- """
-
- def __init__(self, *args, type="PUB", ip="127.0.0.1", port=5555, **kwargs):
- super().__init__(*args, **kwargs)
- logger.info("Constructed Zmq Actor")
- if str(type) in "PUB" or str(type) in "SUB":
- self.pub_sub_flag = True # default
- else:
- self.pub_sub_flag = False
- self.rep_req_flag = not self.pub_sub_flag
- self.ip = ip
- self.port = port
- self.address = "tcp://{}:{}".format(self.ip, self.port)
-
- self.send_socket = None
- self.recv_socket = None
- self.req_socket = None
- self.rep_socket = None
-
- self.context = zmq.Context.instance()
-
- def sendMsg(self, msg, msg_type="pyobj"):
- """
- Sends a message to the controller.
- """
- if not self.send_socket:
- self.setSendSocket()
-
- if msg_type == "multipart":
- self.send_socket.send_multipart(msg)
- if msg_type == "pyobj":
- self.send_socket.send_pyobj(msg)
- elif msg_type == "single":
- self.send_socket.send(msg)
-
- def recvMsg(self, msg_type="pyobj", flags=0):
- """
- Receives a message from the controller.
-
- NOTE: default flag=0 instead of flag=NOBLOCK
- """
- if not self.recv_socket:
- self.setRecvSocket()
-
- while True:
- try:
- if msg_type == "multipart":
- recv_msg = self.recv_socket.recv_multipart(flags=flags)
- elif msg_type == "pyobj":
- recv_msg = self.recv_socket.recv_pyobj(flags=flags)
- elif msg_type == "single":
- recv_msg = self.recv_socket.recv(flags=flags)
- break
- except Again:
- pass
- except ZMQError as e:
- logger.info(f"ZMQ error: {e}")
- if e.errno == ETERM:
- pass # interrupted - pass or break if in try loop
- if e.errno == EAGAIN:
- pass # no message was ready (yet!)
-
- # self.recv_socket.close()
- return recv_msg
-
- def requestMsg(self, msg):
- """Safe version of send/receive with controller.
- Based on the Lazy Pirate pattern [here]
- (https://zguide.zeromq.org/docs/chapter4/#Client-Side-Reliability-Lazy-Pirate-Pattern)
- """
- REQUEST_TIMEOUT = 500
- REQUEST_RETRIES = 3
- retries_left = REQUEST_RETRIES
-
- self.setReqSocket()
-
- reply = None
- try:
- logger.debug(f"Sending {msg} to controller.")
- self.req_socket.send_pyobj(msg)
-
- while True:
- ready = self.req_socket.poll(REQUEST_TIMEOUT)
-
- if ready:
- reply = self.req_socket.recv_pyobj()
- logger.debug(f"Received {reply} from controller.")
- break
- else:
- retries_left -= 1
- logger.debug("No response from server.")
-
- # try to close and reconnect
- self.req_socket.setsockopt(LINGER, 0)
- self.req_socket.close()
- if retries_left == 0:
- logger.debug("Server seems to be offline. Giving up.")
- break
-
- logger.debug("Attempting to reconnect to server...")
-
- self.setReqSocket()
-
- logger.debug(f"Resending {msg} to controller.")
- self.req_socket.send_pyobj(msg)
-
- except asyncio.CancelledError:
- pass
-
- self.req_socket.close()
- return reply
-
- def replyMsg(self, reply, delay=0.0001):
- """
- Safe version of receive/reply with controller.
- """
-
- self.setRepSocket()
-
- msg = self.rep_socket.recv_pyobj()
- time.sleep(delay)
- self.rep_socket.send_pyobj(reply)
- self.rep_socket.close()
-
- return msg
-
- def put(self, msg=None):
- logger.debug(f"Putting message {msg}")
- if self.pub_sub_flag:
- logger.debug(f"putting message {msg} using pub/sub")
- return self.sendMsg(msg)
- else:
- logger.debug(f"putting message {msg} using rep/req")
- return self.requestMsg(msg)
-
- def get(self, reply=None):
- if self.pub_sub_flag:
- logger.debug(f"getting message with pub/sub")
- return self.recvMsg()
- else:
- logger.debug(f"getting message using reply {reply} with pub/sub")
- return self.replyMsg(reply)
-
- def setSendSocket(self, timeout=1.001):
- """
- Sets up the send socket for the actor.
- """
- self.send_socket = self.context.socket(PUB)
- self.send_socket.bind(self.address)
- time.sleep(timeout)
-
- def setRecvSocket(self, timeout=1.001):
- """
- Sets up the receive socket for the actor.
- """
- self.recv_socket = self.context.socket(SUB)
- self.recv_socket.connect(self.address)
- self.recv_socket.setsockopt(SUBSCRIBE, b"")
- time.sleep(timeout)
-
- def setReqSocket(self, timeout=0.0001):
- """
- Sets up the request socket for the actor.
- """
- self.req_socket = self.context.socket(REQ)
- self.req_socket.connect(self.address)
- time.sleep(timeout)
-
- def setRepSocket(self, timeout=0.0001):
- """
- Sets up the reply socket for the actor.
- """
- self.rep_socket = self.context.socket(REP)
- self.rep_socket.bind(self.address)
- time.sleep(timeout)
diff --git a/demos/spike/spike_demo.py b/demos/spike/spike_demo.py
deleted file mode 100644
index ca4f110d..00000000
--- a/demos/spike/spike_demo.py
+++ /dev/null
@@ -1,23 +0,0 @@
-import logging
-
-# Matplotlib is overly verbose by default
-logging.getLogger("matplotlib").setLevel(logging.WARNING)
-from improv.nexus import Nexus
-
-loadFile = "./spike_demo.yaml"
-
-nexus = Nexus("Nexus")
-nexus.createNexus(file=loadFile)
-
-# All modules needed have been imported
-# so we can change the level of logging here
-# import logging
-# import logging.config
-# logging.config.dictConfig({
-# 'version': 1,
-# 'disable_existing_loggers': True,
-# })
-# logger = logging.getLogger("improv")
-# logger.setLevel(logging.INFO)
-
-nexus.startNexus()
diff --git a/demos/spike/spike_demo.yaml b/demos/spike/spike_demo.yaml
index 6aad9992..8b46f7a7 100644
--- a/demos/spike/spike_demo.yaml
+++ b/demos/spike/spike_demo.yaml
@@ -1,6 +1,3 @@
-# settings:
-# use_watcher: [Acquirer, Analysis]
-
actors:
# GUI:
# package: actors.visual
diff --git a/demos/zmq/README.md b/demos/zmq/README.md
deleted file mode 100644
index f02edc26..00000000
--- a/demos/zmq/README.md
+++ /dev/null
@@ -1,14 +0,0 @@
-# ZMQ Demo
-
-This demo is intended to show how actors can communicate using zmq. There are two options for the demo, each corresponding to a different config file. The difference between these two options is in how they send and receive messages. One uses the publish/send concept, while the other uses the request/reply concept.
-
-## Instructions
-
-Just like any other demo, you can run `improv run ` in order to run the demo. For example, from within the `improv/demos/zmq` directory, running `improv run zmq_rr_demo.yaml` will run the zmq_rr demo.
-
-After running the demo, a tui (text user interface) should show up. From here, we can type `setup` followed by `run` to run the demo. After we are done, we can type `stop` to pause, or `quit` to exit.
-
-## Expected Output
-
-Currently, only logging output is supported. There will be no live output during the run.
-
diff --git a/demos/zmq/actors/zmq_ps_sample_processor.py b/demos/zmq/actors/zmq_ps_sample_processor.py
deleted file mode 100644
index f54c00c8..00000000
--- a/demos/zmq/actors/zmq_ps_sample_processor.py
+++ /dev/null
@@ -1,65 +0,0 @@
-import numpy as np
-import logging
-
-from demos.sample_actors.zmqActor import ZmqActor
-
-logger = logging.getLogger(__name__)
-logger.setLevel(logging.INFO)
-
-
-class Processor(ZmqActor):
- """Sample processor used to calculate the average of an array of integers
- using sync ZMQ to communicate.
-
- Intended for use with sample_generator.py.
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
- def setup(self):
- """Initializes all class variables.
- Sets up a ZmqRRActor to receive data from the generator.
-
- self.name (string): name of the actor.
- self.frame (ObjectID): Store object id referencing data from the store.
- self.avg_list (list): list that contains averages of individual vectors.
- self.frame_num (int): index of current frame.
- """
- self.name = "Processor"
- self.frame = None
- self.avg_list = []
- self.frame_num = 1
- logger.info("Completed setup for Processor")
-
- def stop(self):
- """Trivial stop function for testing purposes."""
- logger.info("Processor stopping; have received {} frames so far".format(self.frame_num))
-
- def runStep(self):
- """Gets from the input queue and calculates the average.
- Receives data from the generator using a ZmqRRActor.
-
- Receives an ObjectID, references data in the store using that
- ObjectID, calculates the average of that data, and finally prints
- to stdout.
- """
-
- frame = None
- try:
- frame = self.get()
-
- except:
- logger.error("Could not get frame!")
- pass
-
- if frame is not None:
- self.done = False
- self.frame = self.client.getID(frame)
- avg = np.mean(self.frame[0])
-
- # logger.info(f"Average: {avg}")
- self.avg_list.append(avg)
- logger.info(f"Overall Average: {np.mean(self.avg_list)}")
- # logger.info(f"Frame number: {self.frame_num}")
- self.frame_num += 1
diff --git a/demos/zmq/actors/zmq_rr_sample_generator.py b/demos/zmq/actors/zmq_rr_sample_generator.py
deleted file mode 100644
index ef941ae6..00000000
--- a/demos/zmq/actors/zmq_rr_sample_generator.py
+++ /dev/null
@@ -1,59 +0,0 @@
-import numpy as np
-import logging
-
-from demos.sample_actors.zmqActor import ZmqActor
-
-logger = logging.getLogger(__name__)
-logger.setLevel(logging.INFO)
-
-class Generator(ZmqActor):
- """Sample actor to generate data to pass into a sample processor
- using async ZMQ to communicate.
-
- Intended for use along with sample_processor.py.
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.data = None
- self.name = "Generator"
- self.frame_num = 0
-
- def __str__(self):
- return f"Name: {self.name}, Data: {self.data}"
-
- def setup(self):
- """Generates an array that serves as an initial source of data.
- Sets up a ZmqRRActor to send data to the processor.
-
- Initial array is a 100 row, 5 column numpy matrix that contains
- integers from 1-99, inclusive.
- """
- logger.info("Beginning setup for Generator")
- self.data = np.asmatrix(np.random.randint(100, size=(100, 5)))
- logger.info("Completed setup for Generator")
-
- def stop(self):
- """Save current randint vector to a file."""
- logger.info("Generator stopping")
- np.save("sample_generator_data.npy", self.data)
- return 0
-
- def runStep(self):
- """Generates additional data after initial setup data is exhausted.
- Sends data to the processor using a ZmqRRActor.
-
- Data is of a different form as the setup data in that although it is
- the same size (5x1 vector), it is uniformly distributed in [1, 10]
- instead of in [1, 100]. Therefore, the average over time should
- converge to 5.5.
- """
- if self.frame_num < np.shape(self.data)[0]:
- data_id = self.client.put(self.data[self.frame_num], str(f"Gen_raw_{self.frame_num}"))
- try:
- self.put(data_id)
- self.frame_num += 1
- except Exception as e:
- logger.error(f"Generator Exception: {e}")
- else:
- self.data = np.concatenate((self.data, np.asmatrix(np.random.randint(10, size=(1, 5)))), axis=0)
diff --git a/demos/zmq/config/zmq_ps_demo.yaml b/demos/zmq/config/zmq_ps_demo.yaml
deleted file mode 100644
index 230282d1..00000000
--- a/demos/zmq/config/zmq_ps_demo.yaml
+++ /dev/null
@@ -1,11 +0,0 @@
-actors:
- Generator:
- package: actors.sample_generator
- class: Generator
-
- Processor:
- package: actors.sample_processor
- class: Processor
-
-connections:
- Generator.q_out: [Processor.q_in]
\ No newline at end of file
diff --git a/demos/zmq/config/zmq_rr_demo.yaml b/demos/zmq/config/zmq_rr_demo.yaml
deleted file mode 100644
index 230282d1..00000000
--- a/demos/zmq/config/zmq_rr_demo.yaml
+++ /dev/null
@@ -1,11 +0,0 @@
-actors:
- Generator:
- package: actors.sample_generator
- class: Generator
-
- Processor:
- package: actors.sample_processor
- class: Processor
-
-connections:
- Generator.q_out: [Processor.q_in]
\ No newline at end of file
diff --git a/demos/zmq/zmq_ps_demo.yaml b/demos/zmq/zmq_ps_demo.yaml
deleted file mode 100644
index 63fca91c..00000000
--- a/demos/zmq/zmq_ps_demo.yaml
+++ /dev/null
@@ -1,17 +0,0 @@
-actors:
- Generator:
- package: actors.zmq_ps_sample_generator
- class: Generator
- ip: 127.0.0.1
- port: 5555
- type: PUB
-
- Processor:
- package: actors.zmq_ps_sample_processor
- class: Processor
- ip: 127.0.0.1
- port: 5555
- type: SUB
-
-connections:
- Generator.q_out: [Processor.q_in]
diff --git a/demos/zmq/zmq_rr_demo.yaml b/demos/zmq/zmq_rr_demo.yaml
deleted file mode 100644
index 585cebd1..00000000
--- a/demos/zmq/zmq_rr_demo.yaml
+++ /dev/null
@@ -1,17 +0,0 @@
-actors:
- Generator:
- package: actors.zmq_rr_sample_generator
- class: Generator
- ip: 127.0.0.1
- port: 5555
- type: REQ
-
- Processor:
- package: actors.zmq_rr_sample_processor
- class: Processor
- ip: 127.0.0.1
- port: 5555
- type: REP
-
-connections:
- Generator.q_out: [Processor.q_in]
diff --git a/docs/_toc.yml b/docs/_toc.yml
index 426f566d..c1216d55 100644
--- a/docs/_toc.yml
+++ b/docs/_toc.yml
@@ -10,5 +10,6 @@ chapters:
- file: design
- file: actors
- file: signals
+- file: logging
- file: bibliography
- file: autoapi/index
diff --git a/docs/logging.md b/docs/logging.md
new file mode 100644
index 00000000..2ba53053
--- /dev/null
+++ b/docs/logging.md
@@ -0,0 +1,31 @@
+(page:logging)=
+# Logging and debugging
+
+The asynchronous and distributed nature of _improv_ can lead to significant difficulties in debugging, since errors in any part of the system may lead to silent failures.[^1] To combat this, _improv_ is designed for extensive logging in every part of the system.
+
+## The global log
+
+The global log file (default: `global.log`) is specified by the `-f` flag on the command line. This is a file to which all outputs of the logger, an autonomous messaging server, are written. Because the logger is started as a part of the Nexus startup, it misses messages generated as part of the command line and prior Nexus setup but should include everything thereafter. The default logging level of the global log file is set to the Python `logging` module's `logging.INFO`.
+
+All messages broadcast by the logging server are also echoed to the Log Messages pane inside the text user interface. The level of logging messages displayed can be toggled between `logging.INFO` and `logging.DEBUG` within that interface.
+
+## `improv-debug.log`
+
+This file, generated in the directory in which the `improv` command is invoked, contains debug-level information generated by the `improv` command-line interface, the Nexus server, and any actor started from Nexus by the `fork` method. It often contains better error information than the global log file, which is particularly important when errors hang the system. It does not contain logging information generated by the text user interface, the logging server, the message broker, or actors started via the `spawn` method.
+
+## Other log files
+
+Invoking `improv` also generates other log files. When debuggin, it is often beneficial to also check these files for error or warning messages:
+
+- `log_server.log`: Information, debugging, and error messages generated by the log server itself.
+- `tui.log`: Logging messages from the text user interface.
+
+```{warning}
+Because _improv_'s default is to preserve information at all costs, it _will not_ delete log files, and the default when starting a new run is to _append_ to existing log files. Different runs should be clearly separated by server start timestamps, but users should keep the following in mind:
+
+1. Log files can get quite large very quickly, particularly when user-defined actors generate many log messages (for example, every time `run_step` executes). In this case, users might want to consider whether these messages should be at the `INFO` level (the default written to the global log) or the `DEBUG` level (not written by default).
+
+2. Because `improv` needs to open and read the global log file, very large logs can slow down system performance. Users should either move previous log files to separate folders or periodically delete logs to ensure that these files do not continue to grow.
+```
+
+[^1]: This is a result of an important design choice: Rather than allow actors to crash the system, we capture and log exceptions, rather than allow them to bubble up.
\ No newline at end of file
diff --git a/docs/running.ipynb b/docs/running.ipynb
index 47b7b2b0..2ed13a09 100644
--- a/docs/running.ipynb
+++ b/docs/running.ipynb
@@ -19,10 +19,10 @@
"languageId": "shellscript"
}
},
- "outputs": [],
"source": [
"!improv --help"
- ]
+ ],
+ "outputs": []
},
{
"cell_type": "markdown",
@@ -43,10 +43,10 @@
"languageId": "shellscript"
}
},
- "outputs": [],
"source": [
"!improv run --help"
- ]
+ ],
+ "outputs": []
},
{
"cell_type": "markdown",
@@ -110,40 +110,10 @@
"languageId": "shellscript"
}
},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "usage: improv server [-h] [-c CONTROL_PORT] [-o OUTPUT_PORT] [-l LOGGING_PORT]\n",
- " [-f LOGFILE] [-a ACTOR_PATH]\n",
- " configfile\n",
- "\n",
- "Start the improv server\n",
- "\n",
- "positional arguments:\n",
- " configfile YAML file specifying improv pipeline\n",
- "\n",
- "options:\n",
- " -h, --help show this help message and exit\n",
- " -c CONTROL_PORT, --control-port CONTROL_PORT\n",
- " local port on which control signals are received\n",
- " -o OUTPUT_PORT, --output-port OUTPUT_PORT\n",
- " local port on which output messages are broadcast\n",
- " -l LOGGING_PORT, --logging-port LOGGING_PORT\n",
- " local port on which logging messages are broadcast\n",
- " -f LOGFILE, --logfile LOGFILE\n",
- " name of log file\n",
- " -a ACTOR_PATH, --actor-path ACTOR_PATH\n",
- " search path to add to sys.path when looking for\n",
- " actors; defaults to the directory containing\n",
- " configfile\n"
- ]
- }
- ],
"source": [
"!improv server --help"
- ]
+ ],
+ "outputs": []
},
{
"cell_type": "markdown",
@@ -160,30 +130,10 @@
"languageId": "shellscript"
}
},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "usage: improv client [-h] [-c CONTROL_PORT] [-s SERVER_PORT] [-l LOGGING_PORT]\n",
- "\n",
- "Start the improv client\n",
- "\n",
- "options:\n",
- " -h, --help show this help message and exit\n",
- " -c CONTROL_PORT, --control-port CONTROL_PORT\n",
- " address on which control signals are sent to the\n",
- " server\n",
- " -s SERVER_PORT, --server-port SERVER_PORT\n",
- " address on which messages from the server are received\n",
- " -l LOGGING_PORT, --logging-port LOGGING_PORT\n",
- " address on which logging messages are broadcast\n"
- ]
- }
- ],
"source": [
"!improv client --help"
- ]
+ ],
+ "outputs": []
},
{
"cell_type": "markdown",
diff --git a/docs/signals.md b/docs/signals.md
index 345b61bb..70972fd3 100644
--- a/docs/signals.md
+++ b/docs/signals.md
@@ -20,7 +20,7 @@ In [](tables:signals) we list the signals defined in `Signal` along with the `Ac
| `actor.Signal` received | `ManagedActor` method called |
|---|---|
| `setup` | `setup` |
-| `run` | `runStep` |
+| `run` | `run_step` |
| `pause` | not yet implemented |
| `resume` | not yet implemented |
| `reset` | not yet implemented |
diff --git a/improv/actor.py b/improv/actor.py
index 3d6a2a23..e59bcdfb 100644
--- a/improv/actor.py
+++ b/improv/actor.py
@@ -1,10 +1,17 @@
+from __future__ import annotations
+
import time
import signal
import asyncio
import traceback
from queue import Empty
-import improv.store
+import zmq
+from zmq import SocketOption
+
+from improv import log
+from improv.link import ZmqLink
+from improv.messaging import ActorStateMsg, ActorStateReplyMsg, NexusSignalReplyMsg
from improv.store import StoreInterface
import logging
@@ -13,6 +20,17 @@
logger.setLevel(logging.INFO)
+# TODO construct log handler within post-fork setup based on
+# TODO arguments sent to actor by nexus
+# TODO and add it below; fine to leave logger here as-is
+
+
+class LinkInfo:
+ def __init__(self, link_name, link_topic):
+ self.link_name = link_name
+ self.link_topic = link_topic
+
+
class AbstractActor:
"""Base class for an actor that Nexus
controls and interacts with.
@@ -20,9 +38,7 @@ class AbstractActor:
Also needs to be responsive to sent Signals (e.g. run, setup, etc)
"""
- def __init__(
- self, name, store_loc=None, method="fork", store_port_num=None, *args, **kwargs
- ):
+ def __init__(self, name, method="fork", store_port_num=None, *args, **kwargs):
"""Require a name for multiple instances of the same actor/class
Create initial empty dict of Links for easier referencing
"""
@@ -31,7 +47,6 @@ def __init__(
self.links = {}
self.method = method
self.client = None
- self.store_loc = store_loc
self.lower_priority = False
self.store_port_num = store_port_num
@@ -49,7 +64,7 @@ def __repr__(self):
"""
return self.name + ": " + str(self.links.keys())
- def setStoreInterface(self, client):
+ def set_store_interface(self, client):
"""Sets the client interface to the store
Args:
@@ -57,17 +72,14 @@ def setStoreInterface(self, client):
"""
self.client = client
- def _getStoreInterface(self):
+ def get_store_interface(self):
# TODO: Where do we require this be run? Add a Signal and include in RM?
if not self.client:
store = None
- if StoreInterface == improv.store.RedisStoreInterface:
- store = StoreInterface(self.name, self.store_port_num)
- else:
- store = StoreInterface(self.name, self.store_loc)
- self.setStoreInterface(store)
+ store = StoreInterface(self.name, self.store_port_num)
+ self.set_store_interface(store)
- def setLinks(self, links):
+ def set_links(self, links):
"""General full dict set for links
Args:
@@ -75,7 +87,7 @@ def setLinks(self, links):
"""
self.links = links
- def setCommLinks(self, q_comm, q_sig):
+ def set_comm_links(self, q_comm, q_sig):
"""Set explicit communication links to/from Nexus (q_comm, q_sig)
Args:
@@ -86,7 +98,7 @@ def setCommLinks(self, q_comm, q_sig):
self.q_sig = q_sig
self.links.update({"q_comm": self.q_comm, "q_sig": self.q_sig})
- def setLinkIn(self, q_in):
+ def set_link_in(self, q_in):
"""Set the dedicated input queue
Args:
@@ -95,7 +107,7 @@ def setLinkIn(self, q_in):
self.q_in = q_in
self.links.update({"q_in": self.q_in})
- def setLinkOut(self, q_out):
+ def set_link_out(self, q_out):
"""Set the dedicated output queue
Args:
@@ -104,7 +116,7 @@ def setLinkOut(self, q_out):
self.q_out = q_out
self.links.update({"q_out": self.q_out})
- def setLinkWatch(self, q_watch):
+ def set_link_watch(self, q_watch):
"""Set the dedicated watchout queue
Args:
@@ -113,7 +125,7 @@ def setLinkWatch(self, q_watch):
self.q_watchout = q_watch
self.links.update({"q_watchout": self.q_watchout})
- def addLink(self, name, link):
+ def add_link(self, name, link):
"""Function provided to add additional data links by name
using same form as q_in or q_out
Must be done during registration and not during run
@@ -126,7 +138,7 @@ def addLink(self, name, link):
# User can then use: self.my_queue = self.links['my_queue'] in a setup fcn,
# or continue to reference it using self.links['my_queue']
- def getLinks(self):
+ def get_links(self):
"""Returns dictionary of links for the current actor
Returns:
@@ -134,24 +146,6 @@ def getLinks(self):
"""
return self.links
- def put(self, idnames, q_out=None, save=None):
- """TODO: This is deprecated? Prefer using Links explicitly"""
- if save is None:
- save = [False] * len(idnames)
-
- if len(save) < len(idnames):
- save = save + [False] * (len(idnames) - len(save))
-
- if q_out is None:
- q_out = self.q_out
-
- q_out.put(idnames)
-
- for i in range(len(idnames)):
- if save[i]:
- if self.q_watchout:
- self.q_watchout.put(idnames[i])
-
def setup(self):
"""Essenitally the registration process
Can also be an initialization for the actor
@@ -174,7 +168,7 @@ def stop(self):
"""
pass
- def changePriority(self):
+ def change_priority(self):
"""Try to lower this process' priority
Only changes priority if lower_priority is set
TODO: Only works on unix machines. Add Windows functionality
@@ -188,6 +182,18 @@ def changePriority(self):
logger.info("Lowered priority of this process: {}".format(self.name))
print("Lowered ", os.getpid(), " for ", self.name)
+ def register_with_nexus(self):
+ pass
+
+ def register_with_broker(self):
+ pass
+
+ def setup_links(self):
+ pass
+
+ def setup_logging(self):
+ raise NotImplementedError
+
class ManagedActor(AbstractActor):
def __init__(self, *args, **kwargs):
@@ -196,17 +202,157 @@ def __init__(self, *args, **kwargs):
# Define dictionary of actions for the RunManager
self.actions = {}
self.actions["setup"] = self.setup
- self.actions["run"] = self.runStep
+ self.actions["run"] = self.run_step
self.actions["stop"] = self.stop
+ self.nexus_sig_port: int | None = None
def run(self):
- with RunManager(self.name, self.actions, self.links):
+ self.setup_logging()
+ self.register_with_nexus()
+ self.register_with_broker()
+ self.setup_links()
+ self.get_store_interface()
+ with RunManager(
+ self.name, self.actions, self.links, self.nexus_sig_port, self.improv_logger
+ ):
pass
- def runStep(self):
+ def run_step(self):
raise NotImplementedError
+class ZmqActor(ManagedActor):
+ def __init__(
+ self,
+ nexus_comm_port,
+ broker_sub_port,
+ broker_pub_port,
+ log_pull_port,
+ outgoing_links,
+ incoming_links,
+ broker_host="localhost",
+ log_host="localhost",
+ nexus_host="localhost",
+ *args,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+ self.broker_pub_socket: zmq.Socket | None = None
+ self.zmq_context: zmq.Context | None = None
+ self.nexus_host: str = nexus_host
+ self.nexus_comm_port: int = nexus_comm_port
+ self.nexus_sig_port: int | None = None
+ self.nexus_comm_socket: zmq.Socket | None = None
+ self.nexus_sig_socket: zmq.Socket | None = None
+ self.broker_sub_port: int = broker_sub_port
+ self.broker_pub_port: int = broker_pub_port
+ self.broker_host: str = broker_host
+ self.log_pull_port: int = log_pull_port
+ self.log_host: str = log_host
+ self.outgoing_links: list[LinkInfo] = outgoing_links
+ self.incoming_links: list[LinkInfo] = incoming_links
+ self.incoming_sockets: dict[str, zmq.Socket] = dict()
+
+ # Redefine dictionary of actions for the RunManager
+ self.actions = {"setup": self.setup, "run": self.run_step, "stop": self.stop}
+
+ def register_with_nexus(self):
+ self.improv_logger.info(f"Actor {self.name} registering with nexus")
+ self.zmq_context = zmq.Context()
+ self.zmq_context.setsockopt(SocketOption.LINGER, 0)
+ # create a REQ socket pointed at nexus' global actor in port
+ self.nexus_comm_socket = self.zmq_context.socket(zmq.REQ)
+ self.nexus_comm_socket.connect(
+ f"tcp://{self.nexus_host}:{self.nexus_comm_port}"
+ )
+
+ # create a REP socket for comms from Nexus and save its state
+ self.nexus_sig_socket = self.zmq_context.socket(zmq.REP)
+ self.nexus_sig_socket.bind("tcp://*:0") # find any available port
+ sig_socket_addr = self.nexus_sig_socket.getsockopt_string(
+ SocketOption.LAST_ENDPOINT
+ )
+ self.nexus_sig_port = int(sig_socket_addr.split(":")[-1])
+
+ # build and send a message to nexus
+ actor_state = ActorStateMsg(
+ self.name,
+ Signal.waiting(),
+ self.nexus_sig_port,
+ "comms opened and actor ready to initialize",
+ )
+
+ self.nexus_comm_socket.send_pyobj(actor_state)
+
+ rep: ActorStateReplyMsg = self.nexus_comm_socket.recv_pyobj()
+ self.improv_logger.info(
+ f"Actor {self.name} got response from nexus:\n"
+ f"Status: {rep.status}\n"
+ f"Info: {rep.info}\n"
+ )
+
+ self.links["q_comm"] = ZmqLink(self.nexus_comm_socket, f"{self.name}.q_comm")
+ self.links["q_sig"] = ZmqLink(self.nexus_sig_socket, f"{self.name}.q_sig")
+
+ self.improv_logger.info(f"Actor {self.name} registered with Nexus")
+
+ def register_with_broker(self): # really opening sockets here
+ self.improv_logger.info(f"Actor {self.name} registering with broker")
+ if len(self.outgoing_links) > 0:
+ self.broker_pub_socket = self.zmq_context.socket(zmq.PUB)
+ self.broker_pub_socket.connect(
+ f"tcp://{self.broker_host}:{self.broker_sub_port}"
+ )
+
+ for incoming_link in self.incoming_links:
+ new_socket: zmq.Socket = self.zmq_context.socket(zmq.SUB)
+ new_socket.connect(f"tcp://{self.broker_host}:{self.broker_pub_port}")
+ new_socket.subscribe(incoming_link.link_topic)
+ self.incoming_sockets[incoming_link.link_name] = new_socket
+ self.improv_logger.info(f"Actor {self.name} registered with broker")
+
+ def setup_links(self):
+ self.improv_logger.info(f"Actor {self.name} setting up links")
+ for outgoing_link in self.outgoing_links:
+ self.links[outgoing_link.link_name] = ZmqLink(
+ self.broker_pub_socket,
+ outgoing_link.link_name,
+ outgoing_link.link_topic,
+ )
+
+ for incoming_link in self.incoming_links:
+ self.links[incoming_link.link_name] = ZmqLink(
+ self.incoming_sockets[incoming_link.link_name],
+ incoming_link.link_name,
+ incoming_link.link_topic,
+ )
+
+ if "q_out" in self.links.keys():
+ self.q_out = self.links["q_out"]
+ if "q_in" in self.links.keys():
+ self.q_in = self.links["q_in"]
+ self.improv_logger.info(f"Actor {self.name} finished setting up links")
+
+ @property
+ def improv_logger(self):
+ try:
+ return self._logger
+ except AttributeError as e:
+ err_str = f"Caught exception {e} in {self.name}. "
+ "Did you forget to call setup_logging?"
+ logger = logging.getLogger(self.name)
+ logger.error(err_str)
+
+ def setup_logging(self):
+ self._logger = logging.getLogger(self.name)
+ self._logger.setLevel(logging.INFO)
+ for handler in logger.handlers:
+ self._logger.addHandler(handler)
+ self._logger.addHandler(
+ log.ZmqLogHandler(self.log_host, self.log_pull_port, self.zmq_context)
+ )
+
+
class AsyncActor(AbstractActor):
def __init__(self, *args, **kwargs):
super().__init__(*args)
@@ -214,7 +360,7 @@ def __init__(self, *args, **kwargs):
# Define dictionary of actions for the RunManager
self.actions = {}
self.actions["setup"] = self.setup
- self.actions["run"] = self.runStep
+ self.actions["run"] = self.run_step
self.actions["stop"] = self.stop
def run(self):
@@ -225,13 +371,13 @@ def run(self):
return result
async def setup(self):
- """Essenitally the registration process
+ """Essentially the registration process
Can also be an initialization for the actor
options is a list of options, can be empty
"""
pass
- async def runStep(self):
+ async def run_step(self):
raise NotImplementedError
async def stop(self):
@@ -239,17 +385,28 @@ async def stop(self):
# Aliasing
-Actor = ManagedActor
+Actor = ZmqActor
class RunManager:
- def __init__(self, name, actions, links, runStoreInterface=None, timeout=1e-6):
+ def __init__(
+ self,
+ name,
+ actions,
+ links,
+ nexus_sig_port,
+ improv_logger,
+ runStoreInterface=None,
+ timeout=1e-6,
+ ):
self.run = False
self.stop = False
self.config = False
+ self.nexus_sig_port = nexus_sig_port
+ self.improv_logger = improv_logger
self.actorName = name
- logger.debug("RunManager for {} created".format(self.actorName))
+ self.improv_logger.debug("RunManager for {} created".format(self.actorName))
self.actions = actions
self.links = links
@@ -261,66 +418,103 @@ def __init__(self, name, actions, links, runStoreInterface=None, timeout=1e-6):
def __enter__(self):
self.start = time.time()
- an = self.actorName
- while True:
- # Run any actions given a received Signal
- if self.run:
- try:
- self.actions["run"]()
- except Exception as e:
- logger.error("Actor {} error in run: {}".format(an, e))
- logger.error(traceback.format_exc())
- elif self.stop:
- try:
- self.actions["stop"]()
- except Exception as e:
- logger.error("Actor {} error in stop: {}".format(an, e))
- logger.error(traceback.format_exc())
- self.stop = False # Run once
- elif self.config:
- try:
- if self.runStoreInterface:
- self.runStoreInterface()
- self.actions["setup"]()
- self.q_comm.put([Signal.ready()])
- except Exception as e:
- logger.error("Actor {} error in setup: {}".format(an, e))
- logger.error(traceback.format_exc())
- self.config = False
-
- # Check for new Signals received from Nexus
- try:
- signal = self.q_sig.get(timeout=self.timeout)
- logger.debug("{} received Signal {}".format(self.actorName, signal))
- if signal == Signal.run():
- self.run = True
- logger.warning("Received run signal, begin running")
- elif signal == Signal.setup():
- self.config = True
- elif signal == Signal.stop():
- self.run = False
- self.stop = True
- logger.warning(f"actor {self.actorName} received stop signal")
- elif signal == Signal.quit():
- logger.warning("Received quit signal, aborting")
- break
- elif signal == Signal.pause():
- logger.warning("Received pause signal, pending...")
- self.run = False
- elif signal == Signal.resume(): # currently treat as same as run
- logger.warning("Received resume signal, resuming")
- self.run = True
- except KeyboardInterrupt:
- break
- except Empty:
- pass # No signal from Nexus
+ do_loop = True
+ while do_loop:
+ do_loop = self.loop_logic()
return None
+ def loop_logic(self):
+ # broken out into a separate function for use elsewhere
+ an = self.actorName
+ keep_going = True
+
+ # Run any actions given a received Signal
+ if self.run:
+ try:
+ self.actions["run"]()
+ except Exception as e:
+ self.improv_logger.error("Actor {} error in run: {}".format(an, e))
+ self.improv_logger.error(traceback.format_exc())
+ elif self.stop:
+ try:
+ self.actions["stop"]()
+ except Exception as e:
+ self.improv_logger.error("Actor {} error in stop: {}".format(an, e))
+ self.improv_logger.error(traceback.format_exc())
+ self.stop = False # Run once
+ elif self.config:
+ try:
+ if self.runStoreInterface:
+ self.runStoreInterface()
+ self.actions["setup"]()
+ self.q_comm.put(
+ ActorStateMsg(
+ self.actorName, Signal.ready(), self.nexus_sig_port, ""
+ )
+ )
+ res = self.q_comm.get()
+ self.improv_logger.info(
+ f"Actor {res.actor_name} got state update reply:\n"
+ f"Status: {res.status}\n"
+ f"Info: {res.info}\n"
+ )
+ except Exception as e:
+ self.improv_logger.error("Actor {} error in setup: {}".format(an, e))
+ self.improv_logger.error(traceback.format_exc())
+ self.config = False
+
+ # Check for new Signals received from Nexus
+ try:
+ signal_msg = self.q_sig.get(timeout=self.timeout)
+ signal = signal_msg.signal
+ self.q_sig.put(NexusSignalReplyMsg(an, signal, "OK", "OK"))
+ self.improv_logger.info(
+ "{} received signal {}".format(self.actorName, signal)
+ )
+ if signal == Signal.run():
+ self.run = True
+ self.improv_logger.info(
+ f"{self.actorName} received run signal, begin running"
+ )
+ elif signal == Signal.setup():
+ self.config = True
+ elif signal == Signal.stop():
+ self.run = False
+ self.stop = True
+ self.improv_logger.info(f"Actor {self.actorName} received stop signal")
+ elif signal == Signal.quit():
+ self.improv_logger.info(
+ f"{self.actorName} received quit signal, aborting"
+ )
+ keep_going = False
+ elif signal == Signal.pause():
+ self.improv_logger.info(
+ f"{self.actorName} received pause signal, pending..."
+ )
+ self.run = False
+ elif signal == Signal.resume(): # currently treat as same as run
+ self.improv_logger.info(
+ f"{self.actorName} Received resume signal, resuming"
+ )
+ self.run = True
+ elif signal == Signal.status():
+ self.improv_logger.info(f"{self.actorName} received status request")
+ except KeyboardInterrupt:
+ keep_going = False
+ except Empty:
+ pass # No signal from Nexus
+ except TimeoutError:
+ pass # No signal from Nexus over zmq
+
+ return keep_going
+
def __exit__(self, type, value, traceback):
- logger.info("Ran for " + str(time.time() - self.start) + " seconds")
- logger.warning("Exiting RunManager")
+ self.improv_logger.info(
+ f"{self.actorName} ran for " + str(time.time() - self.start) + " seconds"
+ )
+ self.improv_logger.info("Exiting RunManager")
return None
@@ -384,24 +578,24 @@ async def run_actor(self):
# Check for new Signals received from Nexus
try:
signal = self.q_sig.get(timeout=self.timeout)
- logger.debug("{} received Signal {}".format(self.actorName, signal))
+ logger.debug("{} received signal {}".format(self.actorName, signal))
if signal == Signal.run():
self.run = True
- logger.warning("Received run signal, begin running")
+ logger.info("Received run signal, begin running")
elif signal == Signal.setup():
self.config = True
elif signal == Signal.stop():
self.run = False
self.stop = True
- logger.warning(f"actor {self.actorName} received stop signal")
+ logger.info(f"actor {self.actorName} received stop signal")
elif signal == Signal.quit():
- logger.warning("Received quit signal, aborting")
+ logger.info("Received quit signal, aborting")
break
elif signal == Signal.pause():
- logger.warning("Received pause signal, pending...")
+ logger.info("Received pause signal, pending...")
self.run = False
elif signal == Signal.resume(): # currently treat as same as run
- logger.warning("Received resume signal, resuming")
+ logger.info("Received resume signal, resuming")
self.run = True
except KeyboardInterrupt:
break
@@ -416,7 +610,7 @@ async def __aenter__(self):
async def __aexit__(self, type, value, traceback):
logger.info("Ran for {} seconds".format(time.time() - self.start))
- logger.warning("Exiting AsyncRunManager")
+ logger.info("Exiting AsyncRunManager")
return None
@@ -473,3 +667,11 @@ def stop():
@staticmethod
def stop_success():
return "stop success"
+
+ @staticmethod
+ def status():
+ return "status"
+
+ @staticmethod
+ def waiting():
+ return "waiting"
diff --git a/improv/broker.py b/improv/broker.py
new file mode 100644
index 00000000..e9055080
--- /dev/null
+++ b/improv/broker.py
@@ -0,0 +1,131 @@
+from __future__ import annotations
+
+import logging
+import signal
+
+import zmq
+from zmq import SocketOption
+
+from improv.messaging import BrokerInfoMsg
+
+DEBUG = True
+
+local_log = logging.getLogger(__name__)
+
+
+def bootstrap_broker(nexus_hostname, nexus_port):
+ if DEBUG:
+ local_log.addHandler(logging.FileHandler("broker_server.log"))
+ try:
+ broker = PubSubBroker(nexus_hostname, nexus_port)
+ broker.register_with_nexus()
+ broker.serve(broker.read_and_pub_message)
+ except Exception as e:
+ local_log.error(e)
+ for handler in local_log.handlers:
+ handler.close()
+
+
+class PubSubBroker:
+ def __init__(self, nexus_hostname, nexus_comm_port):
+ self.running = True
+ self.nexus_hostname: str = nexus_hostname
+ self.nexus_comm_port: int = nexus_comm_port
+ self.zmq_context: zmq.Context | None = None
+ self.nexus_socket: zmq.Socket | None = None
+ self.pub_port: int
+ self.sub_port: int
+ self.pub_socket: zmq.Socket | None = None
+ self.sub_socket: zmq.Socket | None = None
+
+ signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
+ for s in signals:
+ signal.signal(s, self.stop)
+
+ def register_with_nexus(self):
+ # connect to nexus
+ self.zmq_context = zmq.Context()
+ self.zmq_context.setsockopt(SocketOption.LINGER, 0)
+ self.nexus_socket = self.zmq_context.socket(zmq.REQ)
+ self.nexus_socket.connect(f"tcp://{self.nexus_hostname}:{self.nexus_comm_port}")
+
+ self.sub_socket = self.zmq_context.socket(zmq.SUB)
+ self.sub_socket.bind("tcp://*:0")
+ sub_port_string = self.sub_socket.getsockopt_string(SocketOption.LAST_ENDPOINT)
+ self.sub_port = int(sub_port_string.split(":")[-1])
+ self.sub_socket.subscribe("") # receive all incoming messages
+
+ self.pub_socket = self.zmq_context.socket(zmq.PUB)
+ self.pub_socket.bind("tcp://*:0")
+ pub_port_string = self.pub_socket.getsockopt_string(SocketOption.LAST_ENDPOINT)
+ self.pub_port = int(pub_port_string.split(":")[-1])
+
+ port_info = BrokerInfoMsg(
+ "broker",
+ self.pub_port,
+ self.sub_port,
+ "Ports up and running, ready to serve messages",
+ )
+
+ self.nexus_socket.send_pyobj(port_info)
+
+ local_log.info("broker attempting to get message from nexus")
+ msg_available = 0
+ retries = 3
+ while (retries > 3) and (msg_available == 0):
+ msg_available = self.nexus_socket.poll(timeout=1000)
+ if msg_available == 0:
+ local_log.info(
+ "broker didn't get a reply from nexus. cycling socket and resending"
+ )
+ self.nexus_socket.close(linger=0)
+ self.nexus_socket = self.zmq_context.socket(zmq.REQ)
+ self.nexus_socket.connect(
+ f"tcp://{self.nexus_hostname}:{self.nexus_comm_port}"
+ )
+ self.nexus_socket.send_pyobj(port_info)
+ local_log.info("broker resent message")
+ retries -= 1
+
+ self.nexus_socket.recv_pyobj()
+
+ local_log.info("broker got message back from nexus")
+
+ return
+
+ def serve(self, message_process_func):
+ local_log.info("broker serving")
+ while self.running:
+ # this is more testable but may have a performance overhead
+ message_process_func()
+ self.shutdown()
+
+ def read_and_pub_message(self):
+ try:
+ msg_ready = self.sub_socket.poll(timeout=0)
+ if msg_ready != 0:
+ msg = self.sub_socket.recv_multipart()
+ self.pub_socket.send_multipart(msg)
+ except zmq.error.ZMQError:
+ self.running = False
+
+ def shutdown(self):
+ for handler in local_log.handlers:
+ handler.close()
+
+ if self.sub_socket:
+ self.sub_socket.close(linger=0)
+
+ if self.pub_socket:
+ self.sub_socket.close(linger=0)
+
+ if self.nexus_socket:
+ self.sub_socket.close(linger=0)
+
+ if self.zmq_context:
+ self.zmq_context.destroy(linger=0)
+
+ def stop(self, signum, frame):
+ local_log.info(f"Log server shutting down due to signal {signum}")
+
+ self.running = False
diff --git a/improv/cli.py b/improv/cli.py
index 720d4e2f..d76d0e05 100644
--- a/improv/cli.py
+++ b/improv/cli.py
@@ -8,15 +8,10 @@
import psutil
import time
import datetime
-from zmq import SocketOption
-from zmq.log.handlers import PUBHandler
from improv.tui import TUI
from improv.nexus import Nexus
MAX_PORT = 2**16 - 1
-DEFAULT_CONTROL_PORT = "0"
-DEFAULT_OUTPUT_PORT = "0"
-DEFAULT_LOGGING_PORT = "0"
def file_exists(fname):
@@ -80,23 +75,26 @@ def parse_cli_args(args):
"-c",
"--control-port",
type=is_valid_port,
- default=DEFAULT_CONTROL_PORT,
help="local port on which control are sent to/from server",
)
run_parser.add_argument(
"-o",
"--output-port",
type=is_valid_port,
- default=DEFAULT_OUTPUT_PORT,
help="local port on which server output messages are broadcast",
)
run_parser.add_argument(
"-l",
"--logging-port",
type=is_valid_port,
- default=DEFAULT_LOGGING_PORT,
help="local port on which logging messages are broadcast",
)
+ run_parser.add_argument(
+ "-i",
+ "--logging-input-port",
+ type=is_valid_port,
+ help="address to which logging messages are submitted",
+ )
run_parser.add_argument(
"-f", "--logfile", default="global.log", help="name of log file"
)
@@ -121,23 +119,26 @@ def parse_cli_args(args):
"-c",
"--control-port",
type=is_valid_ip_addr,
- default=DEFAULT_CONTROL_PORT,
help="address on which control signals are sent to the server",
)
client_parser.add_argument(
"-s",
"--server-port",
type=is_valid_ip_addr,
- default=DEFAULT_OUTPUT_PORT,
help="address on which messages from the server are received",
)
client_parser.add_argument(
"-l",
"--logging-port",
type=is_valid_ip_addr,
- default=DEFAULT_LOGGING_PORT,
help="address on which logging messages are broadcast",
)
+ client_parser.add_argument(
+ "-i",
+ "--logging-input-port",
+ type=is_valid_ip_addr,
+ help="address to which logging messages are submitted",
+ )
client_parser.set_defaults(func=run_client)
server_parser = subparsers.add_parser(
@@ -147,23 +148,26 @@ def parse_cli_args(args):
"-c",
"--control-port",
type=is_valid_port,
- default=DEFAULT_CONTROL_PORT,
help="local port on which control signals are received",
)
server_parser.add_argument(
"-o",
"--output-port",
type=is_valid_port,
- default=DEFAULT_OUTPUT_PORT,
help="local port on which output messages are broadcast",
)
server_parser.add_argument(
"-l",
"--logging-port",
type=is_valid_port,
- default=DEFAULT_LOGGING_PORT,
help="local port on which logging messages are broadcast",
)
+ server_parser.add_argument(
+ "-i",
+ "--logging-input-port",
+ type=is_valid_port,
+ help="address to which logging messages are submitted",
+ )
server_parser.add_argument(
"-f", "--logfile", default="global.log", help="name of log file"
)
@@ -203,7 +207,9 @@ def default_invocation():
def run_client(args):
- app = TUI(args.control_port, args.server_port, args.logging_port)
+ app = TUI(
+ args.control_port, args.server_port, args.logging_port, args.logging_input_port
+ )
app.run()
@@ -212,18 +218,11 @@ def run_server(args):
"""
Runs the improv server in headless mode.
"""
- zmq_log_handler = PUBHandler("tcp://*:%s" % args.logging_port)
- # in case we bound to a random port (default), get port number
- logging_port = int(
- zmq_log_handler.socket.getsockopt_string(SocketOption.LAST_ENDPOINT).split(":")[
- -1
- ]
- )
logging.basicConfig(
level=logging.DEBUG,
format="%(name)s %(message)s",
- handlers=[logging.FileHandler(args.logfile), zmq_log_handler],
+ handlers=[logging.FileHandler("improv-debug.log")],
)
if not args.actor_path:
@@ -232,18 +231,26 @@ def run_server(args):
sys.path.extend(args.actor_path)
server = Nexus()
- control_port, output_port = server.createNexus(
+ control_port, output_port, log_port, log_input_port = server.create_nexus(
file=args.configfile,
control_port=args.control_port,
output_port=args.output_port,
+ log_server_pub_port=args.logging_port,
+ log_server_pull_port=args.logging_input_port,
+ logfile=args.logfile,
)
curr_dt = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(
- f"{curr_dt} Server running on (control, output, log) ports "
- f"({control_port}, {output_port}, {logging_port}).\n"
+ f"{curr_dt} Server running on (control, output, log, log input) ports "
+ f"({control_port}, {output_port}, {log_port}, {log_input_port}).\n"
f"Press Ctrl-C to quit."
)
- server.startNexus()
+ try:
+ server.start_nexus()
+ except Exception as e:
+ print(f"CLI-started server run encountered uncaught error {e}")
+ logging.error(f"CLI-started server run encountered uncaught error {e}")
+ raise e
if args.actor_path:
for p in args.actor_path:
@@ -254,7 +261,7 @@ def run_server(args):
def run_list(args, printit=True):
out_list = []
- pattern = re.compile(r"(improv (run|client|server)|plasma_store|redis-server)")
+ pattern = re.compile(r"(improv (run|client|server)|redis-server)")
# mp_pattern = re.compile(r"-c from multiprocessing") # TODO is this right?
for proc in psutil.process_iter(["pid", "name", "cmdline"]):
if proc.info["cmdline"]:
@@ -281,22 +288,38 @@ def run_cleanup(args, headless=False):
if res.lower() == "y":
for proc in proc_list:
- if not proc.status() == psutil.STATUS_STOPPED:
- logging.info(
- f"process {proc.pid} {proc.name()}"
- f" has status {proc.status()}. Interrupting."
- )
- try:
+ try:
+ if not proc.status() == psutil.STATUS_STOPPED:
+ logging.info(
+ f"process {proc.pid} {proc.name()}"
+ f" has status {proc.status()}. Interrupting."
+ )
proc.send_signal(signal.SIGINT)
- except psutil.NoSuchProcess:
- pass
+ except psutil.NoSuchProcess:
+ pass
gone, alive = psutil.wait_procs(proc_list, timeout=3)
for p in alive:
- p.send_signal(signal.SIGINT)
try:
+ p.terminate()
p.wait(timeout=10)
except psutil.TimeoutExpired as e:
logging.warning(f"{e}: Process did not exit on time.")
+ try:
+ p.kill()
+ except psutil.NoSuchProcess as e:
+ logging.warning(
+ f"{e}: Process exited after wait timeout"
+ f" but before kill signal attempted."
+ )
+ # this happens sometimes because Nexus uses gracious
+ # timeout periods.
+ except psutil.NoSuchProcess as e:
+ logging.warning(
+ f"{e}: Process exited after wait timeout"
+ f" but before kill signal attempted."
+ )
+ # this happens sometimes because Nexus uses gracious
+ # timeout periods.
else:
if not headless:
@@ -313,32 +336,55 @@ def run(args, timeout=10):
server_opts = [
"improv",
"server",
- "-c",
- str(args.control_port),
- "-o",
- str(args.output_port),
- "-l",
- str(args.logging_port),
"-f",
args.logfile,
]
+
+ if args.control_port:
+ server_opts.append("-c")
+ server_opts.append(str(args.control_port))
+
+ if args.output_port:
+ server_opts.append("-o")
+ server_opts.append(str(args.output_port))
+
+ if args.logging_port:
+ server_opts.append("-l")
+ server_opts.append(str(args.logging_port))
+
+ if args.logging_port:
+ server_opts.append("-i")
+ server_opts.append(str(args.logging_input_port))
+
server_opts.extend(apath_opts)
server_opts.append(args.configfile)
- with open(args.logfile, mode="a+") as logfile:
+ print(" ".join(server_opts))
+
+ with open("improv-debug.log", mode="a+") as logfile:
server = subprocess.Popen(server_opts, stdout=logfile, stderr=logfile)
# wait for server to start up
- ports = get_server_ports(args, timeout)
- if ports:
- control_port, output_port, logging_port = ports
- args.logging_port = logging_port
- args.control_port = control_port
- args.server_port = output_port
- run_client(args)
+ curr_dt = datetime.datetime.now().replace(microsecond=0)
+ ports = None
+ while not ports:
+ ports = get_server_ports(args, timeout, curr_dt)
+ if ports:
+ control_port, output_port, logging_port, logging_input_port = ports
+ args.logging_port = logging_port
+ args.logging_input_port = logging_input_port
+ args.control_port = control_port
+ args.server_port = output_port
+ run_client(args)
+ else:
+ reply = input("Do you want to keep waiting? (y/N) ")
+ if not reply.lower() == "y":
+ break
try:
- server.wait(timeout=2)
+ wait_timeout = 60
+ print(f"Waiting {wait_timeout} seconds for Nexus to complete shutdown.")
+ server.wait(timeout=wait_timeout)
except subprocess.TimeoutExpired:
print("Cleaning up the hard way. May have exited dirty.")
server.terminate()
@@ -346,17 +392,14 @@ def run(args, timeout=10):
run_cleanup(args, headless=True)
-def get_server_ports(args, timeout):
- # save current datetime so we can see when server has started up
- curr_dt = datetime.datetime.now().replace(microsecond=0)
-
+def get_server_ports(args, timeout, curr_dt):
increment = 0.05
time_now = 0
ports = None
while time_now < timeout:
- server_start_time = _server_start_logged(args.logfile)
+ server_start_time = _server_start_logged("improv-debug.log")
if server_start_time and server_start_time >= curr_dt:
- ports = _get_ports(args.logfile)
+ ports = _get_ports("improv-debug.log")
if ports:
break
@@ -365,12 +408,12 @@ def get_server_ports(args, timeout):
if not server_start_time:
print(
- f"Unable to read server start time from {args.logfile}.\n"
+ f"Unable to read server start time from {'improv-debug.log'}.\n"
"This may be because the server could not be started or "
"did not log its activity."
)
elif not ports:
- print(f"Unable to read ports from {args.logfile}.")
+ print(f"Unable to read ports from {'improv-debug.log'}.")
return ports
@@ -393,13 +436,18 @@ def _get_ports(logfile):
# read logfile to get ports
with open(logfile, mode="r") as logfile:
contents = logfile.read()
+ return _read_log_contents_for_ports(contents)
- pattern = re.compile(r"(?<=\(control, output, log\) ports \()\d*, \d*, \d*")
- # get most recent match (log file may contain old runs)
- port_str_list = pattern.findall(contents)
- if port_str_list:
- port_str = port_str_list[-1]
- return (int(p) for p in port_str.split(", "))
- else:
- return None
+def _read_log_contents_for_ports(logfile_contents):
+ pattern = re.compile(
+ r"(?<=\(control, output, log, log input\) ports \()\d*, \d*, \d*, \d*"
+ )
+
+ # get most recent match (log file may contain old runs)
+ port_str_list = pattern.findall(logfile_contents)
+ if port_str_list:
+ port_str = port_str_list[-1]
+ return (int(p) for p in port_str.split(", "))
+ else:
+ return None
diff --git a/improv/config.py b/improv/config.py
index 5f73c18b..d897d837 100644
--- a/improv/config.py
+++ b/improv/config.py
@@ -6,46 +6,48 @@
logger = logging.getLogger(__name__)
+class CannotCreateConfigException(Exception):
+ def __init__(self, msg):
+ super().__init__("Cannot create config: {}".format(msg))
+
+
class Config:
"""Handles configuration and logs of configs for
the entire server/processing pipeline.
"""
- def __init__(self, configFile):
- if configFile is None:
- logger.error("Need to specify a config file")
- raise Exception
- else:
- # Reading config from other yaml file
- self.configFile = configFile
-
- with open(self.configFile, "r") as ymlfile:
- cfg = yaml.safe_load(ymlfile)
-
- try:
- if "settings" in cfg:
- self.settings = cfg["settings"]
- else:
- self.settings = {}
+ def __init__(self, config_file):
+ self.actors = {}
+ self.connections = {}
+ self.hasGUI = False
+ self.config_file = config_file
- if "use_watcher" not in self.settings:
- self.settings["use_watcher"] = False
+ with open(self.config_file, "r") as ymlfile:
+ self.config = yaml.safe_load(ymlfile)
- except TypeError:
- if cfg is None:
- logger.error("Error: The config file is empty")
+ if self.config is None:
+ logger.error("The config file is empty")
+ raise CannotCreateConfigException("The config file is empty")
- if type(cfg) is not dict:
+ if type(self.config) is not dict:
logger.error("Error: The config file is not in dictionary format")
raise TypeError
- self.config = cfg
+ def parse_config(self):
+ self.populate_defaults()
+ self.validate_config()
- self.actors = {}
- self.connections = {}
- self.hasGUI = False
+ self.settings = self.config["settings"]
+ self.redis_config = self.config["redis_config"]
+
+ def populate_defaults(self):
+ self.populate_settings_defaults()
+ self.populate_redis_defaults()
- def createConfig(self):
+ def validate_config(self):
+ self.validate_redis_config()
+
+ def create_config(self):
"""Read yaml config file and create config for Nexus
TODO: check for config file compliance, error handle it
beyond what we have below.
@@ -53,9 +55,6 @@ def createConfig(self):
cfg = self.config
for name, actor in cfg["actors"].items():
- if name in self.actors.keys():
- raise RepeatedActorError(name)
-
packagename = actor.pop("package")
classname = actor.pop("class")
@@ -63,11 +62,15 @@ def createConfig(self):
__import__(packagename, fromlist=[classname])
mod = import_module(packagename)
- clss = getattr(mod, classname)
- sig = signature(clss)
- configModule = ConfigModule(name, packagename, classname, options=actor)
- sig.bind(configModule.options)
+ actor_class = getattr(mod, classname)
+ sig = signature(actor_class)
+ config_module = ConfigModule(
+ name, packagename, classname, options=actor
+ )
+ sig.bind(config_module.options)
+ # TODO: this is not trivial to test, since our code formatting
+ # tools won't allow a file with a syntax error to exist
except SyntaxError as e:
logger.error(f"Error: syntax error when initializing actor {name}: {e}")
return -1
@@ -87,101 +90,120 @@ def createConfig(self):
except TypeError:
logger.error("Error: Invalid arguments passed")
params = ""
- for parameter in sig.parameters:
- params = params + " " + parameter.name
+ for param_name, param in sig.parameters.items():
+ params = params + ", " + param.name
logger.warning("Expected Parameters:" + params)
return -1
- except Exception as e:
+ except Exception as e: # TODO: figure out how to test this
logger.error(f"Error: {e}")
return -1
- if "GUI" in name:
- logger.info(f"Config detected a GUI actor: {name}")
- self.hasGUI = True
- self.gui = configModule
- else:
- self.actors.update({name: configModule})
+ self.actors.update({name: config_module})
for name, conn in cfg["connections"].items():
- if name in self.connections.keys():
- raise RepeatedConnectionsError(name)
-
self.connections.update({name: conn})
- if "datastore" in cfg.keys():
- self.datastore = cfg["datastore"]
-
return 0
- def addParams(self, type, param):
- """Function to add paramter param of type type
- TODO: Future work
- """
- pass
-
- def saveActors(self):
+ def save_actors(self):
"""Saves the actors config to a specific file."""
wflag = True
- saveFile = self.configFile.split(".")[0]
+ saveFile = self.config_file.split(".")[0]
pathName = saveFile + "_actors.yaml"
for a in self.actors.values():
- wflag = a.saveConfigModules(pathName, wflag)
-
- def use_plasma(self):
- return "plasma_config" in self.config.keys()
-
- def get_redis_port(self):
- if self.redis_port_specified():
- return self.config["redis_config"]["port"]
- else:
- return Config.get_default_redis_port()
-
- def redis_port_specified(self):
- if "redis_config" in self.config.keys():
- return "port" in self.config["redis_config"]
- return False
-
- def redis_saving_enabled(self):
- if "redis_config" in self.config.keys():
- return (
- self.config["redis_config"]["enable_saving"]
- if "enable_saving" in self.config["redis_config"]
- else None
+ wflag = a.save_config_modules(pathName, wflag)
+
+ def populate_settings_defaults(self):
+ if "settings" not in self.config:
+ self.config["settings"] = {}
+
+ if "store_size" not in self.config["settings"]:
+ self.config["settings"]["store_size"] = 250_000_000
+ if "control_port" not in self.config["settings"]:
+ self.config["settings"]["control_port"] = 5555
+ if "output_port" not in self.config["settings"]:
+ self.config["settings"]["output_port"] = 5556
+ if "logging_port" not in self.config["settings"]:
+ self.config["settings"]["logging_port"] = 5557
+ if "logging_input_port" not in self.config["settings"]:
+ self.config["settings"]["logging_input_port"] = 5558
+ if "harvest_data_from_memory" not in self.config["settings"]:
+ self.config["settings"]["harvest_data_from_memory"] = None
+
+ def populate_redis_defaults(self):
+ if "redis_config" not in self.config:
+ self.config["redis_config"] = {}
+
+ if "enable_saving" not in self.config["redis_config"]:
+ self.config["redis_config"]["enable_saving"] = None
+ if "aof_dirname" not in self.config["redis_config"]:
+ self.config["redis_config"]["aof_dirname"] = None
+ if "generate_ephemeral_aof_dirname" not in self.config["redis_config"]:
+ self.config["redis_config"]["generate_ephemeral_aof_dirname"] = False
+ if "fsync_frequency" not in self.config["redis_config"]:
+ self.config["redis_config"]["fsync_frequency"] = None
+
+ # enable saving automatically if the user configured a saving option
+ if (
+ self.config["redis_config"]["aof_dirname"]
+ or self.config["redis_config"]["generate_ephemeral_aof_dirname"]
+ or self.config["redis_config"]["fsync_frequency"]
+ ) and self.config["redis_config"]["enable_saving"] is None:
+ self.config["redis_config"]["enable_saving"] = True
+
+ if "port" not in self.config["redis_config"]:
+ self.config["redis_config"]["port"] = 6379
+
+ def validate_redis_config(self):
+ fsync_name_dict = {
+ "every_write": "always",
+ "every_second": "everysec",
+ "no_schedule": "no",
+ }
+ if (
+ self.config["redis_config"]["aof_dirname"]
+ and self.config["redis_config"]["generate_ephemeral_aof_dirname"]
+ ):
+ logger.error(
+ "Cannot both generate a unique dirname and use the one provided."
)
-
- def generate_ephemeral_aof_dirname(self):
- if "redis_config" in self.config.keys():
- return (
- self.config["redis_config"]["generate_ephemeral_aof_dirname"]
- if "generate_ephemeral_aof_dirname" in self.config["redis_config"]
- else None
- )
- return False
-
- def get_redis_aof_dirname(self):
- if "redis_config" in self.config.keys():
- return (
- self.config["redis_config"]["aof_dirname"]
- if "aof_dirname" in self.config["redis_config"]
- else None
+ raise Exception("Cannot use unique dirname and use the one provided.")
+
+ if (
+ self.config["redis_config"]["aof_dirname"]
+ or self.config["redis_config"]["generate_ephemeral_aof_dirname"]
+ or self.config["redis_config"]["fsync_frequency"]
+ ):
+ if not self.config["redis_config"]["enable_saving"]:
+ logger.error(
+ "Invalid configuration. Cannot save to disk with saving disabled."
+ )
+ raise Exception("Cannot persist to disk with saving disabled.")
+
+ if self.config["redis_config"]["fsync_frequency"] and self.config[
+ "redis_config"
+ ]["fsync_frequency"] not in [
+ "every_write",
+ "every_second",
+ "no_schedule",
+ ]:
+ logger.error(
+ f"Cannot use unknown fsync frequency "
+ f'{self.config["redis_config"]["fsync_frequency"]}'
)
- return None
-
- def get_redis_fsync_frequency(self):
- if "redis_config" in self.config.keys():
- frequency = (
- self.config["redis_config"]["fsync_frequency"]
- if "fsync_frequency" in self.config["redis_config"]
- else None
+ raise Exception(
+ f"Cannot use unknown fsync frequency "
+ f'{self.config["redis_config"]["fsync_frequency"]}'
)
- return frequency
+ if self.config["redis_config"]["fsync_frequency"] is None:
+ self.config["redis_config"]["fsync_frequency"] = "no_schedule"
- @staticmethod
- def get_default_redis_port():
- return "6379"
+ self.config["redis_config"]["fsync_frequency"] = (fsync_name_dict)[
+ self.config["redis_config"]["fsync_frequency"]
+ ]
class ConfigModule:
@@ -191,11 +213,11 @@ def __init__(self, name, packagename, classname, options=None):
self.classname = classname
self.options = options
- def saveConfigModules(self, pathName, wflag):
+ def save_config_modules(self, path_name, wflag):
"""Loops through each actor to save the modules to the config file.
Args:
- pathName:
+ path_name:
wflag (bool):
Returns:
@@ -213,32 +235,7 @@ def saveConfigModules(self, pathName, wflag):
for key, value in self.options.items():
cfg[self.name].update({key: value})
- with open(pathName, writeOption) as file:
+ with open(path_name, writeOption) as file:
yaml.dump(cfg, file)
return wflag
-
-
-class RepeatedActorError(Exception):
- def __init__(self, repeat):
- super().__init__()
-
- self.name = "RepeatedActorError"
- self.repeat = repeat
-
- self.message = 'Actor name has already been used: "{}"'.format(repeat)
-
- def __str__(self):
- return self.message
-
-
-class RepeatedConnectionsError(Exception):
- def __init__(self, repeat):
- super().__init__()
- self.name = "RepeatedConnectionsError"
- self.repeat = repeat
-
- self.message = 'Connection name has already been used: "{}"'.format(repeat)
-
- def __str__(self):
- return self.message
diff --git a/improv/harvester.py b/improv/harvester.py
new file mode 100644
index 00000000..ba66bdf6
--- /dev/null
+++ b/improv/harvester.py
@@ -0,0 +1,155 @@
+from __future__ import annotations
+
+import logging
+import signal
+import time
+
+import zmq
+
+from improv.link import ZmqLink
+from improv.log import ZmqLogHandler
+from improv.store import RedisStoreInterface
+from zmq import SocketOption
+
+from improv.messaging import HarvesterInfoMsg
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
+
+def bootstrap_harvester(
+ nexus_hostname,
+ nexus_port,
+ redis_hostname,
+ redis_port,
+ broker_hostname,
+ broker_port,
+ logger_hostname,
+ logger_port,
+):
+ harvester = RedisHarvester(
+ nexus_hostname,
+ nexus_port,
+ redis_hostname,
+ redis_port,
+ broker_hostname,
+ broker_port,
+ logger_hostname,
+ logger_port,
+ )
+ harvester.establish_connections()
+ harvester.register_with_nexus()
+ harvester.serve(harvester.collect)
+
+
+class RedisHarvester:
+ def __init__(
+ self,
+ nexus_hostname,
+ nexus_comm_port,
+ redis_hostname,
+ redis_port,
+ broker_hostname,
+ broker_port,
+ logger_hostname,
+ logger_port,
+ ):
+ self.link: ZmqLink | None = None
+ self.running = True
+ self.nexus_hostname: str = nexus_hostname
+ self.nexus_comm_port: int = nexus_comm_port
+ self.redis_hostname: str = redis_hostname
+ self.redis_port: int = redis_port
+ self.broker_hostname: str = broker_hostname
+ self.broker_port: int = broker_port
+ self.zmq_context: zmq.Context | None = None
+ self.nexus_socket: zmq.Socket | None = None
+ self.sub_port: int | None = None
+ self.sub_socket: zmq.Socket | None = None
+ self.store_client: RedisStoreInterface | None = None
+ self.logger_hostname: str = logger_hostname
+ self.logger_port: int = logger_port
+
+ logger.addHandler(ZmqLogHandler(logger_hostname, logger_port))
+
+ signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
+ for s in signals:
+ signal.signal(s, self.stop)
+
+ def establish_connections(self):
+ logger.info("Registering with Nexus")
+ # connect to nexus
+ self.zmq_context = zmq.Context()
+ self.zmq_context.setsockopt(SocketOption.LINGER, 0)
+ self.nexus_socket = self.zmq_context.socket(zmq.REQ)
+ self.nexus_socket.connect(f"tcp://{self.nexus_hostname}:{self.nexus_comm_port}")
+
+ self.sub_socket = self.zmq_context.socket(zmq.SUB)
+ self.sub_socket.connect(f"tcp://{self.broker_hostname}:{self.broker_port}")
+ sub_port_string = self.sub_socket.getsockopt_string(SocketOption.LAST_ENDPOINT)
+ self.sub_port = int(sub_port_string.split(":")[-1])
+ self.sub_socket.subscribe("") # receive all incoming messages
+
+ self.store_client = RedisStoreInterface(
+ "harvester", self.redis_port, self.redis_hostname
+ )
+
+ self.link = ZmqLink(self.sub_socket, "harvester", "")
+
+ # TODO: this is a small function that is currently easy to verify
+ # it could be unit-tested with a multiprocess that spins up a socket
+ # which this can communicate with, but isn't currenlty worth the time
+ def register_with_nexus(self):
+ port_info = HarvesterInfoMsg(
+ "harvester",
+ "Ports up and running, ready to serve messages",
+ )
+
+ self.nexus_socket.send_pyobj(port_info)
+ self.nexus_socket.recv_pyobj()
+
+ return
+
+ def serve(self, message_process_func, *args, **kwargs):
+ logger.info("Harvester beginning harvest")
+ while self.running:
+ message_process_func(*args, **kwargs)
+ self.shutdown()
+
+ def collect(self, *args, **kwargs):
+ db_info = self.store_client.client.info()
+ max_memory = db_info["maxmemory"]
+ used_memory = db_info["used_memory"]
+ used_max_ratio = used_memory / max_memory
+ if used_max_ratio > 0.75:
+ while self.running and (used_max_ratio > 0.50):
+ try:
+ key = self.link.get(timeout=100) # 100ms
+ self.store_client.client.delete(key)
+ except TimeoutError:
+ pass
+ db_info = self.store_client.client.info()
+ max_memory = db_info["maxmemory"]
+ used_memory = db_info["used_memory"]
+ used_max_ratio = used_memory / max_memory
+ time.sleep(0.1)
+ return
+
+ def shutdown(self):
+ for handler in logger.handlers:
+ if isinstance(handler, ZmqLogHandler):
+ handler.close()
+ logger.removeHandler(handler)
+
+ if self.sub_socket:
+ self.sub_socket.close(linger=0)
+
+ if self.nexus_socket:
+ self.nexus_socket.close(linger=0)
+
+ if self.zmq_context:
+ self.zmq_context.destroy(linger=0)
+
+ def stop(self, signum, frame):
+ self.running = False
+ logger.info(f"Harvester shutting down due to signal {signum}")
diff --git a/improv/link.py b/improv/link.py
index 5014204b..f0af0a0f 100644
--- a/improv/link.py
+++ b/improv/link.py
@@ -1,91 +1,27 @@
import asyncio
+import json
import logging
-from multiprocessing import Manager, cpu_count
+from multiprocessing import cpu_count
from concurrent.futures import ThreadPoolExecutor
-from concurrent.futures._base import CancelledError
+
+import zmq
+from zmq import SocketOption
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
-def Link(name, start, end):
- """Function to construct a queue that Nexus uses for
- inter-process (actor) signaling and information passing.
-
- A Link has an internal queue that can be synchronous (put, get)
- as inherited from multiprocessing.Manager.Queue
- or asynchronous (put_async, get_async) using async executors.
-
- Args:
- See AsyncQueue constructor
-
- Returns:
- AsyncQueue: queue for communicating between actors and with Nexus
- """
-
- m = Manager()
- q = AsyncQueue(m.Queue(maxsize=0), name, start, end)
- return q
-
-
-class AsyncQueue(object):
- """Single-output and asynchronous queue class.
-
- Attributes:
- queue:
- real_executor:
- cancelled_join: boolean
- name:
- start:
- end:
- status:
- result:
- num:
- dict:
- """
-
- def __init__(self, q, name, start, end):
- """Constructor for the queue class.
-
- Args:
- q (Queue): A queue from the Manager class
- name (str): String description of this queue
- start (str): The producer (input) actor name for the queue
- end (str): The consumer (output) actor name for the queue
-
- """
- self.queue = q
- self.real_executor = None
- self.cancelled_join = False
-
+class ZmqLink:
+ def __init__(self, socket: zmq.Socket, name, topic=None):
self.name = name
- self.start = start
- self.end = end
-
+ self.real_executor = None
+ self.socket = socket # this must already be set up
+ self.socket_type = self.socket.getsockopt(SocketOption.TYPE)
+ if self.socket_type == zmq.PUB and topic is None:
+ raise Exception("Cannot open PUB link without topic")
self.status = "pending"
self.result = None
-
- def getStart(self):
- """Gets the starting actor.
-
- The starting actor is the actor that is at the tail of the link.
- This actor is the one that gives output.
-
- Returns:
- start (str): The starting actor name
- """
- return self.start
-
- def getEnd(self):
- """Gets the ending actor.
-
- The ending actor is the actor that is at the head of the link.
- This actor is the one that takes input.
-
- Returns:
- end (str): The ending actor name
- """
- return self.end
+ self.topic = topic
@property
def _executor(self):
@@ -93,195 +29,68 @@ def _executor(self):
self.real_executor = ThreadPoolExecutor(max_workers=cpu_count())
return self.real_executor
- def __getstate__(self):
- """Gets a dictionary of attributes.
-
- This function gets a dictionary, with keys being the names of
- the attributes, and values being the values of the attributes.
+ def get(self, timeout=None):
+ if timeout is None:
+ if self.socket_type == zmq.SUB:
+ res = self.socket.recv_multipart()
+ return json.loads(res[1].decode("utf-8"))
- Returns:
- self_dict (dict): A dictionary containing attributes.
- """
- self_dict = self.__dict__
- self_dict["_real_executor"] = None
- return self_dict
+ return self.socket.recv_pyobj()
- def __getattr__(self, name):
- """Gets the attribute specified by "name".
+ msg_ready = self.socket.poll(timeout=timeout)
+ if msg_ready == 0:
+ raise TimeoutError
- Args:
- name (str): Name of the attribute to be returned.
+ if self.socket_type == zmq.SUB:
+ res = self.socket.recv_multipart()
+ return json.loads(res[1])
- Raises:
- AttributeError: Restricts the available attributes to a
- specific list. This error is raised if a different attribute
- of the queue is requested.
+ return self.socket.recv_pyobj()
- TODO:
- Don't raise this?
+ def get_nowait(self):
+ return self.get(timeout=0)
- Returns:
- (object): Value of the attribute specified by "name".
- """
- if name in ["qsize", "empty", "full", "get", "get_nowait", "close"]:
- return getattr(self.queue, name)
+ def put(
+ self, item
+ ): # TODO: is it a problem we're implicitly handling inputs differently here?
+ if self.socket_type == zmq.PUB:
+ self.socket.send_multipart(
+ [self.topic.encode("utf-8"), json.dumps(item).encode("utf-8")]
+ )
else:
- cn = self.__class__.__name__
- raise AttributeError("{} object has no attribute {}".format(cn, name))
-
- def __repr__(self):
- """String representation for Link.
-
- Returns:
- (str): "Link" followed by the name given in the constructor.
- """
- return "Link " + self.name
+ self.socket.send_pyobj(item)
- def put(self, item):
- """Function wrapper for put.
-
- Args:
- item (object): Any item that can be sent through a queue
+ def send(self, item):
"""
- self.queue.put(item)
-
- def put_nowait(self, item):
- """Function wrapper for put without waiting
-
- Args:
- item (object): Any item that can be sent through a queue
+ This combines put and get in the case of a REQ/REP socket to ensure links aren't
+ left in an inconsistent state.
"""
- self.queue.put_nowait(item)
+ self.put(item)
+ if self.socket_type == zmq.SUB:
+ return None
+ else:
+ return self.get()
async def put_async(self, item):
- """Coroutine for an asynchronous put
-
- It adds the put request to the event loop and awaits.
-
- Args:
- item (object): Any item that can be sent through a queue
-
- Returns:
- Awaitable or result of the put
- """
loop = asyncio.get_event_loop()
- try:
- res = await loop.run_in_executor(self._executor, self.put, item)
- return res
- except EOFError:
- logger.warn("Link probably killed (EOF)")
- except FileNotFoundError:
- logger.warn("probably killed (file not found)")
+ res = await loop.run_in_executor(self._executor, self.put, item)
+ return res
async def get_async(self):
- """Coroutine for an asynchronous get
-
- It adds the get request to the event loop and awaits, setting
- the status to pending. Once the get has returned, it returns the
- result of the get and sets its status as done.
-
- Explicitly passes any exceptions to not hinder execution.
- Errors are logged with the get_async tag.
-
- Returns:
- Awaitable or result of the get.
-
- Raises:
- CancelledError: task is cancelled
- EOFError:
- FileNotFoundError:
- Exception:
- """
loop = asyncio.get_event_loop()
self.status = "pending"
- try:
- self.result = await loop.run_in_executor(self._executor, self.get)
- self.status = "done"
- return self.result
- except CancelledError:
- logger.info("Task {} Canceled".format(self.name))
- except EOFError:
- logger.info("probably killed")
- except FileNotFoundError:
- logger.info("probably killed")
- except Exception as e:
- logger.exception("Error in get_async: {}".format(e))
-
- def cancel_join_thread(self):
- """Function wrapper for cancel_join_thread."""
- self._cancelled_join = True
- self._queue.cancel_join_thread()
-
- def join_thread(self):
- """Function wrapper for join_thread."""
- self._queue.join_thread()
- if self._real_executor and not self._cancelled_join:
- self._real_executor.shutdown()
-
-
-def MultiLink(name, start, end):
- """Function to generate links for the multi-output queue case.
-
- Args:
- See constructor for AsyncQueue or MultiAsyncQueue
-
- Returns:
- MultiAsyncQueue: Producer end of the queue
- List: AsyncQueues for consumers
- """
- m = Manager()
-
- q_out = []
- for endpoint in end:
- q = AsyncQueue(m.Queue(maxsize=0), name, start, endpoint)
- q_out.append(q)
-
- q = MultiAsyncQueue(m.Queue(maxsize=0), q_out, name, start, end)
-
- return q, q_out
-
-
-class MultiAsyncQueue(AsyncQueue):
- """Extension of AsyncQueue to have multiple endpoints.
-
- Inherits from AsyncQueue.
- A single producer queue's 'put' is copied to multiple consumer's
- queues, q_in is the producer queue, q_out are the consumer queues.
-
- TODO:
- Test the async nature of this group of queues
- """
-
- def __init__(self, q_in, q_out, name, start, end):
- self.queue = q_in
- self.output = q_out
-
- self.real_executor = None
- self.cancelled_join = False
-
- self.name = name
- self.start = start
- self.end = end[0]
- self.status = "pending"
- self.result = None
-
- def __repr__(self):
- return "MultiLink " + self.name
-
- def __getattr__(self, name):
- # Remove put and put_nowait and define behavior specifically
- # TODO: remove get capability?
- if name in ["qsize", "empty", "full", "get", "get_nowait", "close"]:
- return getattr(self.queue, name)
- else:
- raise AttributeError(
- "'%s' object has no attribute '%s'" % (self.__class__.__name__, name)
- )
-
- def put(self, item):
- for q in self.output:
- q.put(item)
-
- def put_nowait(self, item):
- for q in self.output:
- q.put_nowait(item)
+ # try:
+ self.result = await loop.run_in_executor(self._executor, self.get)
+ self.status = "done"
+ return self.result
+ # TODO: explicitly commenting these out because testing them is hard.
+ # It's better to let them bubble up so that we don't miss them
+ # due to being caught without a clear reproducible mechanism
+ # except CancelledError:
+ # logger.info("Task {} Canceled".format(self.name))
+ # except EOFError:
+ # logger.info("probably killed")
+ # except FileNotFoundError:
+ # logger.info("probably killed")
+ # except Exception as e:
+ # logger.exception("Error in get_async: {}".format(e))
diff --git a/improv/log.py b/improv/log.py
new file mode 100644
index 00000000..6369edb2
--- /dev/null
+++ b/improv/log.py
@@ -0,0 +1,182 @@
+from __future__ import annotations
+
+import logging
+import signal
+from logging import handlers
+from logging.handlers import QueueHandler
+
+import zmq
+from zmq import SocketOption
+from zmq.log.handlers import PUBHandler
+
+from improv.messaging import LogInfoMsg
+
+local_log = logging.getLogger(__name__)
+
+DEBUG = True
+
+# TODO: ideally there should be some kind of drain at shutdown time
+# so we don't miss any log messages, but that would make shutdown
+# also take longer. TBD?
+
+
+def bootstrap_log_server(
+ nexus_hostname,
+ nexus_port,
+ log_filename="global.log",
+ logger_pub_port=None,
+ logger_pull_port=None,
+):
+ if DEBUG:
+ local_log.addHandler(logging.FileHandler("log_server.log"))
+ local_log.setLevel(logging.DEBUG)
+ try:
+ log_server = LogServer(
+ nexus_hostname, nexus_port, log_filename, logger_pub_port, logger_pull_port
+ )
+ log_server.register_with_nexus()
+ log_server.serve(log_server.read_and_log_message)
+ except Exception as e:
+ local_log.error(e)
+ for handler in local_log.handlers:
+ handler.close()
+
+
+class ZmqPullListener(handlers.QueueListener):
+ def __init__(self, ctx, pull_port, /, *handlers, **kwargs):
+ self.sentinel = False
+ self.ctx = ctx
+ self.listen_port = pull_port if pull_port else 0
+ self.pull_socket = self.ctx.socket(zmq.PULL)
+ self.pull_socket.bind(f"tcp://*:{self.listen_port}")
+ pull_port_string = self.pull_socket.getsockopt_string(
+ SocketOption.LAST_ENDPOINT
+ )
+ self.pull_port = int(pull_port_string.split(":")[-1])
+ super().__init__(self.pull_socket, *handlers, **kwargs)
+
+ def dequeue(self, block=True):
+ msg = None
+ while msg is None:
+ if self.sentinel:
+ return handlers.QueueListener._sentinel
+ msg_ready = self.queue.poll(timeout=1000)
+ if msg_ready != 0:
+ msg = self.queue.recv_json()
+ return logging.makeLogRecord(msg)
+
+ def enqueue_sentinel(self):
+ self.sentinel = True
+
+
+class ZmqLogHandler(QueueHandler):
+ def __init__(self, hostname, port, ctx=None):
+ self.ctx = ctx if ctx else zmq.Context()
+ self.ctx.setsockopt(SocketOption.LINGER, 0)
+ self.socket = self.ctx.socket(zmq.PUSH)
+ self.socket.connect(f"tcp://{hostname}:{port}")
+ super().__init__(self.socket)
+
+ def enqueue(self, record):
+ self.queue.send_json(record.__dict__)
+
+ def close(self):
+ self.queue.close(linger=0)
+
+
+class LogServer:
+ def __init__(
+ self, nexus_hostname, nexus_comm_port, log_filename, pub_port, pull_port
+ ):
+ self.running = True
+ self.pub_port: int | None = pub_port if pub_port else 0
+ self.pull_port: int | None = pull_port if pull_port else 0
+ self.pub_socket: zmq.Socket | None = None
+ self.log_filename = log_filename
+ self.nexus_hostname: str = nexus_hostname
+ self.nexus_comm_port: int = nexus_comm_port
+ self.zmq_context: zmq.Context | None = None
+ self.nexus_socket: zmq.Socket | None = None
+ self.pull_socket: zmq.Socket | None = None
+ self.listener: ZmqPullListener | None = None
+
+ signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
+ for s in signals:
+ signal.signal(s, self.stop)
+
+ def register_with_nexus(self):
+ # connect to nexus
+ self.zmq_context = zmq.Context()
+ self.zmq_context.setsockopt(SocketOption.LINGER, 0)
+ self.nexus_socket = self.zmq_context.socket(zmq.REQ)
+ self.nexus_socket.connect(f"tcp://{self.nexus_hostname}:{self.nexus_comm_port}")
+
+ self.pub_socket = self.zmq_context.socket(zmq.PUB)
+ self.pub_socket.bind(f"tcp://*:{self.pub_port}")
+ pub_port_string = self.pub_socket.getsockopt_string(SocketOption.LAST_ENDPOINT)
+ self.pub_port = int(pub_port_string.split(":")[-1])
+
+ self.listener = ZmqPullListener(
+ self.zmq_context,
+ self.pull_port,
+ logging.FileHandler(self.log_filename),
+ PUBHandler(self.pub_socket, self.zmq_context, "nexus_logging"),
+ )
+
+ self.listener.start()
+
+ local_log.info("logger started listening")
+
+ port_info = LogInfoMsg(
+ "broker",
+ self.listener.pull_port,
+ self.pub_port,
+ "Port up and running, ready to log messages",
+ )
+
+ self.nexus_socket.send_pyobj(port_info)
+ local_log.info("logger sent message to nexus")
+ self.nexus_socket.recv_pyobj()
+
+ local_log.info("logger got message from nexus")
+
+ return
+
+ def serve(self, log_func):
+ local_log.info("logger serving")
+ while self.running:
+ log_func() # this is more testable but may have a performance overhead
+ self.shutdown()
+
+ def read_and_log_message(self): # receive and send back out
+ pass
+
+ def shutdown(self):
+ if self.listener:
+ self.listener.stop()
+
+ for handler in self.listener.handlers:
+ try:
+ handler.close()
+ except Exception as e:
+ local_log.error(e)
+
+ if self.pull_socket:
+ self.pull_socket.close(linger=0)
+
+ if self.nexus_socket:
+ self.nexus_socket.close(linger=0)
+
+ if self.pub_socket:
+ self.pub_socket.close(linger=0)
+
+ for handler in local_log.handlers:
+ handler.close()
+
+ if self.zmq_context:
+ self.zmq_context.destroy(linger=0)
+
+ def stop(self, signum, frame):
+ local_log.info(f"Log server shutting down due to signal {signum}")
+
+ self.running = False
diff --git a/improv/messaging.py b/improv/messaging.py
new file mode 100644
index 00000000..809bfe8f
--- /dev/null
+++ b/improv/messaging.py
@@ -0,0 +1,85 @@
+class ActorStateMsg:
+ def __init__(self, actor_name, status, nexus_in_port, info):
+ self.actor_name = actor_name
+ self.status = status
+ self.nexus_in_port = nexus_in_port
+ self.info = info
+
+
+class ActorStateReplyMsg:
+ def __init__(self, actor_name, status, info):
+ self.actor_name = actor_name
+ self.status = status
+ self.info = info
+
+
+class ActorSignalMsg:
+ def __init__(self, actor_name, signal, info):
+ self.actor_name = actor_name
+ self.signal = signal
+ self.info = info
+
+
+class ActorSignalReplyMsg:
+ def __init__(self, actor_name, signal, info):
+ self.actor_name = actor_name
+ self.signal = signal
+ self.info = info
+
+
+class NexusSignalMsg:
+ def __init__(self, actor_name, signal, info):
+ self.actor_name = actor_name
+ self.signal = signal
+ self.info = info
+
+
+class NexusSignalReplyMsg:
+ def __init__(self, actor_name, signal, status, info):
+ self.actor_name = actor_name
+ self.signal = signal
+ self.status = status
+ self.info = info
+
+
+class BrokerInfoMsg:
+ def __init__(self, name, pub_port, sub_port, info):
+ self.name = name
+ self.pub_port = pub_port
+ self.sub_port = sub_port
+ self.info = info
+
+
+class BrokerInfoReplyMsg:
+ def __init__(self, name, status, info):
+ self.name = name
+ self.status = status
+ self.info = info
+
+
+class LogInfoMsg:
+ def __init__(self, name, pull_port, pub_port, info):
+ self.name = name
+ self.pull_port = pull_port
+ self.pub_port = pub_port
+ self.info = info
+
+
+class LogInfoReplyMsg:
+ def __init__(self, name, status, info):
+ self.name = name
+ self.status = status
+ self.info = info
+
+
+class HarvesterInfoMsg:
+ def __init__(self, name, info):
+ self.name = name
+ self.info = info
+
+
+class HarvesterInfoReplyMsg:
+ def __init__(self, name, status, info):
+ self.name = name
+ self.status = status
+ self.info = info
diff --git a/improv/nexus.py b/improv/nexus.py
index b5fe19c4..a90cf40c 100644
--- a/improv/nexus.py
+++ b/improv/nexus.py
@@ -1,4 +1,6 @@
-import os
+from __future__ import annotations
+
+import multiprocessing
import time
import uuid
import signal
@@ -7,46 +9,118 @@
import concurrent
import subprocess
-from queue import Full
from datetime import datetime
-from multiprocessing import Process, get_context
+from multiprocessing import get_context
from importlib import import_module
+import zmq as zmq_sync
import zmq.asyncio as zmq
-from zmq import PUB, REP, SocketOption
-
-from improv.store import StoreInterface, RedisStoreInterface, PlasmaStoreInterface
-from improv.actor import Signal
+from zmq import PUB, REP, REQ, SocketOption
+
+from improv import log
+from improv.broker import bootstrap_broker
+from improv.harvester import bootstrap_harvester
+from improv.log import bootstrap_log_server
+from improv.messaging import (
+ ActorStateMsg,
+ ActorStateReplyMsg,
+ ActorSignalMsg,
+ ActorSignalReplyMsg,
+ NexusSignalMsg,
+ BrokerInfoReplyMsg,
+ BrokerInfoMsg,
+ LogInfoMsg,
+ LogInfoReplyMsg,
+ HarvesterInfoMsg,
+ HarvesterInfoReplyMsg,
+)
+from improv.store import StoreInterface, RedisStoreInterface
+from improv.actor import Signal, Actor, LinkInfo
from improv.config import Config
-from improv.link import Link, MultiLink
+
+ASYNC_DEBUG = False
logger = logging.getLogger(__name__)
-logger.setLevel(logging.DEBUG)
+logger.setLevel(logging.INFO)
+if logger.level == logging.DEBUG:
+ logger.addHandler(logging.StreamHandler())
+
+# TODO: redo docsctrings since things are pretty different now
+
+# TODO: socket setup can fail - need to check it
+
+# TODO: redo how actors register with nexus so that we get actor states
+# earlier
+
+
+class ConfigFileNotProvidedException(Exception):
+ def __init__(self):
+ super().__init__("Config file not provided")
-# TODO: Set up store.notify in async function (?)
+
+class ConfigFileNotValidException(Exception):
+ def __init__(self):
+ super().__init__("Config file not valid")
+
+
+class ActorState:
+ def __init__(
+ self, actor_name, status, nexus_in_port, hostname="localhost", sig_socket=None
+ ):
+ self.actor_name = actor_name
+ self.status = status
+ self.nexus_in_port = nexus_in_port
+ self.hostname = hostname
+ self.sig_socket = None
class Nexus:
"""Main server class for handling objects in improv"""
def __init__(self, name="Server"):
+ self.logger_in_port: int | None = None
+ self.zmq_sync_context: zmq_sync.Context | None = None
+ self.logfile: str | None = None
+ self.p_harvester: multiprocessing.Process | None = None
+ self.p_broker: multiprocessing.Process | None = None
+ self.in_socket: zmq.Socket | None = None
+ self.out_socket: zmq.Socket | None = None
+ self.zmq_context: zmq.Context | None = None
+ self.logger_pub_port: int | None = None
+ self.logger_pull_port: int | None = None
+ self.logger_in_socket: zmq.Socket | None = None
+ self.p_logger: multiprocessing.Process | None = None
+ self.broker_pub_port = None
+ self.broker_sub_port = None
+ self.broker_in_port: int | None = None
+ self.broker_in_socket: zmq_sync.Socket | None = None
+ self.actor_states: dict[str, ActorState | None] = dict()
self.redis_fsync_frequency = None
self.store = None
self.config = None
self.name = name
self.aof_dir = None
self.redis_saving_enabled = False
+ self.allow_setup = False
+ self.outgoing_topics = dict()
+ self.incoming_topics = dict()
+ self.data_queues = {}
+ self.actors = {}
+ self.flags = {}
+ self.processes: list[multiprocessing.Process] = []
def __str__(self):
return self.name
- def createNexus(
+ def create_nexus(
self,
file=None,
- use_watcher=None,
- store_size=10_000_000,
- control_port=0,
- output_port=0,
+ store_size=None,
+ control_port=None,
+ output_port=None,
+ log_server_pub_port=None,
+ log_server_pull_port=None,
+ logfile="global.log",
):
"""Function to initialize class variables based on config file.
@@ -56,7 +130,6 @@ def createNexus(
Args:
file (string): Name of the config file.
- use_watcher (bool): Whether to use watcher for the store.
store_size (int): initial store size
control_port (int): port number for input socket
output_port (int): port number for output socket
@@ -65,92 +138,62 @@ def createNexus(
string: "Shutting down", to notify start() that pollQueues has completed.
"""
+ self.logfile = logfile
+
curr_dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
logger.info(f"************ new improv server session {curr_dt} ************")
if file is None:
logger.exception("Need a config file!")
- raise Exception # TODO
- else:
- logger.info(f"Loading configuration file {file}:")
- self.loadConfig(file=file)
- with open(file, "r") as f: # write config file to log
- logger.info(f.read())
-
- # set config options loaded from file
- # in Python 3.9, can just merge dictionaries using precedence
- cfg = self.config.settings
- if "use_watcher" not in cfg:
- cfg["use_watcher"] = use_watcher
- if "store_size" not in cfg:
- cfg["store_size"] = store_size
- if "control_port" not in cfg or control_port != 0:
- cfg["control_port"] = control_port
- if "output_port" not in cfg or output_port != 0:
- cfg["output_port"] = output_port
-
- # set up socket in lieu of printing to stdout
- self.zmq_context = zmq.Context()
- self.zmq_context.setsockopt(SocketOption.LINGER, 1)
- self.out_socket = self.zmq_context.socket(PUB)
- self.out_socket.bind("tcp://*:%s" % cfg["output_port"])
- out_port_string = self.out_socket.getsockopt_string(SocketOption.LAST_ENDPOINT)
- cfg["output_port"] = int(out_port_string.split(":")[-1])
-
- self.in_socket = self.zmq_context.socket(REP)
- self.in_socket.bind("tcp://*:%s" % cfg["control_port"])
- in_port_string = self.in_socket.getsockopt_string(SocketOption.LAST_ENDPOINT)
- cfg["control_port"] = int(in_port_string.split(":")[-1])
-
- self.configure_redis_persistence()
-
- # default size should be system-dependent
- if self.config and self.config.use_plasma():
- self._startStoreInterface(store_size)
- else:
- self._startStoreInterface(store_size)
- logger.info("Redis server started")
-
- self.out_socket.send_string("StoreInterface started")
-
- # connect to store and subscribe to notifications
- logger.info("Create new store object")
- if self.config and self.config.use_plasma():
- self.store = PlasmaStoreInterface(store_loc=self.store_loc)
- else:
- self.store = StoreInterface(server_port_num=self.store_port)
- logger.info(f"Redis server connected on port {self.store_port}")
-
- self.store.subscribe()
+ raise ConfigFileNotProvidedException
+
+ logger.info(f"Loading configuration file {file}:")
+ self.config = Config(config_file=file)
+ self.config.parse_config()
+
+ with open(file, "r") as f: # write config file to log
+ logger.info(f.read())
+
+ logger.debug("Applying CLI parameter configuration overrides")
+ self.apply_cli_config_overrides(
+ store_size=store_size,
+ control_port=control_port,
+ output_port=output_port,
+ logging_port=log_server_pub_port,
+ logging_input_port=log_server_pull_port,
+ )
- # TODO: Better logic/flow for using watcher as an option
- self.p_watch = None
- if cfg["use_watcher"]:
- self.startWatcher()
+ logger.debug("Setting up sockets")
+ self.set_up_sockets()
- # Create dicts for reading config and creating actors
- self.comm_queues = {}
- self.sig_queues = {}
- self.data_queues = {}
- self.actors = {}
- self.flags = {}
- self.processes = []
+ logger.debug("Setting up services")
+ self.start_improv_services(
+ log_server_pub_port=self.config.settings["logging_port"],
+ log_server_pull_port=self.config.settings["logging_input_port"],
+ store_size=self.config.settings["store_size"],
+ )
- self.initConfig()
+ logger.debug("initializing config")
+ self.init_config()
self.flags.update({"quit": False, "run": False, "load": False})
self.allowStart = False
self.stopped = False
- return (cfg["control_port"], cfg["output_port"])
-
- def loadConfig(self, file):
- """Load configuration file.
- file: a YAML configuration file name
- """
- self.config = Config(configFile=file)
+ logger.info(
+ f"control: {self.config.settings['control_port']}, "
+ f"output: {self.config.settings['output_port']}, "
+ f"logging (pub): {self.logger_pub_port}, "
+ f"logging (pull): {self.logger_pull_port}"
+ )
+ return (
+ self.config.settings["control_port"],
+ self.config.settings["output_port"],
+ self.logger_pub_port,
+ self.logger_pull_port,
+ )
- def initConfig(self):
+ def init_config(self):
"""For each connection:
create a Link with a name (purpose), start, and end
Start links to one actor's name, end to the other.
@@ -168,121 +211,50 @@ def initConfig(self):
"""
# TODO load from file or user input, as in dialogue through FrontEnd?
- flag = self.config.createConfig()
+ logger.info("initializing config")
+ flag = self.config.create_config()
if flag == -1:
logger.error(
"An error occurred when loading the configuration file. "
"Please see the log file for more details."
)
+ self.destroy_nexus()
+ raise ConfigFileNotValidException
# create all data links requested from Config config
- self.createConnections()
-
- if self.config.hasGUI:
- # Have to load GUI first (at least with Caiman)
- name = self.config.gui.name
- m = self.config.gui # m is ConfigModule
- # treat GUI uniquely since user communication comes from here
- try:
- visualClass = m.options["visual"]
- # need to instantiate this actor
- visualActor = self.config.actors[visualClass]
- self.createActor(visualClass, visualActor)
- # then add links for visual
- for k, l in {
- key: self.data_queues[key]
- for key in self.data_queues.keys()
- if visualClass in key
- }.items():
- self.assignLink(k, l)
-
- # then give it to our GUI
- self.createActor(name, m)
- self.actors[name].setup(visual=self.actors[visualClass])
-
- self.p_GUI = Process(target=self.actors[name].run, name=name)
- self.p_GUI.daemon = True
- self.p_GUI.start()
-
- except Exception as e:
- logger.error(f"Exception in setting up GUI {name}: {e}")
-
- else:
- # have fake GUI for communications
- q_comm = Link("GUI_comm", "GUI", self.name)
- self.comm_queues.update({q_comm.name: q_comm})
+ self.create_connections()
+
+ # if self.config.hasGUI:
+ # # treat GUI uniquely since user communication comes from here
+ # # Have to load GUI first (at least with Caiman)
+ # name = self.config.gui.name
+ # m = self.config.gui # m is ConfigModule
+ # try:
+ # pass
+ # except Exception as e:
+ # logger.error(f"Exception in setting up GUI {name}: {e}")
# First set up each class/actor
for name, actor in self.config.actors.items():
if name not in self.actors.keys():
# Check for actors being instantiated twice
try:
- self.createActor(name, actor)
+ self.create_actor(name, actor)
logger.info(f"Setting up actor {name}")
except Exception as e:
logger.error(f"Exception in setting up actor {name}: {e}.")
self.quit()
-
- # Second set up each connection b/t actors
- # TODO: error handling for if a user tries to use q_in without defining it
- for name, link in self.data_queues.items():
- self.assignLink(name, link)
-
- if self.config.settings["use_watcher"]:
- watchin = []
- for name in self.config.settings["use_watcher"]:
- watch_link = Link(name + "_watch", name, "Watcher")
- self.assignLink(name + ".watchout", watch_link)
- watchin.append(watch_link)
- self.createWatcher(watchin)
+ raise e
def configure_redis_persistence(self):
# invalid configs: specifying filename and using an ephemeral filename,
# specifying that saving is off but providing either filename option
- aof_dirname = self.config.get_redis_aof_dirname()
- generate_unique_dirname = self.config.generate_ephemeral_aof_dirname()
- redis_saving_enabled = self.config.redis_saving_enabled()
- redis_fsync_frequency = self.config.get_redis_fsync_frequency()
-
- if aof_dirname and generate_unique_dirname:
- logger.error(
- "Cannot both generate a unique dirname and use the one provided."
- )
- raise Exception("Cannot use unique dirname and use the one provided.")
-
- if aof_dirname or generate_unique_dirname or redis_fsync_frequency:
- if redis_saving_enabled is None:
- redis_saving_enabled = True
- elif not redis_saving_enabled:
- logger.error(
- "Invalid configuration. Cannot save to disk with saving disabled."
- )
- raise Exception("Cannot persist to disk with saving disabled.")
-
- self.redis_saving_enabled = redis_saving_enabled
-
- if redis_fsync_frequency and redis_fsync_frequency not in [
- "every_write",
- "every_second",
- "no_schedule",
- ]:
- logger.error("Cannot use unknown fsync frequency ", redis_fsync_frequency)
- raise Exception(
- "Cannot use unknown fsync frequency ", redis_fsync_frequency
- )
-
- if redis_fsync_frequency is None:
- redis_fsync_frequency = "no_schedule"
-
- if redis_fsync_frequency == "every_write":
- self.redis_fsync_frequency = "always"
- elif redis_fsync_frequency == "every_second":
- self.redis_fsync_frequency = "everysec"
- elif redis_fsync_frequency == "no_schedule":
- self.redis_fsync_frequency = "no"
- else:
- logger.error("Unknown fsync frequency ", redis_fsync_frequency)
- raise Exception("Unknown fsync frequency ", redis_fsync_frequency)
+ aof_dirname = self.config.redis_config["aof_dirname"]
+ generate_unique_dirname = self.config.redis_config[
+ "generate_ephemeral_aof_dirname"
+ ]
+ self.redis_saving_enabled = self.config.redis_config["enable_saving"]
+ self.redis_fsync_frequency = self.config.redis_config["fsync_frequency"]
if aof_dirname:
self.aof_dir = aof_dirname
@@ -307,35 +279,34 @@ def configure_redis_persistence(self):
return
- def startNexus(self):
+ def start_nexus(self, *args, **kwargs):
"""
Puts all actors in separate processes and begins polling
to listen to comm queues
"""
for name, m in self.actors.items():
- if "GUI" not in name: # GUI already started
- if "method" in self.config.actors[name].options:
- meth = self.config.actors[name].options["method"]
- logger.info("This actor wants: {}".format(meth))
- ctx = get_context(meth)
- p = ctx.Process(target=m.run, name=name)
+ # if "GUI" not in name: # GUI already started
+ if "method" in self.config.actors[name].options:
+ meth = self.config.actors[name].options["method"]
+ logger.info("This actor wants: {}".format(meth))
+ ctx = get_context(meth)
+ p = ctx.Process(target=m.run, name=name)
+ else:
+ ctx = get_context("fork")
+ p = ctx.Process(target=self.run_actor, name=name, args=(m,))
+ if "daemon" in self.config.actors[name].options:
+ p.daemon = self.config.actors[name].options["daemon"]
+ logger.info("Setting daemon for {}".format(name))
else:
- ctx = get_context("fork")
- p = ctx.Process(target=self.runActor, name=name, args=(m,))
- if "Watcher" not in name:
- if "daemon" in self.config.actors[name].options:
- p.daemon = self.config.actors[name].options["daemon"]
- logger.info("Setting daemon for {}".format(name))
- else:
- p.daemon = True # default behavior
- self.processes.append(p)
+ p.daemon = True # default behavior
+ self.processes.append(p)
self.start()
loop = asyncio.get_event_loop()
+ res = ""
try:
- self.out_socket.send_string("Awaiting input:")
- res = loop.run_until_complete(self.pollQueues())
+ res = loop.run_until_complete(self.poll_queues())
except asyncio.CancelledError:
logger.info("Loop is cancelled")
@@ -344,12 +315,7 @@ def startNexus(self):
except Exception as e:
logger.info(f"Res failed to await: {e}")
- logger.info(f"Current loop: {asyncio.get_event_loop()}")
-
- loop.stop()
- loop.close()
- logger.info("Shutdown loop")
- self.zmq_context.destroy()
+ logger.info("End Nexus")
def start(self):
"""
@@ -363,30 +329,39 @@ def start(self):
logger.info("All processes started")
- def destroyNexus(self):
+ def destroy_nexus(self):
"""Method that calls the internal method
- to kill the process running the store (plasma server)
+ to kill the processes running the store
+ and the message broker
"""
logger.warning("Destroying Nexus")
- self._closeStoreInterface()
-
- if hasattr(self, "store_loc"):
- try:
- os.remove(self.store_loc)
- except FileNotFoundError:
- logger.warning(
- "StoreInterface file {} is already deleted".format(self.store_loc)
- )
- logger.warning("Delete the store at location {0}".format(self.store_loc))
+ self._shutdown_harvester()
+ self._close_store_interface()
- if hasattr(self, "out_socket"):
+ if self.out_socket:
self.out_socket.close(linger=0)
- if hasattr(self, "in_socket"):
+ if self.in_socket:
self.in_socket.close(linger=0)
- if hasattr(self, "zmq_context"):
+ if self.broker_in_socket:
+ self.broker_in_socket.close(linger=0)
+ if self.logger_in_socket:
+ self.logger_in_socket.close(linger=0)
+
+ self._shutdown_broker()
+ self._shutdown_logger()
+
+ for handler in logger.handlers:
+ # need to close this one since we're about to destroy the zmq context
+ if isinstance(handler, log.ZmqLogHandler):
+ handler.close()
+ logger.removeHandler(handler)
+
+ if self.zmq_context:
self.zmq_context.destroy(linger=0)
+ if self.zmq_sync_context:
+ self.zmq_sync_context.destroy(linger=0)
- async def pollQueues(self):
+ async def poll_queues(self):
"""
Listens to links and processes their signals.
@@ -400,239 +375,268 @@ async def pollQueues(self):
string: "Shutting down", Notifies start() that pollQueues has completed.
"""
self.actorStates = dict.fromkeys(self.actors.keys())
- if not self.config.hasGUI:
- # Since Visual is not started, it cannot send a ready signal.
- try:
- del self.actorStates["Visual"]
- except Exception as e:
- logger.info("Visual is not started: {0}".format(e))
- pass
- polling = list(self.comm_queues.values())
- pollingNames = list(self.comm_queues.keys())
- self.tasks = []
- for q in polling:
- self.tasks.append(asyncio.create_task(q.get_async()))
+ self.actor_states = dict.fromkeys(self.actors.keys(), None)
- self.tasks.append(asyncio.create_task(self.remote_input()))
- self.early_exit = False
+ self.tasks = []
+ self.tasks.append(asyncio.create_task(self.process_actor_message()))
# add signal handlers
loop = asyncio.get_event_loop()
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
for s in signals:
loop.add_signal_handler(
- s, lambda s=s: self.stop_polling_and_quit(s, polling)
+ s, lambda s=s: asyncio.create_task(self.stop_polling_and_quit(s))
)
+ logger.info("Nexus signal handlers added")
+
while not self.flags["quit"]:
try:
done, pending = await asyncio.wait(
self.tasks, return_when=concurrent.futures.FIRST_COMPLETED
)
+
+ for i, t in enumerate(self.tasks):
+ if i == 0: # this index is the original task that processes input
+ if t in done:
+ self.tasks[i] = asyncio.create_task(
+ self.process_actor_message()
+ )
+
except asyncio.CancelledError:
pass
- # sort through tasks to see where we got input from
- # (so we can choose a handler)
- for i, t in enumerate(self.tasks):
- if i < len(polling):
- if t in done or polling[i].status == "done":
- # catch tasks that complete await wait/gather
- r = polling[i].result
- if r:
- if "GUI" in pollingNames[i]:
- self.processGuiSignal(r, pollingNames[i])
- else:
- self.processActorSignal(r, pollingNames[i])
- self.tasks[i] = asyncio.create_task(polling[i].get_async())
- elif t in done:
- logger.debug("t.result = " + str(t.result()))
- self.tasks[i] = asyncio.create_task(self.remote_input())
-
- if not self.early_exit: # don't run this again if we already have
- self.stop_polling(Signal.quit(), polling)
- logger.warning("Shutting down polling")
return "Shutting Down"
- def stop_polling_and_quit(self, signal, queues):
+ async def stop_polling_and_quit(self, stop_signal):
"""
quit the process and stop polling signals from queues
Args:
- signal (signal): Signal for handling async polling.
- One of: signal.SIGHUP, signal.SIGTERM, signal.SIGINT
- queues (improv.link.AsyncQueue): Comm queues for links.
+ stop_signal (signal): Signal for handling async polling.
+ One of: signal.SIGHUP, signal.SIGTERM, signal.SIGINT
"""
- logger.warn(
- "Shutting down via signal handler due to {}. \
- Steps may be out of order or dirty.".format(
- signal
- )
+ if stop_signal == signal.SIGHUP:
+ sig = "SIGHUP"
+ elif stop_signal == signal.SIGTERM:
+ sig = "SIGTERM"
+ elif stop_signal == signal.SIGINT:
+ sig = "SIGINT"
+ elif stop_signal == Signal.quit():
+ sig = "QUIT"
+
+ logger.warning(
+ f"Shutting down via signal handler due to {sig}. "
+ "Steps may be out of order or dirty."
)
- self.stop_polling(signal, queues)
+ await self.stop_polling()
+ logger.info("Nexus waiting for async tasks to have a chance to send")
+ await self.out_socket.send_string("QUIT")
self.flags["quit"] = True
- self.early_exit = True
self.quit()
- async def remote_input(self):
- msg = await self.in_socket.recv_multipart()
- command = msg[0].decode("utf-8")
- await self.in_socket.send_string("Awaiting input:")
- if command == Signal.quit():
- await self.out_socket.send_string("QUIT")
- self.processGuiSignal([command], "TUI_Nexus")
-
- def processGuiSignal(self, flag, name):
- """Receive flags from the Front End as user input"""
- name = name.split("_")[0]
- if flag:
- logger.info("Received signal from user: " + flag[0])
- if flag[0] == Signal.run():
- logger.info("Begin run!")
- # self.flags['run'] = True
- self.run()
- elif flag[0] == Signal.setup():
- logger.info("Running setup")
- self.setup()
- elif flag[0] == Signal.ready():
- logger.info("GUI ready")
- self.actorStates[name] = flag[0]
- elif flag[0] == Signal.quit():
- logger.warning("Quitting the program!")
- self.flags["quit"] = True
- self.quit()
- elif flag[0] == Signal.load():
- logger.info("Loading Config config from file " + flag[1])
- self.loadConfig(flag[1])
- elif flag[0] == Signal.pause():
- logger.info("Pausing processes")
- # TODO. Also resume, reset
-
- # temporary WiP
- elif flag[0] == Signal.kill():
- # TODO: specify actor to kill
- list(self.processes)[0].kill()
- elif flag[0] == Signal.revive():
- dead = [p for p in list(self.processes) if p.exitcode is not None]
- for pro in dead:
- name = pro.name
- m = self.actors[pro.name]
- actor = self.config.actors[name]
- if "GUI" not in name: # GUI hard to revive independently
- if "method" in actor.options:
- meth = actor.options["method"]
- logger.info("This actor wants: {}".format(meth))
- ctx = get_context(meth)
- p = ctx.Process(target=m.run, name=name)
- else:
- ctx = get_context("fork")
- p = ctx.Process(target=self.runActor, name=name, args=(m,))
- if "Watcher" not in name:
- if "daemon" in actor.options:
- p.daemon = actor.options["daemon"]
- logger.info("Setting daemon for {}".format(name))
- else:
- p.daemon = True
-
- # Setting the stores for each actor to be the same
- # TODO: test if this works for fork -- don't think it does?
- al = [act for act in self.actors.values() if act.name != pro.name]
- m.setStoreInterface(al[0].client)
- m.client = None
- m._getStoreInterface()
-
- self.processes.append(p)
- p.start()
- m.q_sig.put_nowait(Signal.setup())
- # TODO: ensure waiting for ready before run?
- m.q_sig.put_nowait(Signal.run())
-
- self.processes = [p for p in list(self.processes) if p.exitcode is None]
- elif flag[0] == Signal.stop():
- logger.info("Nexus received stop signal")
- self.stop()
- elif flag:
- logger.error("Unknown signal received from Nexus: {}".format(flag))
-
- def processActorSignal(self, sig, name):
- if sig is not None:
- logger.info("Received signal " + str(sig[0]) + " from " + name)
- state_val = self.actorStates.values()
- if not self.stopped and sig[0] == Signal.ready():
- self.actorStates[name.split("_")[0]] = sig[0]
- if all(val == Signal.ready() for val in state_val):
- self.allowStart = True
- # TODO: replace with q_sig to FE/Visual
- logger.info("Allowing start")
-
- elif self.stopped and sig[0] == Signal.stop_success():
- self.actorStates[name.split("_")[0]] = sig[0]
- if all(val == Signal.stop_success() for val in state_val):
- self.allowStart = True # TODO: replace with q_sig to FE/Visual
- self.stoppped = False
- logger.info("All stops were successful. Allowing start.")
-
- def setup(self):
- for q in self.sig_queues.values():
+ def process_actor_state_update(self, msg: ActorStateMsg):
+ actor_state = None
+ if msg.actor_name in self.actor_states.keys():
+ actor_state = self.actor_states[msg.actor_name]
+ if not actor_state:
+ logger.info(
+ f"Received state message from new actor {msg.actor_name}"
+ f" with info: {msg.info}\n"
+ )
+ self.actor_states[msg.actor_name] = ActorState(
+ msg.actor_name, msg.status, msg.nexus_in_port
+ )
+ actor_state = self.actor_states[msg.actor_name]
+ else:
+ logger.info(
+ f"Received state message from actor {msg.actor_name}"
+ f" with info: {msg.info}\n"
+ f"Current state:\n"
+ f"Name: {actor_state.actor_name}\n"
+ f"Status: {actor_state.status}\n"
+ f"Nexus control port: {actor_state.nexus_in_port}\n"
+ )
+ if msg.nexus_in_port != actor_state.nexus_in_port:
+ pass
+ # TODO: this actor's signal socket changed
+
+ actor_state.actor_name = msg.actor_name
+ actor_state.status = msg.status
+ actor_state.nexus_in_port = msg.nexus_in_port
+
+ logger.info(
+ "Updated actor state:\n"
+ f"Name: {actor_state.actor_name}\n"
+ f"Status: {actor_state.status}\n"
+ f"Nexus control port: {actor_state.nexus_in_port}\n"
+ )
+
+ if all(
+ [
+ actor_state is not None and actor_state.status == Signal.ready()
+ for actor_state in self.actor_states.values()
+ ]
+ ):
+ logger.info("All actors ready. Allowing run.")
+ self.allowStart = True
+
+ return True
+
+ async def process_actor_message(self):
+ msg = await self.in_socket.recv_pyobj()
+ if isinstance(msg, ActorStateMsg):
+ if self.process_actor_state_update(msg):
+ await self.in_socket.send_pyobj(
+ ActorStateReplyMsg(
+ msg.actor_name, "OK", "actor state updated successfully"
+ )
+ )
+ else:
+ await self.in_socket.send_pyobj(
+ ActorStateReplyMsg(
+ msg.actor_name, "ERROR", "actor state update failed"
+ )
+ )
+ if (not self.allow_setup) and all(
+ [
+ actor_state is not None and actor_state.status == Signal.waiting()
+ for actor_state in self.actor_states.values()
+ ]
+ ):
+ logger.info("All actors connected to Nexus. Allowing setup.")
+ self.allow_setup = True
+
+ if (not self.allowStart) and all(
+ [
+ actor_state is not None and actor_state.status == Signal.ready()
+ for actor_state in self.actor_states.values()
+ ]
+ ):
+ logger.info("All actors ready. Allowing run.")
+ self.allowStart = True
+
+ elif isinstance(msg, ActorSignalMsg):
+ if msg.signal == Signal.quit():
+ reply_str = ""
+ else:
+ reply_str = "Awaiting input:"
+
+ await self.in_socket.send_pyobj(
+ ActorSignalReplyMsg(
+ msg.actor_name,
+ msg.signal,
+ f"Signal {msg.signal} received.\n" + reply_str,
+ )
+ )
+ await self.process_actor_signal(msg)
+
+ else:
+ logger.warning(
+ f"Received message {msg} of unrecognized type {type(msg)}."
+ "Expected ActorStateMsg or ActorSignalMsg."
+ )
+
+ async def process_actor_signal(self, msg):
+ signal = msg.signal
+ if signal == Signal.setup():
+ logger.info("Running setup")
+ await self.setup()
+ elif signal == Signal.run():
+ logger.info("Begin run!")
+ await self.run()
+ elif signal == Signal.stop():
+ logger.info("Stop run!")
+ await self.stop()
+ elif signal == Signal.ready():
+ pass
+ elif signal == Signal.quit():
+ logger.warning("Quitting the program!")
+ task = asyncio.create_task(self.stop_polling_and_quit(Signal.quit()))
try:
- logger.info("Starting setup: " + str(q))
- q.put_nowait(Signal.setup())
- except Full:
- logger.warning("Signal queue" + q.name + "is full")
+ await task
+ except Exception as e:
+ logger.error(f"Caught exception {e} when trying to quit the program.")
+ else:
+ logger.warning(f"Unknown command {signal} from actor {msg.actor_name}")
- def run(self):
+ async def setup(self):
+ if not self.allow_setup:
+ logger.error(
+ "Not all actors connected to Nexus. Please wait, then try again."
+ )
+ return
+
+ for actor in self.actor_states.values():
+ logger.info("Starting setup: " + str(actor.actor_name))
+ actor.sig_socket = self.zmq_context.socket(REQ)
+ actor.sig_socket.connect(f"tcp://{actor.hostname}:{actor.nexus_in_port}")
+
+ await self.signal_to_actors(Signal.setup())
+
+ async def run(self):
if self.allowStart:
- for q in self.sig_queues.values():
- try:
- q.put_nowait(Signal.run())
- except Full:
- logger.warning("Signal queue" + q.name + "is full")
- # queue full, keep going anyway
- # TODO: add repeat trying as async task
+ await self.signal_to_actors(Signal.run())
else:
logger.error("Not all actors ready yet, please wait and then try again.")
def quit(self):
logger.warning("Killing child processes")
- self.out_socket.send_string("QUIT")
-
- for q in self.sig_queues.values():
- try:
- q.put_nowait(Signal.quit())
- except Full:
- logger.warning("Signal queue {} full, cannot quit".format(q.name))
- except FileNotFoundError:
- logger.warning("Queue {} corrupted.".format(q.name))
-
- if self.config.hasGUI:
- self.processes.append(self.p_GUI)
-
- if self.p_watch:
- self.processes.append(self.p_watch)
for p in self.processes:
p.terminate()
- p.join()
+ p.join(timeout=5)
+ if p.exitcode is None:
+ p.kill()
+ logger.error("Process did not exit in time. Kill signal sent.")
logger.warning("Actors terminated")
- self.destroyNexus()
+ self.destroy_nexus()
- def stop(self):
+ async def stop(self):
logger.warning("Starting stop procedure")
self.allowStart = False
- for q in self.sig_queues.values():
- try:
- q.put_nowait(Signal.stop())
- except Full:
- logger.warning("Signal queue" + q.name + "is full")
+ await self.signal_to_actors(Signal.stop())
+
self.allowStart = True
- def revive(self):
+ async def revive(self):
logger.warning("Starting revive")
- def stop_polling(self, stop_signal, queues):
+ await self.signal_to_actors(Signal.revive())
+
+ async def signal_to_actors(self, signal):
+ """Sends signal to actors safely (with error handling
+ and timeout).
+ """
+ for actor in self.actor_states.values():
+ try:
+ send_str = f"Nexus sending {signal} signal to {actor.actor_name}"
+ logger.info(send_str)
+ await actor.sig_socket.send_pyobj(
+ NexusSignalMsg(actor.actor_name, signal, send_str)
+ )
+ msg_ready = await actor.sig_socket.poll(timeout=1000)
+ if msg_ready == 0:
+ raise TimeoutError
+ await actor.sig_socket.recv_pyobj()
+ except TimeoutError:
+ logger.error(
+ f"Timed out waiting for reply to {signal} message "
+ f"from actor {actor.actor_name}. "
+ f"Closing connection."
+ )
+ actor.sig_socket.close(linger=0)
+ except Exception as e:
+ logger.error(
+ f"Unable to send {signal} message "
+ f"to actor {actor.actor_name}: "
+ f"{e}"
+ )
+
+ async def stop_polling(self):
"""Cancels outstanding tasks and fills their last request.
Puts a string into all active queues, then cancels their
@@ -643,15 +647,39 @@ def stop_polling(self, stop_signal, queues):
stop_signal (improv.actor.Signal): Signal for signal handler.
queues (improv.link.AsyncQueue): Comm queues for links.
"""
- logger.info("Received shutdown order")
+ logger.info("Received shutdown order. Time to stop polling.")
- logger.info(f"Stop signal: {stop_signal}")
shutdown_message = Signal.quit()
- for q in queues:
+ for actor in self.actor_states.values():
try:
- q.put(shutdown_message)
- except Exception:
- logger.info("Unable to send shutdown message to {}.".format(q.name))
+ logger.info(f"Sending quit signal to {actor.actor_name}")
+ await actor.sig_socket.send_pyobj(
+ NexusSignalMsg(
+ actor.actor_name, shutdown_message, "Nexus sending quit signal"
+ )
+ )
+ msg_ready = await actor.sig_socket.poll(timeout=1000)
+ if msg_ready == 0:
+ raise TimeoutError
+ else:
+ logger.info(
+ f"Preparing to receive response from {actor.actor_name}"
+ )
+ rep = await actor.sig_socket.recv_pyobj()
+ logger.info(f"Received reply {rep.info} from {actor.actor_name}")
+ except TimeoutError:
+ logger.info(
+ f"Timed out waiting for reply to quit message "
+ f"from actor {actor.actor_name}. "
+ f"Closing connection."
+ )
+ actor.sig_socket.close(linger=0)
+ except Exception as e:
+ logger.info(
+ f"Unable to send shutdown message "
+ f"to actor {actor.actor_name}: "
+ f"{e}"
+ )
logger.info("Canceling outstanding tasks")
@@ -659,19 +687,14 @@ def stop_polling(self, stop_signal, queues):
logger.info("Polling has stopped.")
- def createStoreInterface(self, name):
+ def create_store_interface(self):
"""Creates StoreInterface"""
- if self.config.use_plasma():
- return PlasmaStoreInterface(name, self.store_loc)
- else:
- return RedisStoreInterface(server_port_num=self.store_port)
-
- def _startStoreInterface(self, size, attempts=20):
- """Start a subprocess that runs the plasma store
- Raises a RuntimeError exception size is undefined
- Raises an Exception if the plasma store doesn't start
+ return RedisStoreInterface(server_port_num=self.store_port)
- #TODO: Generalize this to non-plasma stores
+ def _start_store_interface(self, size, attempts=20):
+ """Start a subprocess that runs the redis store
+ Raises a RuntimeError exception if size is undefined
+ Raises an Exception if the redis store doesn't start
Args:
size: in bytes
@@ -683,64 +706,27 @@ def _startStoreInterface(self, size, attempts=20):
"""
if size is None:
raise RuntimeError("Server size needs to be specified")
- self.use_plasma = False
- if self.config and self.config.use_plasma():
- self.use_plasma = True
- self.store_loc = str(os.path.join("/tmp/", str(uuid.uuid4())))
- self.p_StoreInterface = subprocess.Popen(
- [
- "plasma_store",
- "-s",
- self.store_loc,
- "-m",
- str(size),
- "-e",
- "hashtable://test",
- ],
- stdout=subprocess.DEVNULL,
- stderr=subprocess.DEVNULL,
- )
- logger.info("StoreInterface start successful: {}".format(self.store_loc))
- else:
- logger.info("Setting up Redis store.")
- self.store_port = (
- self.config.get_redis_port()
- if self.config and self.config.redis_port_specified()
- else Config.get_default_redis_port()
+
+ logger.info("Setting up Redis store.")
+ self.store_port = self.config.redis_config["port"] if self.config else 6379
+ logger.info("Searching for open port starting at specified port.")
+ for attempt in range(attempts):
+ logger.info(
+ "Attempting to connect to Redis on port {}".format(self.store_port)
)
- if self.config and self.config.redis_port_specified():
- logger.info(
- "Attempting to connect to Redis on port {}".format(self.store_port)
- )
- # try with failure, incrementing port number
- self.p_StoreInterface = self.start_redis(size)
- time.sleep(3)
- if self.p_StoreInterface.poll():
- logger.error("Could not start Redis on specified port number.")
- raise Exception("Could not start Redis on specified port.")
+ # try with failure, incrementing port number
+ self.p_StoreInterface = self.start_redis(size)
+ time.sleep(1)
+ if self.p_StoreInterface.poll(): # Redis could not start
+ logger.info("Could not connect to port {}".format(self.store_port))
+ self.store_port = str(int(self.store_port) + 1)
else:
- logger.info("Redis port not specified. Searching for open port.")
- for attempt in range(attempts):
- logger.info(
- "Attempting to connect to Redis on port {}".format(
- self.store_port
- )
- )
- # try with failure, incrementing port number
- self.p_StoreInterface = self.start_redis(size)
- time.sleep(3)
- if self.p_StoreInterface.poll(): # Redis could not start
- logger.info(
- "Could not connect to port {}".format(self.store_port)
- )
- self.store_port = str(int(self.store_port) + 1)
- else:
- break
- else:
- logger.error("Could not start Redis on any tried port.")
- raise Exception("Could not start Redis on any tried ports.")
+ break
+ else:
+ logger.error("Could not start Redis on any tried port.")
+ raise Exception("Could not start Redis on any tried ports.")
- logger.info(f"StoreInterface start successful on port {self.store_port}")
+ logger.info(f"StoreInterface start successful on port {self.store_port}")
def start_redis(self, size):
subprocess_command = [
@@ -788,26 +774,27 @@ def start_redis(self, size):
stderr=subprocess.DEVNULL,
)
- def _closeStoreInterface(self):
+ def _close_store_interface(self):
"""Internal method to kill the subprocess
- running the store (plasma sever)
+ running the store
"""
if hasattr(self, "p_StoreInterface"):
try:
self.p_StoreInterface.send_signal(signal.SIGINT)
- self.p_StoreInterface.wait()
- logger.info(
- "StoreInterface close successful: {}".format(
- self.store_loc
- if self.config and self.config.use_plasma()
- else self.store_port
+ try:
+ self.p_StoreInterface.wait(timeout=30)
+ logger.info(
+ "StoreInterface close successful: {}".format(self.store_port)
)
- )
+ except subprocess.TimeoutExpired as e:
+ logger.error(e)
+ self.p_StoreInterface.send_signal(signal.SIGKILL)
+ logger.info("Killed datastore process")
except Exception as e:
logger.exception("Cannot close store {}".format(e))
- def createActor(self, name, actor):
+ def create_actor(self, name, actor):
"""Function to instantiate actor, add signal and comm Links,
and update self.actors dictionary
@@ -818,35 +805,32 @@ def createActor(self, name, actor):
# Instantiate selected class
mod = import_module(actor.packagename)
clss = getattr(mod, actor.classname)
- if self.config.use_plasma():
- instance = clss(actor.name, store_loc=self.store_loc, **actor.options)
- else:
- instance = clss(actor.name, store_port_num=self.store_port, **actor.options)
-
- if "method" in actor.options.keys():
- # check for spawn
- if "fork" == actor.options["method"]:
- # Add link to StoreInterface store
- store = self.createStoreInterface(actor.name)
- instance.setStoreInterface(store)
- else:
- # spawn or forkserver; can't pickle plasma store
- logger.info("No store for this actor yet {}".format(name))
- else:
- # Add link to StoreInterface store
- store = self.createStoreInterface(actor.name)
- instance.setStoreInterface(store)
-
- q_comm = Link(actor.name + "_comm", actor.name, self.name)
- q_sig = Link(actor.name + "_sig", self.name, actor.name)
- self.comm_queues.update({q_comm.name: q_comm})
- self.sig_queues.update({q_sig.name: q_sig})
- instance.setCommLinks(q_comm, q_sig)
+ outgoing_links = (
+ self.outgoing_topics[actor.name]
+ if actor.name in self.outgoing_topics
+ else []
+ )
+ incoming_links = (
+ self.incoming_topics[actor.name]
+ if actor.name in self.incoming_topics
+ else []
+ )
+ instance = clss(
+ name=actor.name,
+ nexus_comm_port=self.config.settings["control_port"],
+ broker_sub_port=self.broker_sub_port,
+ broker_pub_port=self.broker_pub_port,
+ log_pull_port=self.logger_pull_port,
+ outgoing_links=outgoing_links,
+ incoming_links=incoming_links,
+ store_port_num=self.store_port,
+ **actor.options,
+ )
# Update information
self.actors.update({name: instance})
- def runActor(self, actor):
+ def run_actor(self, actor: Actor):
"""Run the actor continually; used for separate processes
#TODO: hook into monitoring here?
@@ -855,55 +839,287 @@ def runActor(self, actor):
"""
actor.run()
- def createConnections(self):
- """Assemble links (multi or other)
- for later assignment
+ def create_connections(self):
+ for name, connection in self.config.connections.items():
+ sources = connection["sources"]
+ if not isinstance(sources, list):
+ sources = [sources]
+ sinks = connection["sinks"]
+ if not isinstance(sinks, list):
+ sinks = [sinks]
+
+ name_input = tuple(
+ ["inputs:"]
+ + [name for name in sources]
+ + ["outputs:"]
+ + [name for name in sinks]
+ )
+ name = str(
+ hash(name_input)
+ ) # space-efficient key for uniquely referring to this connection
+ logger.info(f"Created link name {name} for link {name_input}")
+
+ for source in sources:
+ source_actor = source.split(".")[0]
+ source_link = source.split(".")[1]
+ if source_actor not in self.outgoing_topics.keys():
+ self.outgoing_topics[source_actor] = [LinkInfo(source_link, name)]
+ else:
+ self.outgoing_topics[source_actor].append(
+ LinkInfo(source_link, name)
+ )
+
+ for sink in sinks:
+ sink_actor = sink.split(".")[0]
+ sink_link = sink.split(".")[1]
+ if sink_actor not in self.incoming_topics.keys():
+ self.incoming_topics[sink_actor] = [LinkInfo(sink_link, name)]
+ else:
+ self.incoming_topics[sink_actor].append(LinkInfo(sink_link, name))
+
+ def start_logger(self, log_server_pub_port, log_server_pull_port):
+ spawn_context = get_context("spawn")
+ self.p_logger = spawn_context.Process(
+ target=bootstrap_log_server,
+ args=(
+ "localhost",
+ self.logger_in_port,
+ self.logfile,
+ log_server_pub_port,
+ log_server_pull_port,
+ ),
+ )
+ logger.debug("logger created")
+ self.p_logger.start()
+ time.sleep(1)
+ logger.debug("logger started")
+ if not self.p_logger.is_alive():
+ logger.error(
+ "Logger process failed to start. "
+ "Please see the log server log file for more information. "
+ "The improv server will now exit."
+ )
+ self.quit()
+ raise Exception("Could not start log server.")
+ logger.debug("logger is alive")
+ poll_res = self.logger_in_socket.poll(timeout=5000)
+ if poll_res == 0:
+ logger.error(
+ "Never got reply from logger. Cannot proceed setting up Nexus."
+ )
+ try:
+ with open("log_server.log", "r") as file:
+ logger.debug(file.read())
+ except Exception as e:
+ logger.error(e)
+ self.destroy_nexus()
+ logger.error("exiting after destroy")
+ exit(1)
+ logger_info: LogInfoMsg = self.logger_in_socket.recv_pyobj()
+ self.logger_pull_port = logger_info.pull_port
+ self.logger_pub_port = logger_info.pub_port
+ self.logger_in_socket.send_pyobj(
+ LogInfoReplyMsg(logger_info.name, "OK", "registered logger information")
+ )
+ logger.debug("logger replied with setup message")
+
+ def start_message_broker(self):
+ spawn_context = get_context("spawn")
+ self.p_broker = spawn_context.Process(
+ target=bootstrap_broker, args=("localhost", self.broker_in_port)
+ )
+ logger.debug("broker created")
+ self.p_broker.start()
+ time.sleep(1)
+ logger.debug("broker started")
+ if not self.p_broker.is_alive():
+ logger.error(
+ "Broker process failed to start. "
+ "Please see the log file for more information. "
+ "The improv server will now exit."
+ )
+ self.quit()
+ raise Exception("Could not start message broker server.")
+ logger.debug("broker is alive")
+ poll_res = self.broker_in_socket.poll(timeout=5000)
+ if poll_res == 0:
+ logger.error("Never got reply from broker. Cannot proceed.")
+ try:
+ with open("broker_server.log", "r") as file:
+ logger.debug(file.read())
+ except Exception as e:
+ logger.error(e)
+ self.destroy_nexus()
+ logger.debug("exiting after destroy")
+ exit(1)
+ broker_info: BrokerInfoMsg = self.broker_in_socket.recv_pyobj()
+ self.broker_sub_port = broker_info.sub_port
+ self.broker_pub_port = broker_info.pub_port
+ self.broker_in_socket.send_pyobj(
+ BrokerInfoReplyMsg(broker_info.name, "OK", "registered broker information")
+ )
+ logger.debug("broker replied with setup message")
+
+ def _shutdown_broker(self):
+ """Internal method to kill the subprocess
+ running the message broker
"""
- for source, drain in self.config.connections.items():
- name = source.split(".")[0]
- # current assumption is connection goes from q_out to something(s) else
- if len(drain) > 1: # we need multiasyncqueue
- link, endLinks = MultiLink(name + "_multi", source, drain)
- self.data_queues.update({source: link})
- for i, e in enumerate(endLinks):
- self.data_queues.update({drain[i]: e})
- else: # single input, single output
- d = drain[0]
- d_name = d.split(".") # TODO: check if .anything, if not assume q_in
- link = Link(name + "_" + d_name[0], source, d)
- self.data_queues.update({source: link})
- self.data_queues.update({d: link})
-
- def assignLink(self, name, link):
- """Function to set up Links between actors
- for data location passing
- Actor must already be instantiated
-
- #NOTE: Could use this for reassigning links if actors crash?
-
- #TODO: Adjust to use default q_out and q_in vs being specified
+ if self.p_broker:
+ try:
+ self.p_broker.terminate()
+ self.p_broker.join(timeout=5)
+ if self.p_broker.exitcode is None:
+ self.p_broker.kill()
+ logger.error("Killed broker process")
+ else:
+ logger.info(
+ "Broker shutdown successful with exit code {}".format(
+ self.p_broker.exitcode
+ )
+ )
+ except Exception as e:
+ logger.exception(f"Unable to close broker {e}")
+
+ def _shutdown_logger(self):
+ """Internal method to kill the subprocess
+ running the logger
"""
- classname = name.split(".")[0]
- linktype = name.split(".")[1]
- if linktype == "q_out":
- self.actors[classname].setLinkOut(link)
- elif linktype == "q_in":
- self.actors[classname].setLinkIn(link)
- elif linktype == "watchout":
- self.actors[classname].setLinkWatch(link)
- else:
- self.actors[classname].addLink(linktype, link)
+ if self.p_logger:
+ try:
+ self.p_logger.terminate()
+ self.p_logger.join(timeout=5)
+ if self.p_logger.exitcode is None:
+ self.p_logger.kill()
+ logger.error("Killed logger process")
+ else:
+ logger.info("Logger shutdown successful")
+ except Exception as e:
+ logger.exception(f"Unable to close logger: {e}")
- # TODO: StoreInterface access here seems wrong, need to test
- def startWatcher(self):
- from improv.watcher import Watcher
+ def _shutdown_harvester(self):
+ """Internal method to kill the subprocess
+ running the logger
+ """
+ if self.p_harvester:
+ try:
+ self.p_harvester.terminate()
+ self.p_harvester.join(timeout=5)
+ if self.p_harvester.exitcode is None:
+ self.p_harvester.kill()
+ logger.error("Killed harvester process")
+ else:
+ logger.info("Harvester shutdown successful")
+ except Exception as e:
+ logger.exception(f"Unable to close harvester: {e}")
- self.watcher = Watcher("watcher", self.createStoreInterface("watcher"))
- q_sig = Link("watcher_sig", self.name, "watcher")
- self.watcher.setLinks(q_sig)
- self.sig_queues.update({q_sig.name: q_sig})
+ def set_up_sockets(self):
+ logger.debug("Connecting to output")
+ cfg = self.config.settings # this could be self.settings instead
+ self.zmq_context = zmq.Context()
+ self.zmq_context.setsockopt(SocketOption.LINGER, 0)
+ self.out_socket = self.zmq_context.socket(PUB)
+ self.out_socket.bind("tcp://*:%s" % cfg["output_port"])
+ out_port_string = self.out_socket.getsockopt_string(SocketOption.LAST_ENDPOINT)
+ cfg["output_port"] = int(out_port_string.split(":")[-1])
+
+ logger.debug("Connecting to control")
+ self.in_socket = self.zmq_context.socket(REP)
+ self.in_socket.bind("tcp://*:%s" % cfg["control_port"])
+ in_port_string = self.in_socket.getsockopt_string(SocketOption.LAST_ENDPOINT)
+ cfg["control_port"] = int(in_port_string.split(":")[-1])
+
+ logger.debug("Setting up sync server startup socket")
+ self.zmq_sync_context = zmq_sync.Context()
+ self.zmq_sync_context.setsockopt(SocketOption.LINGER, 0)
+
+ self.logger_in_socket = self.zmq_sync_context.socket(REP)
+ self.logger_in_socket.bind("tcp://*:0")
+ logger_in_port_string = self.logger_in_socket.getsockopt_string(
+ SocketOption.LAST_ENDPOINT
+ )
+ self.logger_in_port = int(logger_in_port_string.split(":")[-1])
+
+ self.broker_in_socket = self.zmq_sync_context.socket(REP)
+ self.broker_in_socket.bind("tcp://*:0")
+ broker_in_port_string = self.broker_in_socket.getsockopt_string(
+ SocketOption.LAST_ENDPOINT
+ )
+ self.broker_in_port = int(broker_in_port_string.split(":")[-1])
+
+ def start_improv_services(
+ self, log_server_pub_port, log_server_pull_port, store_size
+ ):
+ logger.debug("Starting logger")
+ self.start_logger(log_server_pub_port, log_server_pull_port)
+ logger.addHandler(
+ log.ZmqLogHandler("localhost", self.logger_pull_port, self.zmq_sync_context)
+ )
+
+ logger.debug("starting broker")
+ self.start_message_broker()
- self.p_watch = Process(target=self.watcher.run, name="watcher_process")
- self.p_watch.daemon = True
- self.p_watch.start()
- self.processes.append(self.p_watch)
+ logger.debug("Parsing redis persistence")
+ self.configure_redis_persistence()
+
+ logger.debug("Starting redis server")
+ # default size should be system-dependent
+ self._start_store_interface(store_size)
+ logger.info("Redis server started")
+
+ if self.config.settings["harvest_data_from_memory"]:
+ logger.debug("starting harvester")
+ self.start_harvester()
+
+ # connect to store and subscribe to notifications
+ logger.info("Create new store object")
+ self.store = StoreInterface(server_port_num=self.store_port)
+ logger.info(f"Redis server connected on port {self.store_port}")
+ logger.info("all services started")
+
+ def apply_cli_config_overrides(
+ self, store_size, control_port, output_port, logging_port, logging_input_port
+ ):
+ if store_size is not None:
+ self.config.settings["store_size"] = store_size
+ if control_port is not None:
+ self.config.settings["control_port"] = control_port
+ if output_port is not None:
+ self.config.settings["output_port"] = output_port
+ if logging_port is not None:
+ self.config.settings["logging_port"] = logging_port
+ if output_port is not None:
+ self.config.settings["logging_input_port"] = logging_input_port
+
+ def start_harvester(self):
+ spawn_context = get_context("spawn")
+ self.p_harvester = spawn_context.Process(
+ target=bootstrap_harvester,
+ args=(
+ "localhost",
+ self.broker_in_port,
+ "localhost",
+ self.store_port,
+ "localhost",
+ self.broker_pub_port,
+ "localhost",
+ self.logger_pull_port,
+ ),
+ )
+ self.p_harvester.start()
+ time.sleep(1)
+ if not self.p_harvester.is_alive():
+ logger.error(
+ "Harvester process failed to start. "
+ "Please see the log file for more information. "
+ "The improv server will now exit."
+ )
+ self.quit()
+ raise Exception("Could not start harvester server.")
+
+ harvester_info: HarvesterInfoMsg = self.broker_in_socket.recv_pyobj()
+ self.broker_in_socket.send_pyobj(
+ HarvesterInfoReplyMsg(
+ harvester_info.name, "OK", "registered harvester information"
+ )
+ )
+ logger.info("Harvester server started")
diff --git a/improv/store.py b/improv/store.py
index 98f54a63..ef4868ef 100644
--- a/improv/store.py
+++ b/improv/store.py
@@ -3,27 +3,20 @@
import pickle
import logging
-import traceback
-
-import numpy as np
-import pyarrow.plasma as plasma
+import zlib
from redis import Redis
from redis.retry import Retry
from redis.backoff import ConstantBackoff
from redis.exceptions import BusyLoadingError, ConnectionError, TimeoutError
-from scipy.sparse import csc_matrix
-from pyarrow.lib import ArrowIOError
-from pyarrow._plasma import PlasmaObjectExists, ObjectNotAvailable
-
-REDIS_GLOBAL_TOPIC = "global_topic"
+ZLIB_COMPRESSION_LEVEL = -1
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
-class StoreInterface:
+class AbstractStoreInterface:
"""General interface for a store"""
def get(self):
@@ -42,18 +35,22 @@ def subscribe(self):
raise NotImplementedError
-class RedisStoreInterface(StoreInterface):
- def __init__(self, name="default", server_port_num=6379, hostname="localhost"):
+class RedisStoreInterface(AbstractStoreInterface):
+ def __init__(
+ self,
+ name="default",
+ server_port_num=6379,
+ hostname="localhost",
+ compression_level=ZLIB_COMPRESSION_LEVEL,
+ ):
self.name = name
self.server_port_num = server_port_num
self.hostname = hostname
self.client = self.connect_to_server()
+ self.compression_level = compression_level
def connect_to_server(self):
- # TODO this should scan for available ports, but only if configured to do so.
- # This happens when the config doesn't have Redis settings,
- # so we need to communicate this somehow to the StoreInterface here.
- """Connect to the store at store_loc, max 20 retries to connect
+ """Connect to the store, max 20 retries to connect
Raises exception if can't connect
Returns the Redis client if successful
@@ -107,18 +104,11 @@ def put(self, object):
object: the object that was a
"""
object_key = str(os.getpid()) + str(uuid.uuid4())
- try:
- # buffers would theoretically go here if we need to force out-of-band
- # serialization for single objects
- # TODO this will actually just silently fail if we use an existing
- # TODO key; not sure it's worth the network overhead to check every
- # TODO key twice every time. we still need a better solution for
- # TODO this, but it will work now singlethreaded most of the time.
-
- self.client.set(object_key, pickle.dumps(object, protocol=5), nx=True)
- except Exception:
- logger.error("Could not store object {}".format(object_key))
- logger.error(traceback.format_exc())
+ data = zlib.compress(
+ pickle.dumps(object, protocol=5), level=self.compression_level
+ )
+
+ self.client.set(object_key, data, nx=True)
return object_key
@@ -138,14 +128,10 @@ def get(self, object_key):
object_value = self.client.get(object_key)
if object_value:
# buffers would also go here to force out-of-band deserialization
- return pickle.loads(object_value)
+ return pickle.loads(zlib.decompress(object_value))
logger.warning("Object {} cannot be found.".format(object_key))
- raise ObjectNotFoundError
-
- def subscribe(self, topic=REDIS_GLOBAL_TOPIC):
- p = self.client.pubsub()
- p.subscribe(topic)
+ raise ObjectNotFoundError(object_key)
def get_list(self, ids):
"""Get multiple objects from the store
@@ -156,7 +142,10 @@ def get_list(self, ids):
Returns:
list of the objects
"""
- return self.client.mget(ids)
+ return [
+ pickle.loads(zlib.decompress(object_value))
+ for object_value in self.client.mget(ids)
+ ]
def get_all(self):
"""Get a listing of all objects in the store.
@@ -166,7 +155,7 @@ def get_all(self):
list of all the objects in the store
"""
all_keys = self.client.keys() # defaults to "*" pattern, so will fetch all
- return self.client.mget(all_keys)
+ return self.get_list(all_keys)
def reset(self):
"""Reset client connection"""
@@ -175,226 +164,6 @@ def reset(self):
"Reset local connection to store on port: {0}".format(self.server_port_num)
)
- def notify(self):
- pass # I don't see any call sites for this, so leaving it blank at the moment
-
-
-class PlasmaStoreInterface(StoreInterface):
- """Basic interface for our specific data store implemented with apache arrow plasma
- Objects are stored with object_ids
- References to objects are contained in a dict where key is shortname,
- value is object_id
- """
-
- def __init__(self, name="default", store_loc="/tmp/store"):
- """
- Constructor for the StoreInterface
-
- :param name:
- :param store_loc: Apache Arrow Plasma client location
- """
-
- self.name = name
- self.store_loc = store_loc
- self.client = self.connect_store(store_loc)
- self.stored = {}
-
- def connect_store(self, store_loc):
- """Connect to the store at store_loc, max 20 retries to connect
- Raises exception if can't connect
- Returns the plasmaclient if successful
- Updates the client internal
-
- Args:
- store_loc: store location
- """
- try:
- self.client = plasma.connect(store_loc, 20)
- logger.info("Successfully connected to store: {} ".format(store_loc))
- except Exception:
- logger.exception("Cannot connect to store: {}".format(store_loc))
- raise CannotConnectToStoreInterfaceError(store_loc)
- return self.client
-
- def put(self, object, object_name):
- """
- Put a single object referenced by its string name
- into the store
- Raises PlasmaObjectExists if we are overwriting
- Unknown error
-
- Args:
- object:
- object_name (str):
- flush_this_immediately (bool):
-
- Returns:
- class 'plasma.ObjectID': Plasma object ID
-
- Raises:
- PlasmaObjectExists: if we are overwriting \
- unknown error
- """
- object_id = None
- try:
- # Need to pickle if object is csc_matrix
- if isinstance(object, csc_matrix):
- prot = pickle.HIGHEST_PROTOCOL
- object_id = self.client.put(pickle.dumps(object, protocol=prot))
- else:
- object_id = self.client.put(object)
-
- except PlasmaObjectExists:
- logger.error("Object already exists. Meant to call replace?")
- except ArrowIOError:
- logger.error("Could not store object {}".format(object_name))
- logger.info("Refreshing connection and continuing")
- self.reset()
- except Exception:
- logger.error("Could not store object {}".format(object_name))
- logger.error(traceback.format_exc())
-
- return object_id
-
- def get(self, object_name):
- """Get a single object from the store by object name
- Checks to see if it knows the object first
- Otherwise throw CannotGetObject to request dict update
- TODO: update for lists of objects
- TODO: replace with getID
-
- Returns:
- Stored object
- """
- # print('trying to get ', object_name)
- # if self.stored.get(object_name) is None:
- # logger.error('Never recorded storing this object: '+object_name)
- # # Don't know anything about this object, treat as problematic
- # raise CannotGetObjectError(query = object_name)
- # else:
- return self.getID(object_name)
-
- def getID(self, obj_id):
- """
- Get object by object ID
-
- Args:
- obj_id (class 'plasma.ObjectID'): the id of the object
-
- Returns:
- Stored object
-
- Raises:
- ObjectNotFoundError: If the id is not found
- """
- res = self.client.get(obj_id, 0) # Timeout = 0 ms
- if res is not plasma.ObjectNotAvailable:
- return res if not isinstance(res, bytes) else pickle.loads(res)
-
- logger.warning("Object {} cannot be found.".format(obj_id))
- raise ObjectNotFoundError
-
- def getList(self, ids):
- """Get multiple objects from the store
-
- Args:
- ids (list): of type plasma.ObjectID
-
- Returns:
- list of the objects
- """
- # self._get()
- return self.client.get(ids)
-
- def get_all(self):
- """Get a listing of all objects in the store
-
- Returns:
- list of all the objects in the store
- """
- return self.client.list()
-
- def reset(self):
- """Reset client connection"""
- self.client = self.connect_store(self.store_loc)
- logger.debug("Reset local connection to store: {0}".format(self.store_loc))
-
- def release(self):
- self.client.disconnect()
-
- # Subscribe to notifications about sealed objects?
- def subscribe(self):
- """Subscribe to a section? of the ds for singals
-
- Raises:
- Exception: Unknown error
- """
- try:
- self.client.subscribe()
- except Exception as e:
- logger.error("Unknown error: {}".format(e))
- raise Exception
-
- def notify(self):
- try:
- notification_info = self.client.get_next_notification()
- # recv_objid, recv_dsize, recv_msize = notification_info
- except ArrowIOError:
- notification_info = None
- except Exception as e:
- logger.exception("Notification error: {}".format(e))
- raise Exception
-
- return notification_info
-
- # Necessary? plasma.ObjectID.from_random()
- def random_ObjectID(self, number=1):
- ids = []
- for i in range(number):
- ids.append(plasma.ObjectID(np.random.bytes(20)))
- return ids
-
- def updateStoreInterfaced(self, object_name, object_id):
- """Update local dict with info we need locally
- Report to Nexus that we updated the store
- (did a put or delete/replace)
-
- Args:
- object_name (str): the name of the object to update
- object_id (): the id of the object to update
- """
- self.stored.update({object_name: object_id})
-
- def getStored(self):
- """
- Returns:
- its info about what it has stored
- """
- return self.stored
-
- def _put(self, obj, id):
- """Internal put"""
- return self.client.put(obj, id)
-
- def _get(self, object_name):
- """Get an object from the store using its name
- Assumes we know the id for the object_name
-
- Raises:
- ObjectNotFound: if object_id returns no object from the store
- """
- # Most errors not shown to user.
- # Maintain separation between external and internal function calls.
- res = self.getID(self.stored.get(object_name))
- # Can also use contains() to check
-
- logger.warning("{}".format(object_name))
- if isinstance(res, ObjectNotAvailable):
- logger.warning("Object {} cannot be found.".format(object_name))
- raise ObjectNotFoundError(obj_id_or_name=object_name) # TODO: Don't raise?
- else:
- return res
-
StoreInterface = RedisStoreInterface
@@ -428,12 +197,12 @@ def __str__(self):
class CannotConnectToStoreInterfaceError(Exception):
"""Raised when failing to connect to store."""
- def __init__(self, store_loc):
+ def __init__(self, store_port):
super().__init__()
self.name = "CannotConnectToStoreInterfaceError"
- self.message = "Cannot connect to store at {}".format(str(store_loc))
+ self.message = "Cannot connect to store at {}".format(str(store_port))
def __str__(self):
return self.message
diff --git a/improv/tui.py b/improv/tui.py
index 45481432..60ab72b9 100644
--- a/improv/tui.py
+++ b/improv/tui.py
@@ -17,9 +17,12 @@
from textual.message import Message
import logging
from zmq.log.handlers import PUBHandler
+from improv.messaging import ActorSignalMsg, ActorSignalReplyMsg
+from improv.log import ZmqLogHandler
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
+logger.addHandler(logging.FileHandler("tui.log"))
class SocketLog(TextLog):
@@ -56,16 +59,19 @@ def _simple_formatter(parts):
async def poll(self):
try:
- ready = await self.socket.poll(10)
- if ready:
- parts = await self.socket.recv_multipart()
- msg_type = parts[0].decode("utf-8")
- if msg_type != "DEBUG" or self.print_debug:
- msg = self.format(parts)
- self.write(msg)
- self.post_message(self.Echo(self, msg))
- except asyncio.CancelledError:
- pass
+ if not self.socket.closed:
+ ready = await self.socket.poll(10)
+ if ready:
+ parts = await self.socket.recv_multipart()
+ msg_type = parts[0].decode("utf-8")
+ if msg_type != "DEBUG" or self.print_debug:
+ msg = self.format(parts)
+ self.write(msg)
+ self.post_message(self.Echo(self, msg))
+ except asyncio.CancelledError as e:
+ if not self.socket.closed:
+ self.socket.close(linger=10)
+ raise e
async def on_mount(self) -> None:
"""Event handler called when widget is added to the app."""
@@ -93,9 +99,9 @@ async def on_key(self, event) -> None:
if event.key == "enter":
event.stop()
- def on_button_pressed(self, event: Button.Pressed) -> None:
+ async def on_button_pressed(self, event: Button.Pressed) -> None:
if event.button.id == "quit":
- self.app.exit()
+ await self.app.clean_up_and_exit()
else:
self.app.pop_screen()
@@ -128,18 +134,37 @@ class TUI(App, inherit_bindings=False):
View class for the text user interface. Implemented as a Textual app.
"""
- def __init__(self, control_port, output_port, logging_port):
+ def __init__(
+ self,
+ control_port,
+ output_port,
+ logging_pub_port,
+ logging_pull_port,
+ log_host="localhost",
+ testing=False,
+ ):
super().__init__()
self.title = "improv console"
- self.control_port = TUI._sanitize_addr(control_port)
- self.output_port = TUI._sanitize_addr(output_port)
- self.logging_port = TUI._sanitize_addr(logging_port)
+ self.log_host = log_host
+ self.control_port = TUI._sanitize_addr(control_port, log_host)
+ self.output_port = TUI._sanitize_addr(output_port, log_host)
+ self.logging_pub_port = TUI._sanitize_addr(logging_pub_port, log_host)
+ self.logging_pull_port = logging_pull_port
+ self.testing = testing
self.context = zmq.Context()
self.control_socket = self.context.socket(REQ)
- self.control_socket.connect("tcp://%s" % self.control_port)
+ self.control_socket.connect(f"tcp://{self.control_port}")
+
+ self.logger = logging.getLogger(self.name)
+ self.logger.setLevel(logging.INFO)
+ for handler in logger.handlers:
+ self.logger.addHandler(handler)
+ self.logger.addHandler(
+ ZmqLogHandler(self.log_host, self.logging_pull_port, self.context)
+ )
- logger.info("Text interface initialized")
+ self.logger.info("Text interface initialized")
CSS_PATH = "tui.css"
BINDINGS = [
@@ -154,13 +179,13 @@ def action_set_debug(self):
log_window.print_debug = not log_window.print_debug
@staticmethod
- def _sanitize_addr(input):
- if isinstance(input, int):
- return "localhost:%s" % str(input)
- elif ":" in input:
+ def _sanitize_addr(input, host=None):
+ if ":" in str(input):
return input
+ elif isinstance(input, int) and host is not None:
+ return f"{host}:{input}"
else:
- return "localhost:%s" % input
+ return f"localhost:{input}"
@staticmethod
def format_log_messages(parts):
@@ -186,7 +211,7 @@ def compose(self) -> ComposeResult:
Header("improv console"),
Label("[white]Log Messages[/]"),
SocketLog(
- self.logging_port,
+ self.logging_pub_port,
self.context,
formatter=self.format_log_messages,
markup=True,
@@ -211,43 +236,49 @@ async def send_to_controller(self, msg):
retries_left = REQUEST_RETRIES
try:
- logger.info(f"Sending {msg} to controller.")
- await self.control_socket.send_string(msg)
+ self.logger.info(f"TUI Sending {msg} to controller.")
+ msg_obj = ActorSignalMsg(
+ actor_name="TUI", signal=msg, info="Input from TUI"
+ )
+ await self.control_socket.send_pyobj(msg_obj)
reply = None
while True:
ready = await self.control_socket.poll(REQUEST_TIMEOUT)
if ready:
- reply = await self.control_socket.recv_multipart()
- reply = reply[0].decode("utf-8")
- logger.info(f"Received {reply} from controller.")
+ reply = await self.control_socket.recv_pyobj()
+ self.logger.info(f"TUI Received '{reply.info}' from controller.")
break
else:
retries_left -= 1
- logger.warning("No response from server.")
+ self.logger.warning("No response to TUI from server.")
# try to close and reconnect
self.control_socket.setsockopt(LINGER, 0)
self.control_socket.close()
if retries_left == 0:
- logger.error("Server seems to be offline. Giving up.")
+ self.logger.error("Server seems to be offline. Giving up.")
break
- logger.info("Attempting to reconnect to server...")
+ self.logger.info("TUI attempting to reconnect to server...")
self.control_socket = self.context.socket(REQ)
self.control_socket.connect("tcp://%s" % self.control_port)
- logger.info(f"Resending {msg} to controller.")
- await self.control_socket.send_string(msg)
+ self.logger.info(f"TUI resending {msg} to controller.")
+ await self.control_socket.send_pyobj(msg_obj)
- except asyncio.CancelledError:
+ finally:
pass
- return reply
+ if reply is not None:
+ return reply.info
async def on_mount(self):
+ if not self.testing:
+ reply = await self.send_to_controller("ready")
+ self.query_one("#console").write(reply)
self.set_focus(self.query_one(Input))
async def on_input_submitted(self, message):
@@ -255,11 +286,13 @@ async def on_input_submitted(self, message):
self.query_one("#console").write(message.value)
reply = await self.send_to_controller(message.value)
self.query_one("#console").write(reply)
+ if reply and "QUIT" in reply:
+ await self.clean_up_and_exit()
async def on_socket_log_echo(self, message):
- if message.sender.id == "console" and message.value == "QUIT":
- logger.info("Got QUIT; will try to exit")
- self.exit()
+ if message.sender.id == "console" and "QUIT" in message.value:
+ self.logger.info("TUI got QUIT; will try to exit")
+ await self.clean_up_and_exit()
def action_request_quit(self):
self.push_screen(QuitScreen())
@@ -267,11 +300,15 @@ def action_request_quit(self):
def action_help(self):
self.push_screen(HelpScreen())
+ async def clean_up_and_exit(self):
+ self.exit()
+
if __name__ == "__main__":
CONTROL_PORT = "5555"
OUTPUT_PORT = "5556"
LOGGING_PORT = "5557"
+ LOGGING_PULL_PORT = "5558"
import random
@@ -290,11 +327,20 @@ async def backend():
Fake program to be controlled by TUI.
"""
while True:
- msg = await socket.recv_multipart()
- if msg[0].decode("utf-8") == "quit":
- await socket.send_string("QUIT")
+ msg = await socket.recv_pyobj()
+ if msg.signal == "quit":
+ reply_str = "QUIT"
+ elif msg.signal == "ready":
+ reply_str = "Awaiting input:"
else:
- await socket.send_string("Awaiting input:")
+ reply_str = "Awaiting input:"
+ await socket.send_pyobj(
+ ActorSignalReplyMsg(
+ msg.actor_name,
+ msg.signal,
+ f"Signal {msg.signal} received.\n" + reply_str,
+ )
+ )
async def publish():
"""
@@ -320,12 +366,15 @@ async def log():
counter += 1
async def main_loop():
- app = TUI(CONTROL_PORT, OUTPUT_PORT, LOGGING_PORT)
+ app = TUI(CONTROL_PORT, OUTPUT_PORT, LOGGING_PORT, LOGGING_PULL_PORT)
# the following construct ensures both the
# (infinite) fake servers are killed once the tui finishes
finished, unfinished = await asyncio.wait(
- [app.run_async(), publish(), backend(), log()],
+ [
+ asyncio.create_task(c)
+ for c in (app.run_async(), publish(), backend(), log())
+ ],
return_when=asyncio.FIRST_COMPLETED,
)
diff --git a/improv/utils/checks.py b/improv/utils/checks.py
index a72f4657..97c06f24 100644
--- a/improv/utils/checks.py
+++ b/improv/utils/checks.py
@@ -32,10 +32,17 @@ def check_if_connections_acyclic(path_to_yaml):
# Need to keep only module names
connections = {}
- for key, values in raw.items():
- new_key = key.split(".")[0]
- new_values = [value.split(".")[0] for value in values]
- connections[new_key] = new_values
+ for connection, values in raw.items():
+ for source in values["sources"]:
+ source_name = source.split(".")[0]
+ if source_name not in connections.keys():
+ connections[source_name] = [
+ sink_name.split(".")[0] for sink_name in values["sinks"]
+ ]
+ else:
+ connections[source_name] = connections[source_name] + [
+ sink_name.split(".")[0] for sink_name in values["sinks"]
+ ]
g = nx.DiGraph(connections)
dag = nx.is_directed_acyclic_graph(g)
diff --git a/improv/watcher.py b/improv/watcher.py
deleted file mode 100644
index ad2dbf1f..00000000
--- a/improv/watcher.py
+++ /dev/null
@@ -1,175 +0,0 @@
-import numpy as np
-import asyncio
-
-# import pyarrow.plasma as plasma
-# from multiprocessing import Process, Queue, Manager, cpu_count, set_start_method
-# import subprocess
-# import signal
-# import time
-from queue import Empty
-import logging
-
-import concurrent
-
-# from pyarrow.plasma import ObjectNotAvailable
-from improv.actor import Actor, Signal, RunManager
-from improv.store import ObjectNotFoundError
-import pickle
-
-logger = logging.getLogger(__name__)
-logger.setLevel(logging.INFO)
-
-
-class BasicWatcher(Actor):
- """
- Actor that monitors stored objects from the other actors
- and saves objects that have been flagged by those actors
- """
-
- def __init__(self, *args, inputs=None):
- super().__init__(*args)
-
- self.watchin = inputs
-
- def setup(self):
- """
- set up tasks and polling based on inputs which will
- be used for asynchronous polling of input queues
- """
- self.numSaved = 0
- self.tasks = []
- self.polling = self.watchin
- self.setUp = False
-
- def run(self):
- """
- continually run the watcher to check all of the
- input queues for objects to save
- """
-
- with RunManager(
- self.name, self.watchrun, self.setup, self.q_sig, self.q_comm
- ) as rm:
- logger.info(rm)
-
- print("watcher saved " + str(self.numSaved) + " objects")
-
- def watchrun(self):
- """
- set up async loop for polling
- """
- loop = asyncio.get_event_loop()
- loop.run_until_complete(self.watch())
-
- async def watch(self):
- """
- function for asynchronous polling of input queues
- loops through each of the queues in watchin and checks
- if an object is present and then saves the object if found
- """
-
- if self.setUp is False:
- for q in self.polling:
- self.tasks.append(asyncio.create_task(q.get_async()))
- self.setUp = True
-
- done, pending = await asyncio.wait(
- self.tasks, return_when=concurrent.futures.FIRST_COMPLETED
- )
-
- for i, t in enumerate(self.tasks):
- if t in done or self.polling[i].status == "done":
- r = self.polling[i].result # r is array with id and name of object
- actorID = self.polling[
- i
- ].getStart() # name of actor asking watcher to save the object
- try:
- obj = self.client.getID(r[0])
- np.save("output/saved/" + actorID + r[1], obj)
- except ObjectNotFoundError as e:
- logger.info(e.message)
- pass
- self.tasks[i] = asyncio.create_task(self.polling[i].get_async())
-
-
-class Watcher:
- """Monitors the store as separate process
- TODO: Facilitate Watcher being used in multiple processes (shared list)
- """
-
- # Related to subscribe - could be private, i.e., _subscribe
- def __init__(self, name, client):
- self.name = name
- self.client = client
- self.flag = False
- self.saved_ids = []
-
- self.client.subscribe()
- self.n = 0
-
- def setLinks(self, links):
- self.q_sig = links
-
- def run(self):
- while True:
- if self.flag:
- try:
- self.checkStoreInterface2()
- except Exception as e:
- logger.error("Watcher exception during run: {}".format(e))
- # break
- try:
- signal = self.q_sig.get(timeout=0.005)
- if signal == Signal.run():
- self.flag = True
- logger.warning("Received run signal, begin running")
- elif signal == Signal.quit():
- logger.warning("Received quit signal, aborting")
- break
- elif signal == Signal.pause():
- logger.warning("Received pause signal, pending...")
- self.flag = False
- elif signal == Signal.resume(): # currently treat as same as run
- logger.warning("Received resume signal, resuming")
- self.flag = True
- except Empty:
- pass # no signal from Nexus
-
- # def checkStoreInterface(self):
- # notification_info = self.client.notify()
- # recv_objid, recv_dsize, recv_msize = notification_info
- # obj = self.client.getID(recv_objid)
- # try:
- # self.saveObj(obj)
- # self.n += 1
- # except Exception as e:
- # logger.error('Watcher error: {}'.format(e))
-
- def saveObj(self, obj, name):
- with open(
- "/media/hawkwings/Ext Hard Drive/dump/dump" + name + ".pkl", "wb"
- ) as output:
- pickle.dump(obj, output)
-
- def checkStoreInterface2(self):
- objs = list(self.client.get_all().keys())
- ids_to_save = list(set(objs) - set(self.saved_ids))
-
- # with Pool() as pool:
- # saved_ids = pool.map(saveObjbyID, ids_to_save)
- # print('Saved :', len(saved_ids))
- # self.saved_ids.extend(saved_ids)
-
- for id in ids_to_save:
- self.saveObj(self.client.getID(id), str(id))
- self.saved_ids.append(id)
-
-
-# def saveObjbyID(id):
-# client = plasma.connect('/tmp/store')
-# obj = client.get(id)
-# with open(
-# '/media/hawkwings/Ext\ Hard\ Drive/dump/dump'+str(id)+'.pkl', 'wb'
-# ) as output:
-# pickle.dump(obj, output)
-# return id
diff --git a/pyproject.toml b/pyproject.toml
index bf5a42b7..fd185aad 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -16,7 +16,6 @@ dependencies = [
"numpy<=1.26",
"scipy",
"matplotlib",
- "pyarrow==9.0.0",
"PyQt5",
"pyyaml",
"textual==0.15.0",
@@ -56,6 +55,7 @@ exclude = ["test", "pytest", "env", "demos", "figures"]
[tool.pytest.ini_options]
asyncio_mode = "auto"
filterwarnings = [ ]
+asyncio_default_fixture_loop_scope = "function"
log_cli = true
log_cli_level = "INFO"
diff --git a/test/actors/actor_for_bad_args.py b/test/actors/actor_for_bad_args.py
new file mode 100644
index 00000000..bac95922
--- /dev/null
+++ b/test/actors/actor_for_bad_args.py
@@ -0,0 +1,3 @@
+class BadActorArgs:
+ def __init__(self, existent_parameter, another_required_param):
+ pass
diff --git a/test/actors/sample_generator.py b/test/actors/sample_generator.py
index 9d1c3891..5afc09fc 100644
--- a/test/actors/sample_generator.py
+++ b/test/actors/sample_generator.py
@@ -1,4 +1,4 @@
-from improv.actor import Actor
+from improv.actor import ZmqActor
from datetime import date # used for saving
import numpy as np
import logging
@@ -7,7 +7,7 @@
logger.setLevel(logging.INFO)
-class Generator(Actor):
+class Generator(ZmqActor):
"""Sample actor to generate data to pass into a sample processor.
Intended for use along with sample_processor.py.
@@ -37,7 +37,7 @@ def setup(self):
# """
# self.fcns = {}
# self.fcns['setup'] = self.setup
- # self.fcns['run'] = self.runStep
+ # self.fcns['run'] = self.run_step
# self.fcns['stop'] = self.stop
# with RunManager(self.name, self.fcns, self.links) as rm:
@@ -52,7 +52,7 @@ def stop(self):
# will overwrite previous files with the same name.
return 0
- def runStep(self):
+ def run_step(self):
"""Generates additional data after initial setup data is exhausted.
Data is of a different form as the setup data in that although it is
diff --git a/test/actors/sample_generator_wrong_import.py b/test/actors/sample_generator_wrong_import.py
index d509868e..0e8cec84 100644
--- a/test/actors/sample_generator_wrong_import.py
+++ b/test/actors/sample_generator_wrong_import.py
@@ -39,7 +39,7 @@ def stop(self):
np.save("sample_generator_data.npy", self.data)
return 0
- def runStep(self):
+ def run_step(self):
"""Generates additional data after initial setup data is exhausted.
Data is of a different form as the setup data in that although it is
diff --git a/test/actors/sample_generator_wrong_init.py b/test/actors/sample_generator_wrong_init.py
index 9d1c3891..405ab92f 100644
--- a/test/actors/sample_generator_wrong_init.py
+++ b/test/actors/sample_generator_wrong_init.py
@@ -1,76 +1,2 @@
-from improv.actor import Actor
-from datetime import date # used for saving
-import numpy as np
-import logging
-
-logger = logging.getLogger(__name__)
-logger.setLevel(logging.INFO)
-
-
-class Generator(Actor):
- """Sample actor to generate data to pass into a sample processor.
-
- Intended for use along with sample_processor.py.
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.data = None
- self.name = "Generator"
- self.frame_num = 0
-
- def __str__(self):
- return f"Name: {self.name}, Data: {self.data}"
-
- def setup(self):
- """Generates an array that serves as an initial source of data.
-
- Initial array is a 100 row, 5 column numpy matrix that contains
- integers from 1-99, inclusive.
- """
-
- self.data = np.asmatrix(np.random.randint(100, size=(100, 5)))
- logger.info("Completed setup for Generator")
-
- # def run(self):
- # """ Send array into the store.
- # """
- # self.fcns = {}
- # self.fcns['setup'] = self.setup
- # self.fcns['run'] = self.runStep
- # self.fcns['stop'] = self.stop
-
- # with RunManager(self.name, self.fcns, self.links) as rm:
- # logger.info(rm)
-
- def stop(self):
- """Save current randint vector to a file."""
-
- print("Generator stopping")
- np.save(f"sample_generator_data_{date.today()}", self.data)
- # This is not the best example of a save function,
- # will overwrite previous files with the same name.
- return 0
-
- def runStep(self):
- """Generates additional data after initial setup data is exhausted.
-
- Data is of a different form as the setup data in that although it is
- the same size (5x1 vector), it is uniformly distributed in [1, 10]
- instead of in [1, 100]. Therefore, the average over time should
- converge to 5.5.
- """
-
- if self.frame_num < np.shape(self.data)[0]:
- data_id = self.client.put(
- self.data[self.frame_num], str(f"Gen_raw: {self.frame_num}")
- )
- try:
- self.q_out.put([[data_id, str(self.frame_num)]])
- self.frame_num += 1
- except Exception as e:
- logger.error(f"Generator Exception: {e}")
- else:
- self.data = np.concatenate(
- (self.data, np.asmatrix(np.random.randint(10, size=(1, 5)))), axis=0
- )
+class GeneratorActor:
+ pass
diff --git a/demos/minimal/actors/sample_generator.py b/test/actors/sample_generator_zmq.py
similarity index 62%
rename from demos/minimal/actors/sample_generator.py
rename to test/actors/sample_generator_zmq.py
index 31335833..56fd14e3 100644
--- a/demos/minimal/actors/sample_generator.py
+++ b/test/actors/sample_generator_zmq.py
@@ -1,4 +1,5 @@
-from improv.actor import Actor
+from improv.actor import ZmqActor
+from datetime import date # used for saving
import numpy as np
import logging
@@ -6,7 +7,7 @@
logger.setLevel(logging.INFO)
-class Generator(Actor):
+class Generator(ZmqActor):
"""Sample actor to generate data to pass into a sample processor.
Intended for use along with sample_processor.py.
@@ -28,18 +29,30 @@ def setup(self):
integers from 1-99, inclusive.
"""
- logger.info("Beginning setup for Generator")
self.data = np.asmatrix(np.random.randint(100, size=(100, 5)))
logger.info("Completed setup for Generator")
+ # def run(self):
+ # """ Send array into the store.
+ # """
+ # self.fcns = {}
+ # self.fcns['setup'] = self.setup
+ # self.fcns['run'] = self.run_step
+ # self.fcns['stop'] = self.stop
+
+ # with RunManager(self.name, self.fcns, self.links) as rm:
+ # logger.info(rm)
+
def stop(self):
"""Save current randint vector to a file."""
- logger.info("Generator stopping")
- np.save("sample_generator_data.npy", self.data)
+ print("Generator stopping")
+ np.save(f"sample_generator_data_{date.today()}", self.data)
+ # This is not the best example of a save function,
+ # will overwrite previous files with the same name.
return 0
- def runStep(self):
+ def run_step(self):
"""Generates additional data after initial setup data is exhausted.
Data is of a different form as the setup data in that although it is
@@ -49,25 +62,14 @@ def runStep(self):
"""
if self.frame_num < np.shape(self.data)[0]:
- if self.store_loc:
- data_id = self.client.put(
- self.data[self.frame_num], str(f"Gen_raw: {self.frame_num}")
- )
- else:
- data_id = self.client.put(self.data[self.frame_num])
- # logger.info('Put data in store')
+ data_id = self.client.put(self.data[self.frame_num])
try:
- if self.store_loc:
- self.q_out.put([[data_id, str(self.frame_num)]])
- else:
- self.q_out.put(data_id)
- # logger.info("Sent message on")
-
+ self.q_out.put(data_id)
+ # logger.info(f"Sent {self.data[self.frame_num]} with key {data_id}")
self.frame_num += 1
+
except Exception as e:
- logger.error(
- f"--------------------------------Generator Exception: {e}"
- )
+ logger.error(f"Generator Exception: {e}")
else:
self.data = np.concatenate(
(self.data, np.asmatrix(np.random.randint(10, size=(1, 5)))), axis=0
diff --git a/test/actors/sample_processor.py b/test/actors/sample_processor.py
index cd706465..cc4c2a6f 100644
--- a/test/actors/sample_processor.py
+++ b/test/actors/sample_processor.py
@@ -41,13 +41,13 @@ def stop(self):
# """
# self.fcns = {}
# self.fcns['setup'] = self.setup
- # self.fcns['run'] = self.runStep
+ # self.fcns['run'] = self.run_step
# self.fcns['stop'] = self.stop
# with RunManager(self.name, self.fcns, self.links) as rm:
# logger.info(rm)
- def runStep(self):
+ def run_step(self):
"""Gets from the input queue and calculates the average.
Receives an ObjectID, references data in the store using that
diff --git a/demos/zmq/actors/zmq_rr_sample_processor.py b/test/actors/sample_processor_zmq.py
similarity index 58%
rename from demos/zmq/actors/zmq_rr_sample_processor.py
rename to test/actors/sample_processor_zmq.py
index e4ed9cd3..36eec5f8 100644
--- a/demos/zmq/actors/zmq_rr_sample_processor.py
+++ b/test/actors/sample_processor_zmq.py
@@ -1,34 +1,33 @@
-from improv.actor import Actor, AsyncActor
+from improv.actor import ZmqActor
import numpy as np
import logging
-from demos.sample_actors.zmqActor import ZmqActor
-
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class Processor(ZmqActor):
- """Sample processor used to calculate the average of an array of integers
- using async ZMQ to communicate.
+ """Sample processor used to calculate the average of an array of integers.
Intended for use with sample_generator.py.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
+ if "name" in kwargs:
+ self.name = kwargs["name"]
def setup(self):
"""Initializes all class variables.
- Sets up a ZmqRRActor to receive data from the generator.
self.name (string): name of the actor.
- self.frame (ObjectID): Store object id referencing data from the store.
+ self.frame (ObjectID): StoreInterface object id referencing data from the store.
self.avg_list (list): list that contains averages of individual vectors.
self.frame_num (int): index of current frame.
"""
- self.name = "Processor"
+ if not hasattr(self, "name"):
+ self.name = "Processor"
self.frame = None
self.avg_list = []
self.frame_num = 1
@@ -38,10 +37,21 @@ def stop(self):
"""Trivial stop function for testing purposes."""
logger.info("Processor stopping")
+ return 0
+
+ # def run(self):
+ # """ Send array into the store.
+ # """
+ # self.fcns = {}
+ # self.fcns['setup'] = self.setup
+ # self.fcns['run'] = self.run_step
+ # self.fcns['stop'] = self.stop
+
+ # with RunManager(self.name, self.fcns, self.links) as rm:
+ # logger.info(rm)
- def runStep(self):
+ def run_step(self):
"""Gets from the input queue and calculates the average.
- Receives data from the generator using a ZmqRRActor.
Receives an ObjectID, references data in the store using that
ObjectID, calculates the average of that data, and finally prints
@@ -49,21 +59,19 @@ def runStep(self):
"""
frame = None
-
try:
- frame = self.get(reply='received')
-
- except:
- logger.error("Could not get frame!")
+ frame = self.q_in.get(timeout=0.05)
+ except Exception as e:
+ logger.error(f"{self.name} could not get frame! At {self.frame_num}: {e}")
pass
if frame is not None and self.frame_num is not None:
self.done = False
- self.frame = self.client.getID(frame)
+ self.frame = self.client.get(frame)
avg = np.mean(self.frame[0])
-
+ # logger.info(f"{self.name} got frame {frame} with value {self.frame}")
# logger.info(f"Average: {avg}")
self.avg_list.append(avg)
- logger.info(f"Overall Average: {np.mean(self.avg_list)}")
+ # logger.info(f"Overall Average: {np.mean(self.avg_list)}")
# logger.info(f"Frame number: {self.frame_num}")
self.frame_num += 1
diff --git a/test/configs/bad_args.yaml b/test/configs/bad_args.yaml
index be94212c..ce96b6b6 100644
--- a/test/configs/bad_args.yaml
+++ b/test/configs/bad_args.yaml
@@ -1,20 +1,7 @@
actors:
Acquirer:
- package: demos.sample_actors.acquire
- class: FileAcquirer
- fiasdfe: data/Tolias_mesoscope_2.hdf5
- fraefawe: 30
-
- Processor:
- package: demos.sample_actors.process
- class: CaimanProcessor
- init_filename: data/tbif_ex_crop.h5
- config_file: eva_caiman_params.txt
-
- Analysis:
- package: demos.sample_actors.analysis
- class: MeanAnalysis
+ package: actors.actor_for_bad_args
+ class: BadActorArgs
+ existent_parameter: 42
connections:
- Acquirer.q_out: [Processor.q_in]
- Processor.q_out: [Analysis.q_in]
diff --git a/test/configs/complex_graph.yaml b/test/configs/complex_graph.yaml
index 3c82029a..16961825 100644
--- a/test/configs/complex_graph.yaml
+++ b/test/configs/complex_graph.yaml
@@ -14,8 +14,14 @@ actors:
class: BehaviorAcquirer
connections:
- Acquirer.q_out: [Analysis.q_in, InputStim.q_in]
- Analysis.q_out: [InputStim.q_in]
-
-# settings:
-# use_watcher: [Acquirer, Processor, Visual, Analysis]
+ Acquirer-out:
+ sources:
+ - Acquirer.q_out
+ sinks:
+ - Analysis.q_in
+ - InputStim.q_in
+ Analysis-InputStim:
+ sources:
+ - Analysis.q_out
+ sinks:
+ - InputStim.q_in
diff --git a/test/configs/cyclic_config.yaml b/test/configs/cyclic_config.yaml
index 61c91303..8e87b880 100644
--- a/test/configs/cyclic_config.yaml
+++ b/test/configs/cyclic_config.yaml
@@ -14,5 +14,14 @@ actors:
class: BehaviorAcquirer
connections:
- Acquirer.q_out: [Analysis.q_in, InputStim.q_in]
- Analysis.q_out: [Acquirer.q_in]
+ Acquirer-out:
+ sources:
+ - Acquirer.q_out
+ sinks:
+ - Analysis.q_in
+ - InputStim.q_in
+ Analysis-Acquirer:
+ sources:
+ - Analysis.q_out
+ sinks:
+ - Acquirer.q_in
diff --git a/test/configs/good_config.yaml b/test/configs/good_config.yaml
index b0c70b70..6f72da5c 100644
--- a/test/configs/good_config.yaml
+++ b/test/configs/good_config.yaml
@@ -10,4 +10,8 @@ actors:
class: SimpleAnalysis
connections:
- Acquirer.q_out: [Analysis.q_in]
+ Acquirer-Analysis:
+ sources:
+ - Acquirer.q_out
+ sinks:
+ - Analysis.q_in
diff --git a/test/configs/good_config_actors.yaml b/test/configs/good_config_actors.yaml
deleted file mode 100644
index e84fd262..00000000
--- a/test/configs/good_config_actors.yaml
+++ /dev/null
@@ -1,8 +0,0 @@
-Acquirer:
- class: FileAcquirer
- filename: data/Tolias_mesoscope_2.hdf5
- framerate: 30
- package: demos.sample_actors.acquire
-Analysis:
- class: SimpleAnalysis
- package: demos.sample_actors.simple_analysis
diff --git a/test/configs/good_config_plasma.yaml b/test/configs/good_config_plasma.yaml
deleted file mode 100644
index 8323017b..00000000
--- a/test/configs/good_config_plasma.yaml
+++ /dev/null
@@ -1,15 +0,0 @@
-actors:
- Acquirer:
- package: demos.sample_actors.acquire
- class: FileAcquirer
- filename: data/Tolias_mesoscope_2.hdf5
- framerate: 30
-
- Analysis:
- package: demos.sample_actors.simple_analysis
- class: SimpleAnalysis
-
-connections:
- Acquirer.q_out: [Analysis.q_in]
-
-plasma_config:
\ No newline at end of file
diff --git a/test/configs/minimal.yaml b/test/configs/minimal.yaml
index 230282d1..d752b415 100644
--- a/test/configs/minimal.yaml
+++ b/test/configs/minimal.yaml
@@ -1,11 +1,15 @@
actors:
Generator:
- package: actors.sample_generator
+ package: actors.sample_generator_zmq
class: Generator
Processor:
- package: actors.sample_processor
+ package: actors.sample_processor_zmq
class: Processor
connections:
- Generator.q_out: [Processor.q_in]
\ No newline at end of file
+ Generator-Processor:
+ sources:
+ - Generator.q_out
+ sinks:
+ - Processor.q_in
diff --git a/test/configs/minimal_harvester.yaml b/test/configs/minimal_harvester.yaml
new file mode 100644
index 00000000..74847433
--- /dev/null
+++ b/test/configs/minimal_harvester.yaml
@@ -0,0 +1,18 @@
+actors:
+ Generator:
+ package: actors.sample_generator_zmq
+ class: Generator
+
+ Processor:
+ package: actors.sample_processor_zmq
+ class: Processor
+
+connections:
+ Generator-Processor:
+ sources:
+ - Generator.q_out
+ sinks:
+ - Processor.q_in
+
+settings:
+ harvest_data_from_memory: True
\ No newline at end of file
diff --git a/test/configs/minimal_with_custom_aof_dirname.yaml b/test/configs/minimal_with_custom_aof_dirname.yaml
index 1bac5863..f07c0c44 100644
--- a/test/configs/minimal_with_custom_aof_dirname.yaml
+++ b/test/configs/minimal_with_custom_aof_dirname.yaml
@@ -1,14 +1,18 @@
actors:
Generator:
- package: actors.sample_generator
+ package: actors.sample_generator_zmq
class: Generator
Processor:
- package: actors.sample_processor
+ package: actors.sample_processor_zmq
class: Processor
connections:
- Generator.q_out: [Processor.q_in]
+ Generator-Processor:
+ sources:
+ - Generator.q_out
+ sinks:
+ - Processor.q_in
redis_config:
aof_dirname: custom_aof_dirname
\ No newline at end of file
diff --git a/test/configs/minimal_with_ephemeral_aof_dirname.yaml b/test/configs/minimal_with_ephemeral_aof_dirname.yaml
index 9933896e..d4f5a635 100644
--- a/test/configs/minimal_with_ephemeral_aof_dirname.yaml
+++ b/test/configs/minimal_with_ephemeral_aof_dirname.yaml
@@ -1,14 +1,18 @@
actors:
Generator:
- package: actors.sample_generator
+ package: actors.sample_generator_zmq
class: Generator
Processor:
- package: actors.sample_processor
+ package: actors.sample_processor_zmq
class: Processor
connections:
- Generator.q_out: [Processor.q_in]
+ Generator-Processor:
+ sources:
+ - Generator.q_out
+ sinks:
+ - Processor.q_in
redis_config:
generate_ephemeral_aof_dirname: True
\ No newline at end of file
diff --git a/test/configs/minimal_with_every_second_saving.yaml b/test/configs/minimal_with_every_second_saving.yaml
index 74d4b314..6d381f72 100644
--- a/test/configs/minimal_with_every_second_saving.yaml
+++ b/test/configs/minimal_with_every_second_saving.yaml
@@ -1,14 +1,18 @@
actors:
Generator:
- package: actors.sample_generator
+ package: actors.sample_generator_zmq
class: Generator
Processor:
- package: actors.sample_processor
+ package: actors.sample_processor_zmq
class: Processor
connections:
- Generator.q_out: [Processor.q_in]
+ Generator-Processor:
+ sources:
+ - Generator.q_out
+ sinks:
+ - Processor.q_in
redis_config:
enable_saving: True
diff --git a/test/configs/minimal_with_every_write_saving.yaml b/test/configs/minimal_with_every_write_saving.yaml
index add0c74f..4f1e3af4 100644
--- a/test/configs/minimal_with_every_write_saving.yaml
+++ b/test/configs/minimal_with_every_write_saving.yaml
@@ -1,14 +1,18 @@
actors:
Generator:
- package: actors.sample_generator
+ package: actors.sample_generator_zmq
class: Generator
Processor:
- package: actors.sample_processor
+ package: actors.sample_processor_zmq
class: Processor
connections:
- Generator.q_out: [Processor.q_in]
+ Generator-Processor:
+ sources:
+ - Generator.q_out
+ sinks:
+ - Processor.q_in
redis_config:
enable_saving: True
diff --git a/test/configs/minimal_with_fixed_default_redis_port.yaml b/test/configs/minimal_with_fixed_default_redis_port.yaml
index c8e0b24e..8d33ded9 100644
--- a/test/configs/minimal_with_fixed_default_redis_port.yaml
+++ b/test/configs/minimal_with_fixed_default_redis_port.yaml
@@ -1,14 +1,18 @@
actors:
Generator:
- package: actors.sample_generator
+ package: actors.sample_generator_zmq
class: Generator
Processor:
- package: actors.sample_processor
+ package: actors.sample_processor_zmq
class: Processor
connections:
- Generator.q_out: [Processor.q_in]
+ Generator-Processor:
+ sources:
+ - Generator.q_out
+ sinks:
+ - Processor.q_in
redis_config:
port: 6379
\ No newline at end of file
diff --git a/test/configs/minimal_with_fixed_redis_port.yaml b/test/configs/minimal_with_fixed_redis_port.yaml
index f6f98253..b63471f4 100644
--- a/test/configs/minimal_with_fixed_redis_port.yaml
+++ b/test/configs/minimal_with_fixed_redis_port.yaml
@@ -1,14 +1,18 @@
actors:
Generator:
- package: actors.sample_generator
+ package: actors.sample_generator_zmq
class: Generator
Processor:
- package: actors.sample_processor
+ package: actors.sample_processor_zmq
class: Processor
connections:
- Generator.q_out: [Processor.q_in]
+ Generator-Processor:
+ sources:
+ - Generator.q_out
+ sinks:
+ - Processor.q_in
redis_config:
port: 6378
\ No newline at end of file
diff --git a/test/configs/minimal_with_no_schedule_saving.yaml b/test/configs/minimal_with_no_schedule_saving.yaml
index 2e5f62f3..20731f3d 100644
--- a/test/configs/minimal_with_no_schedule_saving.yaml
+++ b/test/configs/minimal_with_no_schedule_saving.yaml
@@ -1,14 +1,18 @@
actors:
Generator:
- package: actors.sample_generator
+ package: actors.sample_generator_zmq
class: Generator
Processor:
- package: actors.sample_processor
+ package: actors.sample_processor_zmq
class: Processor
connections:
- Generator.q_out: [Processor.q_in]
+ Generator-Processor:
+ sources:
+ - Generator.q_out
+ sinks:
+ - Processor.q_in
redis_config:
enable_saving: True
diff --git a/test/configs/minimal_with_redis_saving.yaml b/test/configs/minimal_with_redis_saving.yaml
index 15e95b6c..087f685d 100644
--- a/test/configs/minimal_with_redis_saving.yaml
+++ b/test/configs/minimal_with_redis_saving.yaml
@@ -1,14 +1,18 @@
actors:
Generator:
- package: actors.sample_generator
+ package: actors.sample_generator_zmq
class: Generator
Processor:
- package: actors.sample_processor
+ package: actors.sample_processor_zmq
class: Processor
connections:
- Generator.q_out: [Processor.q_in]
+ Generator-Processor:
+ sources:
+ - Generator.q_out
+ sinks:
+ - Processor.q_in
redis_config:
enable_saving: True
\ No newline at end of file
diff --git a/test/configs/minimal_with_settings.yaml b/test/configs/minimal_with_settings.yaml
index a5aec2d8..daff183d 100644
--- a/test/configs/minimal_with_settings.yaml
+++ b/test/configs/minimal_with_settings.yaml
@@ -4,16 +4,19 @@ settings:
control_port: 6000
output_port: 6001
logging_port: 6002
- use_watcher: false
actors:
Generator:
- package: actors.sample_generator
+ package: actors.sample_generator_zmq
class: Generator
Processor:
- package: actors.sample_processor
+ package: actors.sample_processor_zmq
class: Processor
connections:
- Generator.q_out: [Processor.q_in]
\ No newline at end of file
+ Generator-Processor:
+ sources:
+ - Generator.q_out
+ sinks:
+ - Processor.q_in
diff --git a/test/configs/minimal_wrong_import.yaml b/test/configs/minimal_wrong_import.yaml
index 6c5e51c5..3fbc7822 100644
--- a/test/configs/minimal_wrong_import.yaml
+++ b/test/configs/minimal_wrong_import.yaml
@@ -4,8 +4,12 @@ actors:
class: Generator
Processor:
- package: actors.sample_processor
+ package: actors.sample_processor_zmq
class: Processor
connections:
- Generator.q_out: [Processor.q_in]
\ No newline at end of file
+ Generator-Processor:
+ sources:
+ - Generator.q_out
+ sinks:
+ - Processor.q_in
\ No newline at end of file
diff --git a/test/configs/minimal_wrong_init.yaml b/test/configs/minimal_wrong_init.yaml
index a59c954a..edb45b51 100644
--- a/test/configs/minimal_wrong_init.yaml
+++ b/test/configs/minimal_wrong_init.yaml
@@ -4,8 +4,12 @@ actors:
class: Generator
Processor:
- package: actors.sample_processor
+ package: actors.sample_processor_zmq
class: Processor
connections:
- Generator.q_out: [Processor.q_in]
\ No newline at end of file
+ Generator-Processor:
+ sources:
+ - Generator.q_out
+ sinks:
+ - Processor.q_in
\ No newline at end of file
diff --git a/test/configs/simple_graph.yaml b/test/configs/simple_graph.yaml
index 67086def..9238739c 100644
--- a/test/configs/simple_graph.yaml
+++ b/test/configs/simple_graph.yaml
@@ -10,7 +10,8 @@ actors:
class: SimpleAnalysis
connections:
- Acquirer.q_out: [Analysis.q_in]
-
-# settings:
-# use_watcher: [Acquirer, Processor, Visual, Analysis]
+ Acquirer-Analysis:
+ sources:
+ - Acquirer.q_out
+ sinks:
+ - Analysis.q_in
diff --git a/test/configs/single_actor.yaml b/test/configs/single_actor.yaml
index 812480b1..c9c6b0e2 100644
--- a/test/configs/single_actor.yaml
+++ b/test/configs/single_actor.yaml
@@ -6,5 +6,3 @@ actors:
framerate: 15
connections:
-# settings:
-# use_watcher: [Acquirer, Processor, Visual, Analysis]
diff --git a/test/configs/single_actor_plasma.yaml b/test/configs/single_actor_plasma.yaml
deleted file mode 100644
index 812480b1..00000000
--- a/test/configs/single_actor_plasma.yaml
+++ /dev/null
@@ -1,10 +0,0 @@
-actors:
- Acquirer:
- package: demos.sample_actors.acquire
- class: FileAcquirer
- filename: data/Tolias_mesoscope_2.hdf5
- framerate: 15
-
-connections:
-# settings:
-# use_watcher: [Acquirer, Processor, Visual, Analysis]
diff --git a/test/conftest.py b/test/conftest.py
index f9b80fa5..3d898524 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -1,22 +1,86 @@
+import logging
+import multiprocessing
import os
import signal
-import uuid
+import time
+
import pytest
import subprocess
-store_loc = str(os.path.join("/tmp/", str(uuid.uuid4())))
-redis_port_num = 6379
-WAIT_TIMEOUT = 120
+import zmq
+
+from improv.actor import ZmqActor
+from improv.harvester import bootstrap_harvester
+from improv.nexus import Nexus
+
+REDIS_PORT_NUM = 6379
+WAIT_TIMEOUT = 20
+
+SERVER_COUNTER = 0
+
+signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
+
+
+@pytest.fixture
+def ports():
+ global SERVER_COUNTER
+ CONTROL_PORT = 30000
+ OUTPUT_PORT = 30001
+ LOGGING_PORT = 30002
+ LOGGING_INPUT_PORT = 30003
+ yield (
+ CONTROL_PORT + SERVER_COUNTER,
+ OUTPUT_PORT + SERVER_COUNTER,
+ LOGGING_PORT + SERVER_COUNTER,
+ LOGGING_INPUT_PORT + SERVER_COUNTER,
+ )
+ SERVER_COUNTER += 4
@pytest.fixture
-def set_store_loc():
- return store_loc
+def setdir():
+ prev = os.getcwd()
+ os.chdir(os.path.dirname(__file__) + "/configs")
+ yield None
+ os.chdir(prev)
+
+
+@pytest.fixture
+def set_dir_config_parent():
+ prev = os.getcwd()
+ os.chdir(os.path.dirname(__file__))
+ yield None
+ os.chdir(prev)
+
+
+@pytest.fixture
+def sample_nex(setdir, ports):
+ nex = Nexus("test")
+ try:
+ nex.create_nexus(
+ file="good_config.yaml",
+ store_size=40000000,
+ control_port=ports[0],
+ output_port=ports[1],
+ log_server_pub_port=ports[2],
+ log_server_pull_port=ports[3],
+ )
+ except Exception as e:
+ print(f"error caught in test harness during create_nexus step: {e}")
+ logging.error(f"error caught in test harness during create_nexus step: {e}")
+ raise e
+ yield nex
+ try:
+ nex.destroy_nexus()
+ except Exception as e:
+ print(f"error caught in test harness during destroy_nexus step: {e}")
+ logging.error(f"error caught in test harness during destroy_nexus step: {e}")
+ raise e
@pytest.fixture
def server_port_num():
- return redis_port_num
+ return REDIS_PORT_NUM
@pytest.fixture
@@ -37,21 +101,70 @@ def setup_store(server_port_num):
stderr=subprocess.DEVNULL,
)
+ time.sleep(3)
+ if p.poll() is not None:
+ raise Exception("redis-server failed to start")
+
yield p
# kill the subprocess when the caller is done with it
- p.send_signal(signal.SIGINT)
- p.wait(WAIT_TIMEOUT)
+ if p.poll() is None:
+ p.send_signal(signal.SIGINT)
+ p.wait(timeout=WAIT_TIMEOUT)
@pytest.fixture
-def setup_plasma_store(set_store_loc, scope="module"):
- """Fixture to set up the store subprocess with 10 mb."""
- p = subprocess.Popen(
- ["plasma_store", "-s", set_store_loc, "-m", str(10000000)],
- stdout=subprocess.DEVNULL,
- stderr=subprocess.DEVNULL,
- )
+def zmq_actor(ports):
+ actor = ZmqActor(ports[0], None, None, ports[2], None, None, name="test")
+
+ p = multiprocessing.Process(target=actor_startup, args=(actor,))
+
yield p
- p.send_signal(signal.SIGINT)
- p.wait(WAIT_TIMEOUT)
+
+ p.terminate()
+ p.join(WAIT_TIMEOUT)
+ if p.exitcode is None:
+ p.kill()
+
+
+def actor_startup(actor):
+ actor.setup_logging()
+ actor.register_with_nexus()
+
+
+@pytest.fixture
+def harvester(ports):
+ ctx = zmq.Context()
+ socket = ctx.socket(zmq.PUB)
+ socket.bind("tcp://*:1234")
+ p = multiprocessing.Process(
+ target=bootstrap_harvester,
+ args=(
+ "localhost",
+ ports[0],
+ "localhost",
+ 6379,
+ "localhost",
+ 1234,
+ "localhost",
+ 12345,
+ ),
+ )
+ p.start()
+ time.sleep(1)
+ yield ports, socket, p
+ socket.close(linger=0)
+ ctx.destroy(linger=0)
+
+
+class SignalManager:
+ def __init__(self):
+ self.signal_handlers = dict()
+
+ def __enter__(self):
+ for sig in signals:
+ self.signal_handlers[sig] = signal.getsignal(sig)
+
+ def __exit__(self, type, value, traceback):
+ for sig, handler in self.signal_handlers.items():
+ signal.signal(sig, handler)
diff --git a/test/conftest_with_errors.py b/test/conftest_with_errors.py
deleted file mode 100644
index c299959d..00000000
--- a/test/conftest_with_errors.py
+++ /dev/null
@@ -1,78 +0,0 @@
-# import pytest
-import subprocess
-import asyncio
-from improv.actor import RunManager, AsyncRunManager
-import os
-import uuid
-
-
-class StoreInterfaceDependentTestCase:
- def set_up(self):
- """Start the server"""
- print("Setting up Plasma store.")
- store_loc = str(os.path.join("/tmp/", str(uuid.uuid4())))
- self.p = subprocess.Popen(
- ["plasma_store", "-s", store_loc, "-m", str(10000000)],
- stdout=subprocess.DEVNULL,
- stderr=subprocess.DEVNULL,
- )
-
- def tear_down(self):
- """Kill the server"""
- print("Tearing down Plasma store.")
- self.p.kill()
- self.p.wait()
-
-
-class ActorDependentTestCase:
- def set_up(self):
- """Start the server"""
- print("Setting up Plasma store.")
- store_loc = str(os.path.join("/tmp/", str(uuid.uuid4())))
- self.p = subprocess.Popen(
- ["plasma_store", "-s", store_loc, "-m", str(10000000)],
- stdout=subprocess.DEVNULL,
- stderr=subprocess.DEVNULL,
- )
-
- def tear_down(self):
- """Kill the server"""
- print("Tearing down Plasma store.")
- self.p.kill()
- self.p.wait()
-
- def run_setup(self):
- print("Set up = True.")
- self.is_set_up = True
-
- def run_method(self):
- # Accurate print statement?
- print("Running method.")
- self.run_num += 1
-
- def process_setup(self):
- print("Processing setup.")
- pass
-
- def process_run(self):
- print("Processing run - ran.")
- self.q_comm.put("ran")
-
- def create_process(self, q_sig, q_comm):
- print("Creating process.")
- with RunManager(
- "test", self.process_run, self.process_setup, q_sig, q_comm
- ) as rm:
- print(rm)
-
- async def createAsyncProcess(self, q_sig, q_comm):
- print("Creating asyn process.")
- async with AsyncRunManager(
- "test", self.process_run, self.process_setup, q_sig, q_comm
- ) as rm:
- print(rm)
-
- async def a_put(self, signal, time):
- print("Async put.")
- await asyncio.sleep(time)
- self.q_sig.put_async(signal)
diff --git a/test/nexus_analog.py b/test/nexus_analog.py
deleted file mode 100644
index 0fc698ec..00000000
--- a/test/nexus_analog.py
+++ /dev/null
@@ -1,119 +0,0 @@
-import concurrent
-import time
-import asyncio
-import math
-import uuid
-import os
-from improv.link import Link
-from improv.store import StoreInterface
-import subprocess
-
-
-def clean_list_print(lst):
- print("\n=======================\n")
- for el in lst:
- print(el)
- print("\n")
- print("\n=======================\n")
-
-
-def setup_store():
- """Fixture to set up the store subprocess with 10 mb.
-
- This fixture runs a subprocess that instantiates the store with a
- memory of 10 megabytes. It specifies that "/tmp/store/" is the
- location of the store socket.
-
- Yields:
- StoreInterface: An instance of the store.
-
- TODO:
- Figure out the scope.
- """
- store_loc = str(os.path.join("/tmp/", str(uuid.uuid4())))
- subprocess.Popen(
- ["plasma_store", "-s", store_loc, "-m", str(10000000)],
- stdout=subprocess.DEVNULL,
- stderr=subprocess.DEVNULL,
- )
- store = StoreInterface(store_loc=store_loc)
- return store
-
-
-async def pollQueues(links):
- tasks = []
- for link in links:
- tasks.append(asyncio.create_task(link.get_async()))
-
- links_cpy = links
- t_0 = time.perf_counter()
- t_1 = time.perf_counter()
- print("time get")
- cur_time = 0
- while t_1 - t_0 < 5:
- # need to add something to the queue such that asyncio.wait returns
-
- links[0].put("Message")
- done, pending = await asyncio.wait(
- tasks, return_when=concurrent.futures.FIRST_COMPLETED
- )
- for i, t in enumerate(tasks):
- if t in done:
- pass
- tasks[i] = asyncio.create_task(links_cpy[i].get_async())
-
- t_1 = time.perf_counter()
-
- if math.floor(t_1 - t_0) != cur_time:
- print(math.floor(t_1 - t_0))
- cur_time = math.floor(t_1 - t_0)
-
- print("All tasks prior to stop polling: \n")
- clean_list_print([task for task in tasks])
-
- asyncio.get_running_loop()
- return stop_polling(tasks, links)
-
-
-def start():
- links = [
- Link(f"Link {i}", f"start {i}", f"end {i}", setup_store()) for i in range(4)
- ]
- loop = asyncio.get_event_loop()
- print("RUC loop")
- res = loop.run_until_complete(pollQueues(links))
- print(f"RES: {res}")
-
- print("**********************\nAll tasks at the end of execution:")
- clean_list_print(res)
- print("**********************")
- print(f"Loop: {loop}")
- loop.close()
- print(f"Loop: {loop}")
-
-
-def stop_polling(tasks, links):
- # asyncio.gather(*tasks)
- print("Cancelling")
-
- [lnk.put("msg") for lnk in links]
-
- # [task.cancel() for task in tasks]
- # [task.cancel() for task in tasks]
-
- # [lnk.put("msg") for lnk in links]
-
- print("All tasks: \n")
- clean_list_print([task for task in tasks])
- print("Pending:\n")
- clean_list_print([task for task in tasks if not task.done()])
- print("Cancelled: \n")
- clean_list_print([task for task in tasks if task.cancelled()])
- print("Pending and cancelled: \n")
- clean_list_print([task for task in tasks if not task.done() and task.cancelled()])
-
- return [task for task in tasks]
-
-
-if __name__ == "__main__":
- start()
diff --git a/test/special_configs/basic_demo.yaml b/test/special_configs/basic_demo.yaml
index 1b84f7d8..6f113765 100644
--- a/test/special_configs/basic_demo.yaml
+++ b/test/special_configs/basic_demo.yaml
@@ -16,6 +16,3 @@ actors:
connections:
Acquirer.q_out: [Analysis.q_in]
InputStim.q_out: [Analysis.input_stim_queue]
-
-# settings:
-# use_watcher: [Acquirer, Processor, Visual, Analysis]
diff --git a/test/special_configs/basic_demo_with_GUI.yaml b/test/special_configs/basic_demo_with_GUI.yaml
index edfb09e1..0e5dfc27 100644
--- a/test/special_configs/basic_demo_with_GUI.yaml
+++ b/test/special_configs/basic_demo_with_GUI.yaml
@@ -21,6 +21,3 @@ actors:
connections:
Acquirer.q_out: [Analysis.q_in]
InputStim.q_out: [Analysis.input_stim_queue]
-
-# settings:
-# use_watcher: [Acquirer, Processor, Visual, Analysis]
diff --git a/test/test_actor.py b/test/test_actor.py
index 448fd34e..da7eb28e 100644
--- a/test/test_actor.py
+++ b/test/test_actor.py
@@ -1,9 +1,12 @@
import os
import psutil
import pytest
-from improv.link import Link # , AsyncQueue
+import zmq
+
from improv.actor import AbstractActor as Actor
-from improv.store import StoreInterface, PlasmaStoreInterface
+from improv.messaging import ActorStateMsg, ActorStateReplyMsg
+from improv.store import StoreInterface
+from improv.link import ZmqLink
# set global_variables
@@ -12,10 +15,10 @@
@pytest.fixture
-def init_actor(set_store_loc):
+def init_actor():
"""Fixture to initialize and teardown an instance of actor."""
- act = Actor("Test", set_store_loc)
+ act = Actor("Test")
yield act
act = None
@@ -33,33 +36,15 @@ def example_links(setup_store, server_port_num):
"""Fixture to provide link objects as test input and setup store."""
StoreInterface(server_port_num=server_port_num)
- acts = [
- Actor("act" + str(i), server_port_num) for i in range(1, 5)
- ] # range must be even
+ ctx = zmq.Context()
+ s = ctx.socket(zmq.PUSH)
- links = [
- Link("L" + str(i + 1), acts[i], acts[i + 1]) for i in range(len(acts) // 2)
- ]
+ links = [ZmqLink(s, "test") for i in range(2)]
link_dict = {links[i].name: links[i] for i, l in enumerate(links)}
pytest.example_links = link_dict
- return pytest.example_links
-
-
-@pytest.fixture
-def example_links_plasma(setup_store, set_store_loc):
- """Fixture to provide link objects as test input and setup store."""
- PlasmaStoreInterface(store_loc=set_store_loc)
-
- acts = [
- Actor("act" + str(i), set_store_loc) for i in range(1, 5)
- ] # range must be even
-
- links = [
- Link("L" + str(i + 1), acts[i], acts[i + 1]) for i in range(len(acts) // 2)
- ]
- link_dict = {links[i].name: links[i] for i, l in enumerate(links)}
- pytest.example_links = link_dict
- return pytest.example_links
+ yield pytest.example_links
+ s.close(linger=0)
+ ctx.destroy(linger=0)
@pytest.mark.parametrize(
@@ -89,66 +74,34 @@ def test_repr_default_initialization(init_actor):
assert rep == "Test: dict_keys([])"
-def test_repr(example_string_links, set_store_loc):
+def test_repr(example_string_links):
"""Test if the actor representation has the right, nonempty, dict."""
- act = Actor("Test", set_store_loc)
- act.setLinks(example_string_links)
+ act = Actor("Test")
+ act.set_links(example_string_links)
assert act.__repr__() == "Test: dict_keys(['1', '2', '3'])"
-def test_setStoreInterface(setup_store, server_port_num):
+def test_set_store_interface(setup_store, server_port_num):
"""Tests if the store is started and linked with the actor."""
act = Actor("Acquirer", server_port_num)
store = StoreInterface(server_port_num=server_port_num)
- act.setStoreInterface(store.client)
- assert act.client is store.client
-
-
-def test_plasma_setStoreInterface(setup_plasma_store, set_store_loc):
- """Tests if the store is started and linked with the actor."""
-
- act = Actor("Acquirer", set_store_loc)
- store = PlasmaStoreInterface(store_loc=set_store_loc)
- act.setStoreInterface(store.client)
+ act.set_store_interface(store.client)
assert act.client is store.client
@pytest.mark.parametrize(
- "links", [(pytest.example_string_links), ({}), (pytest.example_links), (None)]
+ "links", [pytest.example_string_links, ({}), pytest.example_links, None]
)
-def test_setLinks(links, set_store_loc):
+def test_set_links(links):
"""Tests if the actors links can be set to certain values."""
- act = Actor("test", set_store_loc)
- act.setLinks(links)
+ act = Actor("test")
+ act.set_links(links)
assert act.links == links
-@pytest.mark.parametrize(
- ("qc", "qs"),
- [
- ("comm", "sig"),
- (None, None),
- ("", ""),
- ("LINK", "LINK"), # these are placeholder names (store is not setup)
- ],
-)
-def test_setCommLinks(example_links, qc, qs, init_actor, setup_store, set_store_loc):
- """Tests if commLinks can be added to the actor"s links."""
-
- if qc == "LINK" and qs == "LINK":
- qc = Link("L1", Actor("1", set_store_loc), Actor("2", set_store_loc))
- qs = Link("L2", Actor("3", set_store_loc), Actor("4", set_store_loc))
- act = init_actor
- act.setLinks(example_links)
- act.setCommLinks(qc, qs)
-
- example_links.update({"q_comm": qc, "q_sig": qs})
- assert act.links == example_links
-
-
@pytest.mark.parametrize(
("links", "expected"),
[
@@ -158,18 +111,18 @@ def test_setCommLinks(example_links, qc, qs, init_actor, setup_store, set_store_
(None, TypeError),
],
)
-def test_setLinkIn(init_actor, example_string_links, example_links, links, expected):
+def test_set_link_in(init_actor, example_string_links, example_links, links, expected):
"""Tests if we can set the input queue."""
act = init_actor
- act.setLinks(links)
+ act.set_links(links)
if links is not None:
- act.setLinkIn("input_q")
+ act.set_link_in("input_q")
expected.update({"q_in": "input_q"})
assert act.links == expected
else:
with pytest.raises(AttributeError):
- act.setLinkIn("input_queue")
+ act.set_link_in("input_queue")
@pytest.mark.parametrize(
@@ -181,18 +134,18 @@ def test_setLinkIn(init_actor, example_string_links, example_links, links, expec
(None, TypeError),
],
)
-def test_setLinkOut(init_actor, example_string_links, example_links, links, expected):
+def test_set_link_out(init_actor, example_string_links, example_links, links, expected):
"""Tests if we can set the output queue."""
act = init_actor
- act.setLinks(links)
+ act.set_links(links)
if links is not None:
- act.setLinkOut("output_q")
+ act.set_link_out("output_q")
expected.update({"q_out": "output_q"})
assert act.links == expected
else:
with pytest.raises(AttributeError):
- act.setLinkIn("output_queue")
+ act.set_link_in("output_queue")
@pytest.mark.parametrize(
@@ -204,29 +157,31 @@ def test_setLinkOut(init_actor, example_string_links, example_links, links, expe
(None, TypeError),
],
)
-def test_setLinkWatch(init_actor, example_string_links, example_links, links, expected):
+def test_set_link_watch(
+ init_actor, example_string_links, example_links, links, expected
+):
"""Tests if we can set the watch queue."""
act = init_actor
- act.setLinks(links)
+ act.set_links(links)
if links is not None:
- act.setLinkWatch("watch_q")
+ act.set_link_watch("watch_q")
expected.update({"q_watchout": "watch_q"})
assert act.links == expected
else:
with pytest.raises(AttributeError):
- act.setLinkIn("input_queue")
+ act.set_link_in("input_queue")
-def test_addLink(setup_store, set_store_loc):
+def test_add_link(setup_store):
"""Tests if a link can be added to the dictionary of links."""
- act = Actor("test", set_store_loc)
+ act = Actor("test")
links = {"1": "one", "2": "two"}
- act.setLinks(links)
+ act.set_links(links)
newName = "3"
newLink = "three"
- act.addLink(newName, newLink)
+ act.add_link(newName, newLink)
links.update({"3": "three"})
# trying to check for two separate conditions while being able to
@@ -234,7 +189,7 @@ def test_addLink(setup_store, set_store_loc):
passes = []
err_messages = []
- if act.getLinks()["3"] == "three":
+ if act.get_links()["3"] == "three":
passes.append(True)
else:
passes.append(False)
@@ -243,7 +198,7 @@ def test_addLink(setup_store, set_store_loc):
actor.getLinks()['3'] is not equal to \"three\""
)
- if act.getLinks() == links:
+ if act.get_links() == links:
passes.append(True)
else:
passes.append("False")
@@ -256,7 +211,7 @@ def test_addLink(setup_store, set_store_loc):
assert all(passes), f"The following errors occurred: {err_out}"
-def test_getLinks(init_actor, example_string_links):
+def test_get_links(init_actor, example_string_links):
"""Tests if we can access the dictionary of links.
TODO:
@@ -265,25 +220,9 @@ def test_getLinks(init_actor, example_string_links):
act = init_actor
links = example_string_links
- act.setLinks(links)
-
- assert act.getLinks() == {"1": "one", "2": "two", "3": "three"}
-
+ act.set_links(links)
-@pytest.mark.skip(
- reason="this is something we'll do later because\
- we will subclass actor w/ watcher later"
-)
-def test_put(init_actor):
- """Tests if data keys can be put to output links.
-
- TODO:
- Ask Anne to take a look.
- """
-
- act = init_actor
- act.put()
- assert True
+ assert act.get_links() == {"1": "one", "2": "two", "3": "three"}
def test_run(init_actor):
@@ -293,12 +232,12 @@ def test_run(init_actor):
act.run()
-def test_changePriority(init_actor):
+def test_change_priority(init_actor):
"""Tests if we are able to change the priority of an actor."""
act = init_actor
act.lower_priority = True
- act.changePriority()
+ act.change_priority()
assert psutil.Process(os.getpid()).nice() == 19
@@ -311,39 +250,23 @@ def test_actor_connection(setup_store, server_port_num):
one actor. Then, in the other actor, it is removed from the queue, and
checked to verify it matches the original message.
"""
- act1 = Actor("a1", server_port_num)
- act2 = Actor("a2", server_port_num)
-
- StoreInterface(server_port_num=server_port_num)
- link = Link("L12", act1, act2)
- act1.setLinkIn(link)
- act2.setLinkOut(link)
-
- msg = "message"
-
- act1.q_in.put(msg)
+ assert True
- assert act2.q_out.get() == msg
+def test_actor_registration_with_nexus(ports, zmq_actor):
+ context = zmq.Context()
+ nex_socket = context.socket(zmq.REP)
+ nex_socket.bind(f"tcp://*:{ports[0]}") # actor in port
-def test_plasma_actor_connection(setup_plasma_store, set_store_loc):
- """Test if the links between actors are established correctly.
+ zmq_actor.start()
- This test instantiates two actors with different names, then instantiates
- a Link object linking the two actors. A string is put to the input queue of
- one actor. Then, in the other actor, it is removed from the queue, and
- checked to verify it matches the original message.
- """
- act1 = Actor("a1", set_store_loc)
- act2 = Actor("a2", set_store_loc)
+ res = nex_socket.recv_pyobj()
+ assert isinstance(res, ActorStateMsg)
- PlasmaStoreInterface(store_loc=set_store_loc)
- link = Link("L12", act1, act2)
- act1.setLinkIn(link)
- act2.setLinkOut(link)
+ nex_socket.send_pyobj(ActorStateReplyMsg("test", "OK", ""))
- msg = "message"
+ zmq_actor.terminate()
+ zmq_actor.join(10)
- act1.q_in.put(msg)
- assert act2.q_out.get() == msg
+# TODO: register with broker test
diff --git a/test/test_cli.py b/test/test_cli.py
index c4b34bff..3dd2f289 100644
--- a/test/test_cli.py
+++ b/test/test_cli.py
@@ -1,36 +1,29 @@
+import time
+
import pytest
import os
import datetime
from collections import namedtuple
import subprocess
-import asyncio
import signal
-import improv.cli as cli
-from test_nexus import ports
-
-SERVER_WARMUP = 16
-SERVER_TIMEOUT = 16
+import improv.cli as cli
+from conftest import ports
-@pytest.fixture
-def setdir():
- prev = os.getcwd()
- os.chdir(os.path.dirname(__file__))
- yield None
- os.chdir(prev)
+SERVER_WARMUP = 10
+SERVER_TIMEOUT = 15
@pytest.fixture
-async def server(setdir, ports):
+def server(setdir, ports):
"""
Sets up a server using minimal.yaml in the configs folder.
Requires the actor path command line argument and so implicitly
tests that as well.
"""
- os.chdir("configs")
- control_port, output_port, logging_port = ports
+ control_port, output_port, logging_port, logging_input_port = ports
# start server
server_opts = [
@@ -42,6 +35,8 @@ async def server(setdir, ports):
str(output_port),
"-l",
str(logging_port),
+ "-i",
+ str(logging_input_port),
"-a",
"..",
"-f",
@@ -49,29 +44,33 @@ async def server(setdir, ports):
"minimal.yaml",
]
- server = subprocess.Popen(
- server_opts, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
- )
- await asyncio.sleep(SERVER_WARMUP)
+ server = subprocess.Popen(server_opts)
+ time.sleep(SERVER_WARMUP)
yield server
- server.wait(SERVER_TIMEOUT)
- try:
- os.remove("testlog")
- except FileNotFoundError:
- pass
+ if server.poll() is None:
+ pytest.fail("Server did not shut down correctly.")
@pytest.fixture
-async def cli_args(setdir, ports):
+def cli_args(setdir, ports):
logfile = "tmp.log"
- control_port, output_port, logging_port = ports
- config_file = "configs/minimal.yaml"
+ control_port, output_port, logging_port, logging_input_port = ports
+ config_file = "minimal.yaml"
Args = namedtuple(
"cli_args",
- "control_port output_port logging_port logfile configfile actor_path",
+ "control_port output_port logging_port logging_input_port logfile "
+ "configfile actor_path",
)
- args = Args(control_port, output_port, logging_port, logfile, config_file, [])
+ args = Args(
+ control_port,
+ output_port,
+ logging_port,
+ logging_input_port,
+ logfile,
+ config_file,
+ [],
+ )
return args
@@ -86,7 +85,7 @@ def test_configfile_required(setdir):
cli.parse_cli_args(["server", "does_not_exist.yaml"])
-def test_multiple_actor_path(setdir):
+def test_multiple_actor_path(set_dir_config_parent):
args = cli.parse_cli_args(
["run", "-a", "actors", "-a", "configs", "configs/blank_file.yaml"]
)
@@ -104,22 +103,26 @@ def test_multiple_actor_path(setdir):
("run", "-c", "6000"),
("run", "-o", "6000"),
("run", "-l", "6000"),
+ ("run", "-i", "6000"),
("server", "-c", "6000"),
("server", "-o", "6000"),
("server", "-l", "6000"),
+ ("server", "-i", "6000"),
("client", "-c", "6000"),
("client", "-s", "6000"),
("client", "-l", "6000"),
+ ("client", "-i", "6000"),
],
)
def test_can_override_ports(mode, flag, expected, setdir):
- file = "configs/blank_file.yaml"
+ file = "blank_file.yaml"
localhost = "127.0.0.1:"
params = {
"-c": "control_port",
"-o": "output_port",
"-s": "server_port",
"-l": "logging_port",
+ "-i": "logging_input_port",
}
if mode in ["run", "server"]:
@@ -142,7 +145,7 @@ def test_can_override_ports(mode, flag, expected, setdir):
],
)
def test_non_port_is_error(mode, flag, expected):
- file = "configs/blank_file.yaml"
+ file = "blank_file.yaml"
with pytest.raises(SystemExit):
cli.parse_cli_args([mode, flag, expected, file])
@@ -182,27 +185,29 @@ def test_can_override_ip(mode, flag, expected):
assert vars(args)[params[flag]] == expected
-async def test_sigint_kills_server(server):
+def test_sigint_kills_server(server):
server.send_signal(signal.SIGINT)
+ server.wait(SERVER_TIMEOUT)
-async def test_improv_list_nonempty(server):
+def test_improv_list_nonempty(server):
proc_list = cli.run_list("", printit=False)
assert len(proc_list) > 0
server.send_signal(signal.SIGINT)
+ server.wait(SERVER_TIMEOUT)
-async def test_improv_kill_empties_list(server):
+def test_improv_kill_empties_list(server):
proc_list = cli.run_list("", printit=False)
assert len(proc_list) > 0
cli.run_cleanup("", headless=True)
proc_list = cli.run_list("", printit=False)
assert len(proc_list) == 0
+ server.wait(SERVER_TIMEOUT)
-async def test_improv_run_writes_stderr_to_log(setdir, ports):
- os.chdir("configs")
- control_port, output_port, logging_port = ports
+def test_improv_run_writes_stderr_to_log(setdir, ports):
+ control_port, output_port, logging_port, logging_input_port = ports
# start server
server_opts = [
@@ -214,6 +219,8 @@ async def test_improv_run_writes_stderr_to_log(setdir, ports):
str(output_port),
"-l",
str(logging_port),
+ "-i",
+ str(logging_input_port),
"-a",
"..",
"-f",
@@ -223,62 +230,73 @@ async def test_improv_run_writes_stderr_to_log(setdir, ports):
server = subprocess.Popen(
server_opts, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
)
- await asyncio.sleep(SERVER_WARMUP)
+ time.sleep(SERVER_WARMUP)
server.kill()
server.wait(SERVER_TIMEOUT)
- with open("testlog") as log:
+ with open("improv-debug.log") as log:
contents = log.read()
+ print(contents)
assert "Traceback" in contents
+ os.remove("improv-debug.log")
os.remove("testlog")
cli.run_cleanup("", headless=True)
-async def test_get_ports_from_logfile(setdir):
+def test_get_ports_from_logfile(setdir):
test_control_port = 53349
test_output_port = 53350
test_logging_port = 53351
+ test_logging_input_port = 53352
logfile = "tmp.log"
with open(logfile, "w") as log:
log.write(
- "Server running on (control, output, log) ports (53345, 53344, 53343)."
+ "Server running on (control, output, log, log input) ports "
+ "(53345, 53344, 53343, 53352).\n"
)
log.write(
- f"Server running on (control, output, log) ports ({test_control_port}, "
- f"{test_output_port}, {test_logging_port})."
+ f"Server running on (control, output, log, log input) ports "
+ f"({test_control_port}, "
+ f"{test_output_port}, {test_logging_port}, {test_logging_input_port})."
)
- control_port, output_port, logging_port = cli._get_ports(logfile)
+ control_port, output_port, logging_port, logging_input_port = cli._get_ports(
+ logfile
+ )
os.remove(logfile)
assert control_port == test_control_port
assert output_port == test_output_port
assert logging_port == test_logging_port
+ assert logging_input_port == test_logging_input_port
-async def test_no_server_start_in_logfile_raises_error(setdir, cli_args, capsys):
- with open(cli_args.logfile, mode="w") as f:
+def test_no_server_start_in_logfile_raises_error(setdir, cli_args, capsys):
+ with open("improv-debug.log", mode="w") as f:
f.write("this is some placeholder text")
- cli.get_server_ports(cli_args, timeout=1)
+ curr_dt = datetime.datetime.now().replace(microsecond=0)
+ timeout = 1
+ cli.get_server_ports(cli_args, timeout, curr_dt)
captured = capsys.readouterr()
assert "Unable to read server start time" in captured.out
- os.remove(cli_args.logfile)
+ os.remove("improv-debug.log")
cli.run_cleanup("", headless=True)
-async def test_no_ports_in_logfile_raises_error(setdir, cli_args, capsys):
+def test_no_ports_in_logfile_raises_error(setdir, cli_args, capsys):
curr_dt = datetime.datetime.now().replace(microsecond=0)
- with open(cli_args.logfile, mode="w") as f:
+ with open("improv-debug.log", mode="w") as f:
f.write(f"{curr_dt} Server running on (control, output, log) ports XXX\n")
- cli.get_server_ports(cli_args, timeout=1)
+ timeout = 1
+ cli.get_server_ports(cli_args, timeout, curr_dt)
captured = capsys.readouterr()
- assert f"Unable to read ports from {cli_args.logfile}." in captured.out
+ assert f"Unable to read ports from {'improv-debug.log'}." in captured.out
- os.remove(cli_args.logfile)
+ os.remove("improv-debug.log")
cli.run_cleanup("", headless=True)
diff --git a/test/test_config.py b/test/test_config.py
index 12cc79bf..1db17a17 100644
--- a/test/test_config.py
+++ b/test/test_config.py
@@ -6,13 +6,14 @@
# from importlib import import_module
# from improv.config import RepeatedActorError
-from improv.config import Config
+from improv.config import Config, CannotCreateConfigException
from improv.utils import checks
import logging
logger = logging.getLogger(__name__)
+
# set global variables
@@ -25,7 +26,7 @@ def set_configdir():
os.chdir(prev)
-@pytest.mark.parametrize("test_input", [("good_config.yaml")])
+@pytest.mark.parametrize("test_input", ["good_config.yaml"])
def test_init(test_input, set_configdir):
"""Checks if cfg.configFile matches the provided configFile.
@@ -34,61 +35,43 @@ def test_init(test_input, set_configdir):
"""
cfg = Config(test_input)
- assert cfg.configFile == test_input
-
-
-# def test_init_attributes():
-# """ Tests if config has correct default attributes on initialization.
-
-# Checks if actors, connection, and hasGUI are all empty or
-# nonexistent. Detects errors by maintaining a list of errors, and
-# then adding to it every time an unexpected behavior is encountered.
-
-# Asserts:
-# If the default attributes are empty or nonexistent.
+ assert cfg.config_file == test_input
-# """
-# cfg = config()
-# errors = []
-
-# if(cfg.actors != {}):
-# errors.append("config.actors is not empty! ")
-# if(cfg.connections != {}):
-# errors.append("config.connections is not empty! ")
-# if(cfg.hasGUI):
-# errors.append("config.hasGUI already exists! ")
-
-# assert not errors, "The following errors occurred:\n{}".format(
-# "\n".join(errors))
-
-
-def test_createConfig_settings(set_configdir):
+def test_create_config_settings(set_configdir):
"""Check if the default way config creates config.settings is correct.
Asserts:
- If the default setting is the dictionary {"use_watcher": "None"}
+ If the default setting is the config as a dict.
"""
cfg = Config("good_config.yaml")
- cfg.createConfig()
- assert cfg.settings == {"use_watcher": False}
+ cfg.parse_config()
+ cfg.create_config()
+ assert cfg.settings == {
+ "control_port": 5555,
+ "output_port": 5556,
+ "logging_port": 5557,
+ "logging_input_port": 5558,
+ "store_size": 250_000_000,
+ "harvest_data_from_memory": None,
+ }
-# File with syntax error cannot pass the format check
-# def test_createConfig_init_typo(set_configdir):
-# """Tests if createConfig can catch actors with errors in init function.
+def test_create_config_init_typo(set_configdir):
+ """Tests if createConfig can catch actors with errors in init function.
-# Asserts:
-# If createConfig raise any errors.
-# """
+ Asserts:
+ If createConfig raise any errors.
+ """
-# cfg = config("minimal_wrong_init.yaml")
-# res = cfg.createConfig()
-# assert res == -1
+ cfg = Config("minimal_wrong_init.yaml")
+ cfg.parse_config()
+ res = cfg.create_config()
+ assert res == -1
-def test_createConfig_wrong_import(set_configdir):
+def test_create_config_wrong_import(set_configdir):
"""Tests if createConfig can catch actors with errors during import.
Asserts:
@@ -96,11 +79,12 @@ def test_createConfig_wrong_import(set_configdir):
"""
cfg = Config("minimal_wrong_import.yaml")
- res = cfg.createConfig()
+ cfg.parse_config()
+ res = cfg.create_config()
assert res == -1
-def test_createConfig_clean(set_configdir):
+def test_create_config_clean(set_configdir):
"""Tests if createConfig runs without error given a good config.
Asserts:
@@ -108,52 +92,57 @@ def test_createConfig_clean(set_configdir):
"""
cfg = Config("good_config.yaml")
+ cfg.parse_config()
try:
- cfg.createConfig()
+ cfg.create_config()
except Exception as exc:
pytest.fail(f"createConfig() raised an exception {exc}")
-def test_createConfig_noActor(set_configdir):
+def test_create_config_no_actor(set_configdir):
"""Tests if AttributeError is raised when there are no actors."""
cfg = Config("no_actor.yaml")
+ cfg.parse_config()
with pytest.raises(AttributeError):
- cfg.createConfig()
+ cfg.create_config()
-def test_createConfig_ModuleNotFound(set_configdir):
+def test_create_config_module_not_found(set_configdir):
"""Tests if an error is raised when the package can"t be found."""
cfg = Config("bad_package.yaml")
- res = cfg.createConfig()
+ cfg.parse_config()
+ res = cfg.create_config()
assert res == -1
-def test_createConfig_class_ImportError(set_configdir):
+def test_create_config_class_import_error(set_configdir):
"""Tests if an error is raised when the class name is invalid."""
cfg = Config("bad_class.yaml")
- res = cfg.createConfig()
+ cfg.parse_config()
+ res = cfg.create_config()
assert res == -1
-def test_createConfig_AttributeError(set_configdir):
+def test_create_config_attribute_error(set_configdir):
"""Tests if AttributeError is raised."""
cfg = Config("bad_class.yaml")
- res = cfg.createConfig()
+ cfg.parse_config()
+ res = cfg.create_config()
assert res == -1
-def test_createConfig_blank_file(set_configdir):
+def test_create_config_blank_file(set_configdir):
"""Tests if a blank config file raises an error."""
- with pytest.raises(TypeError):
+ with pytest.raises(CannotCreateConfigException):
Config("blank_file.yaml")
-def test_createConfig_nonsense_file(set_configdir, caplog):
+def test_create_config_nonsense_file(set_configdir, caplog):
"""Tests if an improperly formatted config raises an error."""
with pytest.raises(TypeError):
@@ -170,12 +159,13 @@ def test_cyclic_graph(set_configdir):
assert not checks.check_if_connections_acyclic(path)
-def test_saveActors_clean(set_configdir):
+def test_save_actors_clean(set_configdir):
"""Compares internal actor representation to what was saved in the file."""
cfg = Config("good_config.yaml")
- cfg.createConfig()
- cfg.saveActors()
+ cfg.parse_config()
+ cfg.create_config()
+ cfg.save_actors()
with open("good_config_actors.yaml") as savedConfig:
data = yaml.safe_load(savedConfig)
@@ -185,9 +175,107 @@ def test_saveActors_clean(set_configdir):
assert savedKeys == originalKeys
+ os.remove("good_config_actors.yaml")
+
def test_config_settings_read(set_configdir):
cfg = Config("minimal_with_settings.yaml")
- cfg.createConfig()
+ cfg.parse_config()
+ cfg.create_config()
assert "store_size" in cfg.settings
+
+
+def test_config_bad_actor_args(set_configdir):
+ cfg = Config("bad_args.yaml")
+ cfg.parse_config()
+ res = cfg.create_config()
+ assert res == -1
+
+
+def test_config_harvester_disabled(set_configdir):
+ cfg = Config("minimal.yaml")
+ cfg.config = dict()
+ cfg.config["settings"] = dict()
+ cfg.parse_config()
+ assert cfg.settings["harvest_data_from_memory"] is None
+
+
+def test_config_harvester_enabled(set_configdir):
+ cfg = Config("minimal.yaml")
+ cfg.config = dict()
+ cfg.config["settings"] = dict()
+ cfg.config["settings"]["harvest_data_from_memory"] = True
+ cfg.parse_config()
+ assert cfg.settings["harvest_data_from_memory"]
+
+
+def test_config_redis_aof_enabled_saving_not_specified(set_configdir):
+ cfg = Config("minimal.yaml")
+ cfg.config = dict()
+ cfg.config["redis_config"] = dict()
+ cfg.config["redis_config"]["aof_dirname"] = "test"
+ cfg.parse_config()
+ assert cfg.redis_config["aof_dirname"] == "test"
+ assert cfg.redis_config["enable_saving"] is True
+
+
+def test_config_redis_ephemeral_dirname_enabled_saving_not_specified(set_configdir):
+ cfg = Config("minimal.yaml")
+ cfg.config = dict()
+ cfg.config["redis_config"] = dict()
+ cfg.config["redis_config"]["generate_ephemeral_aof_dirname"] = True
+ cfg.parse_config()
+ assert cfg.redis_config["generate_ephemeral_aof_dirname"] is True
+ assert cfg.redis_config["enable_saving"] is True
+
+
+def test_config_redis_ephemeral_dirname_and_aof_dirname_specified(set_configdir):
+ cfg = Config("minimal.yaml")
+ cfg.config = dict()
+ cfg.config["redis_config"] = dict()
+ cfg.config["redis_config"]["generate_ephemeral_aof_dirname"] = True
+ cfg.config["redis_config"]["aof_dirname"] = "test"
+ with pytest.raises(
+ Exception, match="Cannot use unique dirname and use the one provided."
+ ):
+ cfg.parse_config()
+
+
+def test_config_redis_ephemeral_dirname_enabled_saving_disabled(set_configdir):
+ cfg = Config("minimal.yaml")
+ cfg.config = dict()
+ cfg.config["redis_config"] = dict()
+ cfg.config["redis_config"]["generate_ephemeral_aof_dirname"] = True
+ cfg.config["redis_config"]["enable_saving"] = False
+ with pytest.raises(Exception, match="Cannot persist to disk with saving disabled."):
+ cfg.parse_config()
+
+
+def test_config_redis_aof_dirname_enabled_saving_disabled(set_configdir):
+ cfg = Config("minimal.yaml")
+ cfg.config = dict()
+ cfg.config["redis_config"] = dict()
+ cfg.config["redis_config"]["aof_dirname"] = "test"
+ cfg.config["redis_config"]["enable_saving"] = False
+ with pytest.raises(Exception, match="Cannot persist to disk with saving disabled."):
+ cfg.parse_config()
+
+
+def test_config_redis_fsync_enabled_saving_disabled(set_configdir):
+ cfg = Config("minimal.yaml")
+ cfg.config = dict()
+ cfg.config["redis_config"] = dict()
+ cfg.config["redis_config"]["fsync_frequency"] = "always"
+ cfg.config["redis_config"]["enable_saving"] = False
+ with pytest.raises(Exception, match="Cannot persist to disk with saving disabled."):
+ cfg.parse_config()
+
+
+def test_config_redis_unknown_fsync_freq(set_configdir):
+ cfg = Config("minimal.yaml")
+ cfg.config = dict()
+ cfg.config["redis_config"] = dict()
+ cfg.config["redis_config"]["fsync_frequency"] = "unknown"
+ with pytest.raises(Exception, match="Cannot use unknown fsync frequency unknown"):
+ cfg.parse_config()
diff --git a/test/test_demos.py b/test/test_demos.py
index bc6e96c6..04d498e6 100644
--- a/test/test_demos.py
+++ b/test/test_demos.py
@@ -4,18 +4,15 @@
import os
import asyncio
import subprocess
+
import improv.tui as tui
-import concurrent.futures
import logging
-from demos.sample_actors.zmqActor import ZmqActor
-from test_nexus import ports
-
LOGGER = logging.getLogger(__name__)
LOGGER.setLevel(logging.DEBUG)
-
SERVER_WARMUP = 10
+UNUSED_TCP_PORT = 10567
@pytest.fixture
@@ -35,20 +32,19 @@ def ip():
return pytest.ip
-@pytest.mark.parametrize(
- ("dir", "configfile", "logfile"),
- [
- ("minimal", "minimal.yaml", "testlog"),
- ("minimal", "minimal_plasma.yaml", "testlog"),
- ],
-)
-async def test_simple_boot_and_quit(dir, configfile, logfile, setdir, ports):
- os.chdir(dir)
+@pytest.fixture
+def unused_tcp_port():
+ global UNUSED_TCP_PORT
+ pytest.unused_tcp_port = UNUSED_TCP_PORT
+ yield pytest.unused_tcp_port
+ UNUSED_TCP_PORT += 1
+
- control_port, output_port, logging_port = ports
+@pytest.fixture
+def server_opts(ports, logfile, configfile):
+ control_port, output_port, logging_port, logging_input_port = ports
- # start server
- server_opts = [
+ opts = [
"improv",
"server",
"-c",
@@ -57,11 +53,28 @@ async def test_simple_boot_and_quit(dir, configfile, logfile, setdir, ports):
str(output_port),
"-l",
str(logging_port),
+ "-i",
+ str(logging_input_port),
"-f",
logfile,
configfile,
]
+ return opts
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ ("dir", "configfile", "logfile"),
+ [
+ ("minimal", "minimal.yaml", "testlog"),
+ ],
+)
+async def test_simple_boot_and_quit(dir, logfile, setdir, ports, server_opts):
+ os.chdir(dir)
+
+ control_port, output_port, logging_port, logging_input_port = ports
+
with open(logfile, mode="a+") as log:
server = subprocess.Popen(server_opts, stdout=log, stderr=log)
time.sleep(5)
@@ -69,7 +82,7 @@ async def test_simple_boot_and_quit(dir, configfile, logfile, setdir, ports):
await asyncio.sleep(SERVER_WARMUP)
# initialize client
- app = tui.TUI(control_port, output_port, logging_port)
+ app = tui.TUI(control_port, output_port, logging_port, logging_input_port)
# run client
async with app.run_test() as pilot:
@@ -80,44 +93,76 @@ async def test_simple_boot_and_quit(dir, configfile, logfile, setdir, ports):
await pilot.pause(2)
assert not pilot.app._running
+ # wait on server to fully shut down
+ server.wait(15)
+ # os.remove(logfile) # later, might want to read this file and check for messages
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ ("dir", "configfile", "logfile", "datafile"),
+ [
+ ("minimal", "minimal.yaml", "testlog", "sample_generator_data.npy"),
+ ("minimal", "minimal_persistence.yaml", "testlog", "test_persistence.csv"),
+ ],
+)
+async def test_stop_output(dir, logfile, datafile, setdir, ports, server_opts):
+ os.chdir(dir)
+
+ control_port, output_port, logging_port, logging_input_port = ports
+
+ with open(logfile, mode="a+") as log:
+ server = subprocess.Popen(server_opts, stdout=log, stderr=log)
+ await asyncio.sleep(SERVER_WARMUP)
+
+ # initialize client
+ app = tui.TUI(control_port, output_port, logging_port, logging_input_port)
+
+ # run client
+ async with app.run_test() as pilot:
+ print("running pilot")
+ await pilot.press(*"setup", "enter")
+ await pilot.pause(0.5)
+ await pilot.press(*"run", "enter")
+ await pilot.pause(1)
+ await pilot.press(*"stop", "enter")
+ await pilot.pause(2)
+ await pilot.press(*"quit", "enter")
+ await pilot.pause(3)
+ assert not pilot.app._running
+
# wait on server to fully shut down
server.wait(10)
+
+ # check that the file written by Generator's stop function got written
+ if os.path.isfile(datafile):
+ # then remove that file and logile
+ os.remove(datafile)
+
os.remove(logfile) # later, might want to read this file and check for messages
+@pytest.mark.skip
+@pytest.mark.asyncio
@pytest.mark.parametrize(
("dir", "configfile", "logfile", "datafile"),
[
- ("minimal", "minimal.yaml", "testlog", "sample_generator_data.npy"),
("minimal", "minimal_spawn.yaml", "testlog", "sample_generator_data.npy"),
],
)
-async def test_stop_output(dir, configfile, logfile, datafile, setdir, ports):
+async def test_stop_output_spawn(
+ dir, configfile, logfile, datafile, setdir, ports, server_opts
+):
os.chdir(dir)
- control_port, output_port, logging_port = ports
-
- # start server
- server_opts = [
- "improv",
- "server",
- "-c",
- str(control_port),
- "-o",
- str(output_port),
- "-l",
- str(logging_port),
- "-f",
- logfile,
- configfile,
- ]
+ control_port, output_port, logging_port, logging_input_port = ports
with open(logfile, mode="a+") as log:
server = subprocess.Popen(server_opts, stdout=log, stderr=log)
await asyncio.sleep(SERVER_WARMUP)
# initialize client
- app = tui.TUI(control_port, output_port, logging_port)
+ app = tui.TUI(control_port, output_port, logging_port, logging_input_port)
# run client
async with app.run_test() as pilot:
@@ -141,63 +186,3 @@ async def test_stop_output(dir, configfile, logfile, datafile, setdir, ports):
# then remove that file and logile
os.remove(datafile)
os.remove(logfile) # later, might want to read this file and check for messages
-
-
-def test_zmq_ps(ip, unused_tcp_port):
- """Tests if we can set the zmq PUB/SUB socket and send message."""
- port = unused_tcp_port
- LOGGER.info("beginning test")
- act1 = ZmqActor("act1", type="PUB", ip=ip, port=port)
- act2 = ZmqActor("act2", type="SUB", ip=ip, port=port)
- LOGGER.info("ZMQ Actors constructed")
- # Note these sockets must be set up for testing
- # this is not needed for running in improv
- act1.setSendSocket()
- act2.setRecvSocket()
-
- msg = "hello"
- act1.put(msg)
- LOGGER.info("sent message")
- recvmsg = act2.get()
- LOGGER.info("received message")
- assert recvmsg == msg
-
-
-def test_zmq_rr(ip, unused_tcp_port):
- """Tests if we can set the zmq REQ/REP socket and send message."""
- port = unused_tcp_port
- act1 = ZmqActor("act1", "/tmp/store", type="REQ", ip=ip, port=port)
- act2 = ZmqActor("act2", "/tmp/store", type="REP", ip=ip, port=port)
- msg = "hello"
- reply = "world"
-
- def handle_request():
- return act1.put(msg)
-
- def handle_reply():
- return act2.get(reply)
-
- # Use a ThreadPoolExecutor to run handle_request()
- # and handle_reply() in separate threads.
-
- with concurrent.futures.ThreadPoolExecutor() as executor:
- future1 = executor.submit(handle_request)
- future2 = executor.submit(handle_reply)
-
- # Ensure the request is sent before the reply.
- request_result = future1.result()
- reply_result = future2.result()
-
- # Check if the received message is equal to the original message.
- assert reply_result == msg
- # Check if the reply is correct.
- assert request_result == reply
-
-
-def test_zmq_rr_timeout(ip, unused_tcp_port):
- """Test for requestMsg where we timeout or fail to send"""
- port = unused_tcp_port
- act1 = ZmqActor("act1", "/tmp/store", type="REQ", ip=ip, port=port)
- msg = "hello"
- replymsg = act1.put(msg)
- assert replymsg is None
diff --git a/test/test_harvester.py b/test/test_harvester.py
new file mode 100644
index 00000000..ceec8ef6
--- /dev/null
+++ b/test/test_harvester.py
@@ -0,0 +1,222 @@
+import logging
+import signal
+import time
+
+import pytest
+import zmq
+from zmq.backend.cython._zmq import SocketOption
+
+from improv.harvester import RedisHarvester
+from improv.store import RedisStoreInterface
+
+from improv.link import ZmqLink
+from improv.messaging import HarvesterInfoReplyMsg
+from conftest import SignalManager
+
+
+def test_harvester_shuts_down_on_sigint(setup_store, harvester):
+ harvester_ports, broker_socket, p = harvester
+ ctx = zmq.Context()
+ s = ctx.socket(zmq.REP)
+ s.bind(f"tcp://*:{harvester_ports[0]}")
+ s.recv_pyobj()
+ reply = HarvesterInfoReplyMsg("harvester", "OK", "")
+ s.send_pyobj(reply)
+ time.sleep(2)
+ p.terminate()
+ p.join(5)
+ if p.exitcode is None:
+ p.kill()
+ pytest.fail("Harvester did not exit in time")
+ else:
+ assert True
+
+ s.close(linger=0)
+ ctx.destroy(linger=0)
+
+
+def test_harvester_relieves_memory_pressure(setup_store, harvester):
+ store_interface = RedisStoreInterface()
+ harvester_ports, broker_socket, p = harvester
+ broker_link = ZmqLink(broker_socket, "test", "test topic")
+ ctx = zmq.Context()
+ s = ctx.socket(zmq.REP)
+ s.bind(f"tcp://*:{harvester_ports[0]}")
+ s.recv_pyobj()
+ reply = HarvesterInfoReplyMsg("harvester", "OK", "")
+ s.send_pyobj(reply)
+ client = RedisStoreInterface()
+ for i in range(10):
+ message = [i for i in range(500000)]
+ key = client.put(message)
+ broker_link.put(key)
+ time.sleep(2)
+
+ db_info = store_interface.client.info()
+ max_memory = db_info["maxmemory"]
+ used_memory = db_info["used_memory"]
+ used_max_ratio = used_memory / max_memory
+ assert used_max_ratio <= 0.75
+
+ p.terminate()
+ p.join(5)
+ if p.exitcode is None:
+ p.kill()
+ pytest.fail("Harvester did not exit in time")
+
+ s.close(linger=0)
+ ctx.destroy(linger=0)
+
+
+def test_harvester_stop_logs_and_halts_running():
+ with SignalManager():
+ ctx = zmq.Context()
+ s = ctx.socket(zmq.PULL)
+ s.bind("tcp://*:0")
+ pull_port_string = s.getsockopt_string(SocketOption.LAST_ENDPOINT)
+ pull_port = int(pull_port_string.split(":")[-1])
+ harvester = RedisHarvester(
+ nexus_hostname="localhost",
+ nexus_comm_port=10000, # never gets called in this test
+ redis_hostname="localhost",
+ redis_port=10000, # never gets called in this test
+ broker_hostname="localhost",
+ broker_port=10000, # never gets called in this test
+ logger_hostname="localhost",
+ logger_port=pull_port,
+ )
+ harvester.stop(signal.SIGINT, None)
+ assert not harvester.running
+ msg_available = s.poll(timeout=1000)
+ assert msg_available
+ record = s.recv_json()
+ assert (
+ record["message"]
+ == f"Harvester shutting down due to signal {signal.SIGINT}"
+ )
+
+ s.close(linger=0)
+ ctx.destroy(linger=0)
+
+
+# @pytest.mark.skip
+def test_harvester_relieves_memory_pressure_one_loop(ports, setup_store):
+ def harvest_and_quit(harvester_instance: RedisHarvester):
+ harvester_instance.collect()
+ harvester_instance.stop(signal.SIGINT, None)
+
+ with SignalManager():
+ try:
+ ctx = zmq.Context()
+ nex_s = ctx.socket(zmq.REP)
+ nex_s.bind(f"tcp://*:{ports[0]}")
+
+ broker_s = ctx.socket(zmq.PUB)
+ broker_s.bind("tcp://*:1234")
+ broker_link = ZmqLink(broker_s, "test", "test topic")
+
+ log_s = ctx.socket(zmq.PULL)
+ log_s.bind("tcp://*:0")
+ pull_port_string = log_s.getsockopt_string(SocketOption.LAST_ENDPOINT)
+ pull_port = int(pull_port_string.split(":")[-1])
+
+ harvester = RedisHarvester(
+ nexus_hostname="localhost",
+ nexus_comm_port=ports[0], # never gets called in this test
+ redis_hostname="localhost",
+ redis_port=6379, # never gets called in this test
+ broker_hostname="localhost",
+ broker_port=1234, # never gets called in this test
+ logger_hostname="localhost",
+ logger_port=pull_port,
+ )
+
+ harvester.establish_connections()
+
+ client = RedisStoreInterface()
+ for i in range(9):
+ message = [i for i in range(500000)]
+ try:
+ key = client.put(message)
+ broker_link.put(key)
+ except Exception as e:
+ logging.warning(e)
+ logging.warning("Proceeding under the assumption Redis is full.")
+ break
+ time.sleep(2)
+
+ harvester.serve(harvest_and_quit, harvester_instance=harvester)
+
+ db_info = client.client.info()
+ max_memory = db_info["maxmemory"]
+ used_memory = db_info["used_memory"]
+ used_max_ratio = used_memory / max_memory
+
+ assert used_max_ratio <= 0.5
+ assert not harvester.running
+ assert harvester.nexus_socket.closed
+ assert harvester.sub_socket.closed
+ assert harvester.zmq_context.closed
+ logging.info("Passed harvester one loop test")
+
+ except Exception as e:
+ logging.error(f"Encountered exception {e} during execution of test")
+ print(e)
+ pass
+
+ try:
+ nex_s.close(linger=0)
+ broker_s.close(linger=0)
+ log_s.close(linger=0)
+ ctx.destroy(linger=0)
+ except Exception as e:
+ logging.error(f"Encountered exception {e} during teardown")
+ print(e)
+ pass
+
+
+def test_harvester_loops_with_no_memory_pressure(ports, setup_store):
+ def harvest_and_quit(harvester_instance: RedisHarvester):
+ harvester_instance.collect()
+ harvester_instance.stop(signal.SIGINT, None)
+
+ with SignalManager():
+ ctx = zmq.Context()
+ nex_s = ctx.socket(zmq.REP)
+ nex_s.bind(f"tcp://*:{ports[0]}")
+
+ log_s = ctx.socket(zmq.PULL)
+ log_s.bind("tcp://*:0")
+ pull_port_string = log_s.getsockopt_string(SocketOption.LAST_ENDPOINT)
+ pull_port = int(pull_port_string.split(":")[-1])
+
+ harvester = RedisHarvester(
+ nexus_hostname="localhost",
+ nexus_comm_port=ports[0], # never gets called in this test
+ redis_hostname="localhost",
+ redis_port=6379, # never gets called in this test
+ broker_hostname="localhost",
+ broker_port=1234, # never gets called in this test
+ logger_hostname="localhost",
+ logger_port=pull_port,
+ )
+
+ harvester.establish_connections()
+
+ client = RedisStoreInterface()
+
+ harvester.serve(harvest_and_quit, harvester_instance=harvester)
+
+ db_info = client.client.info()
+ max_memory = db_info["maxmemory"]
+ used_memory = db_info["used_memory"]
+ used_max_ratio = used_memory / max_memory
+ assert used_max_ratio <= 0.5
+ assert not harvester.running
+ assert harvester.nexus_socket.closed
+ assert harvester.sub_socket.closed
+ assert harvester.zmq_context.closed
+
+ nex_s.close(linger=0)
+ log_s.close(linger=0)
+ ctx.destroy(linger=0)
diff --git a/test/test_link.py b/test/test_link.py
index eec9a521..55fc02f8 100644
--- a/test/test_link.py
+++ b/test/test_link.py
@@ -1,170 +1,137 @@
-import asyncio
-import queue
-import subprocess
+import json
import time
import pytest
+import zmq
+from improv.actor import ManagedActor
+from zmq import SocketOption
-from improv.actor import Actor
+from improv.link import ZmqLink
-from improv.link import Link
+@pytest.fixture
+def test_sub_link():
+ """Fixture to provide a commonly used Link object."""
+ ctx = zmq.Context()
+ link_socket = ctx.socket(zmq.SUB)
+ link_socket.bind("tcp://*:0")
+ link_socket_string = link_socket.getsockopt_string(SocketOption.LAST_ENDPOINT)
+ link_socket_port = int(link_socket_string.split(":")[-1])
-def init_actors(n=1):
- """Function to return n unique actors.
-
- Returns:
- list: A list of n actors, each being named after its index.
- """
+ link_socket.poll(timeout=0)
- # the links must be specified as an empty dictionary to avoid
- # actors sharing a dictionary of links
+ link_pub_socket = ctx.socket(zmq.PUB)
+ link_pub_socket.connect(f"tcp://localhost:{link_socket_port}")
- return [Actor("test " + str(i), "/tmp/store", links={}) for i in range(n)]
+ link_socket.poll(timeout=0) # open libzmq bug (not pyzmq)
+ link_socket.subscribe("test_topic")
+ time.sleep(0.5)
+ link_socket.poll(timeout=0)
+ link = ZmqLink(link_socket, "test_link", "test_topic")
-@pytest.fixture
-def example_link(setup_store):
- """Fixture to provide a commonly used Link object."""
- act = init_actors(2)
- lnk = Link("Example", act[0].name, act[1].name)
- yield lnk
- lnk = None
+ yield link, link_pub_socket
+ link.socket.close(linger=0)
+ link_pub_socket.close(linger=0)
+ ctx.destroy(linger=0)
@pytest.fixture
-def example_actor_system(setup_store):
- """Fixture to provide a list of 4 connected actors."""
-
- # store = setup_store
- acts = init_actors(4)
-
- L01 = Link("L01", acts[0].name, acts[1].name)
- L13 = Link("L13", acts[1].name, acts[3].name)
- L12 = Link("L12", acts[1].name, acts[2].name)
- L23 = Link("L23", acts[2].name, acts[3].name)
-
- links = [L01, L13, L12, L23]
+def test_pub_link():
+ """Fixture to provide a commonly used Link object."""
+ ctx = zmq.Context()
+ link_socket = ctx.socket(zmq.PUB)
+ link_socket.bind("tcp://*:0")
+ link_socket_string = link_socket.getsockopt_string(SocketOption.LAST_ENDPOINT)
+ link_socket_port = int(link_socket_string.split(":")[-1])
- acts[0].addLink("q_out_1", L01)
- acts[1].addLink("q_out_1", L13)
- acts[1].addLink("q_out_2", L12)
- acts[2].addLink("q_out_1", L23)
+ link_sub_socket = ctx.socket(zmq.SUB)
+ link_sub_socket.connect(f"tcp://localhost:{link_socket_port}")
+ link_sub_socket.poll(timeout=0)
+ link_sub_socket.subscribe("test_topic")
+ time.sleep(0.5)
+ link_sub_socket.poll(timeout=0)
- acts[1].addLink("q_in_1", L01)
- acts[2].addLink("q_in_1", L12)
- acts[3].addLink("q_in_1", L13)
- acts[3].addLink("q_in_2", L23)
+ link = ZmqLink(link_socket, "test_link", "test_topic")
- yield [acts, links] # also yield Links
- acts = None
+ yield link, link_sub_socket
+ link.socket.close(linger=0)
+ link_sub_socket.close(linger=0)
+ ctx.destroy(linger=0)
@pytest.fixture
-def _kill_pytest_processes():
- """Kills all processes with "pytest" in their name.
+def test_req_link():
+ """Fixture to provide a commonly used Link object."""
+ ctx = zmq.Context()
+ link_socket = ctx.socket(zmq.REQ)
+ link_socket.bind("tcp://*:0")
+ link_socket_string = link_socket.getsockopt_string(SocketOption.LAST_ENDPOINT)
+ link_socket_port = int(link_socket_string.split(":")[-1])
- NOTE:
- This fixture should only be used at the end of testing.
- """
+ link_rep_socket = ctx.socket(zmq.REP)
+ link_rep_socket.connect(f"tcp://localhost:{link_socket_port}")
- subprocess.Popen(
- ["kill", "`pgrep pytest`"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
- )
+ link = ZmqLink(link_socket, "test_link")
-
-@pytest.mark.parametrize(
- ("attribute", "expected"),
- [
- ("name", "Example"),
- ("real_executor", None),
- ("cancelled_join", False),
- ("status", "pending"),
- ("result", None),
- ],
-)
-def test_Link_init(setup_store, example_link, attribute, expected):
- """Tests if the default initialization attributes are set."""
-
- lnk = example_link
- atr = getattr(lnk, attribute)
- assert atr == expected
+ yield link, link_rep_socket
+ link.socket.close(linger=0)
+ link_rep_socket.close(linger=0)
+ ctx.destroy(linger=0)
-def test_Link_init_start_end(setup_store):
- """Tests if the initialization has the right actors."""
+@pytest.fixture
+def test_rep_link():
+ """Fixture to provide a commonly used Link object."""
+ ctx = zmq.Context()
+ link_socket = ctx.socket(zmq.REP)
+ link_socket.bind("tcp://*:0")
+ link_socket_string = link_socket.getsockopt_string(SocketOption.LAST_ENDPOINT)
+ link_socket_port = int(link_socket_string.split(":")[-1])
- act = init_actors(2)
- lnk = Link("example_link", act[0].name, act[1].name)
+ link_req_socket = ctx.socket(zmq.REQ)
+ link_req_socket.connect(f"tcp://localhost:{link_socket_port}")
- assert lnk.start == act[0].name
- assert lnk.end == act[1].name
+ link = ZmqLink(link_socket, "test_link")
+ yield link, link_req_socket
+ link.socket.close(linger=0)
+ link_req_socket.close(linger=0)
+ ctx.destroy(linger=0)
-def test_getstate(example_link):
- """Tests if __getstate__ has the right values on initialization.
- Gets the dictionary of the link, then compares them against known
- default values. Does not compare store and actors.
+def test_pub_put(test_pub_link):
+ """Tests if messages can be put into the link.
TODO:
- Compare store and actors.
+ Parametrize multiple test input types.
"""
- res = example_link.__getstate__()
- errors = []
- errors.append(res["_real_executor"] is None)
- errors.append(res["cancelled_join"] is False)
-
- assert all(errors)
-
-
-@pytest.mark.parametrize(
- "input",
- [([None]), ([1]), ([i for i in range(5)]), ([str(i**i) for i in range(10)])],
-)
-def test_qsize_empty(example_link, input):
- """Tests that the queue has the number of elements in "input"."""
-
- lnk = example_link
- for i in input:
- lnk.put(i)
-
- qsize = lnk.queue.qsize()
- assert qsize == len(input)
-
-
-def test_getStart(example_link):
- """Tests if getStart returns the starting actor."""
-
- lnk = example_link
-
- assert lnk.getStart() == Actor("test 0", "/tmp/store").name
-
-
-def test_getEnd(example_link):
- """Tests if getEnd returns the ending actor."""
-
- lnk = example_link
+ link, link_sub_socket = test_pub_link
+ msg = "message"
- assert lnk.getEnd() == Actor("test 1", "/tmp/store").name
+ link.put(msg)
+ res = link_sub_socket.recv_multipart()
+ assert res[0].decode("utf-8") == "test_topic"
+ assert json.loads(res[1].decode("utf-8")) == msg
-def test_put(example_link):
+def test_req_put(test_req_link):
"""Tests if messages can be put into the link.
TODO:
Parametrize multiple test input types.
"""
- lnk = example_link
+ link, link_rep_socket = test_req_link
msg = "message"
- lnk.put(msg)
- assert lnk.get() == "message"
+ link.put(msg)
+ res = link_rep_socket.recv_pyobj()
+ assert res == msg
-def test_put_unserializable(example_link, caplog, setup_store):
+def test_put_unserializable(test_pub_link):
"""Tests if an unserializable object raises an error.
Instantiates an actor, which is unserializable, and passes it into
@@ -173,185 +140,199 @@ def test_put_unserializable(example_link, caplog, setup_store):
Raises:
SerializationCallbackError: Actor objects are unserializable.
"""
- # store = setup_store
- act = Actor("test", "/tmp/store")
- lnk = example_link
+ act = ManagedActor("test", "/tmp/store")
+ link, link_sub_socket = test_pub_link
sentinel = True
try:
- lnk.put(act)
+ link.put(act)
except Exception:
sentinel = False
- assert sentinel, "Unable to put"
- assert str(lnk.get()) == str(act)
+ assert not sentinel
-def test_put_irreducible(example_link, setup_store):
+def test_put_irreducible(test_pub_link, setup_store):
"""Tests if an irreducible object raises an error."""
- lnk = example_link
+ link, link_sub_socket = test_pub_link
store = setup_store
with pytest.raises(TypeError):
- lnk.put(store)
+ link.put(store)
-def test_put_nowait(example_link):
+def test_put_nowait(test_pub_link):
"""Tests if messages can be put into the link without blocking.
TODO:
Parametrize multiple test input types.
"""
- lnk = example_link
+ link, link_sub_socket = test_pub_link
msg = "message"
t_0 = time.perf_counter()
- lnk.put_nowait(msg)
+ link.put(msg) # put is already async even in synchronous zmq
t_1 = time.perf_counter()
t_net = t_1 - t_0
assert t_net < 0.005 # 5 ms
-@pytest.mark.asyncio
-async def test_put_async_success(example_link):
- """Tests if put_async returns None.
-
- TODO:
- Parametrize test input.
- """
-
- lnk = example_link
- msg = "message"
- res = await lnk.put_async(msg)
- assert res is None
-
-
-@pytest.mark.asyncio
-async def test_put_async_multiple(example_link):
+def test_put_multiple(test_pub_link):
"""Tests if async putting multiple objects preserves their order."""
+ link, link_sub_socket = test_pub_link
+
messages = [str(i) for i in range(10)]
messages_out = []
for msg in messages:
- await example_link.put_async(msg)
+ link.put(msg)
for msg in messages:
- messages_out.append(example_link.get())
+ messages_out.append(
+ json.loads(link_sub_socket.recv_multipart()[1].decode("utf-8"))
+ )
assert messages_out == messages
@pytest.mark.asyncio
-async def test_put_and_get_async(example_link):
+async def test_put_and_get_async(test_pub_link):
"""Tests if async get preserves order after async put."""
messages = [str(i) for i in range(10)]
messages_out = []
+ link, link_sub_socket = test_pub_link
+
for msg in messages:
- await example_link.put_async(msg)
+ await link.put_async(msg)
for msg in messages:
- messages_out.append(await example_link.get_async())
+ messages_out.append(
+ json.loads(link_sub_socket.recv_multipart()[1].decode("utf-8"))
+ )
assert messages_out == messages
-@pytest.mark.skip(
- reason="This test needs additional work to cause an overflow in the datastore."
+@pytest.mark.parametrize(
+ "message",
+ [
+ "message",
+ "",
+ None,
+ [str(i) for i in range(5)],
+ ],
)
-def test_put_overflow(setup_store, server_port_num, caplog):
- """Tests if putting too large of an object raises an error."""
-
- p = subprocess.Popen(
- ["redis-server", "--port", str(server_port_num), "--maxmemory", str(1000)],
- stdout=subprocess.DEVNULL,
- stderr=subprocess.DEVNULL,
- )
-
- acts = init_actors(2)
- lnk = Link("L1", acts[0], acts[1])
-
- message = [i for i in range(10**6)] # 24000 bytes
+@pytest.mark.parametrize(
+ "timeout",
+ [
+ None,
+ 5,
+ ],
+)
+def test_sub_get(test_sub_link, message, timeout):
+ """Tests if get gets the correct element from the queue."""
- lnk.put(message)
+ link, link_pub_socket = test_sub_link
- p.kill()
- p.wait()
+ time.sleep(1)
- if caplog.records:
- for record in caplog.records:
- if "PlasmaStoreInterfaceFull" in record.msg:
- assert True
+ if type(message) is list:
+ for i in message:
+ link_pub_socket.send_multipart(
+ [link.topic.encode("utf-8"), json.dumps(i).encode("utf-8")]
+ )
+ expected = message[0]
else:
- pytest.fail("expected an error!")
+ link_pub_socket.send_multipart(
+ [link.topic.encode("utf-8"), json.dumps(message).encode("utf-8")]
+ )
+ expected = message
+
+ assert link.get(timeout=timeout) == expected
@pytest.mark.parametrize(
"message",
[
- ("message"),
- (""),
- (None),
- ([str(i) for i in range(5)]),
+ "message",
+ "",
+ None,
+ [str(i) for i in range(5)],
+ ],
+)
+@pytest.mark.parametrize(
+ "timeout",
+ [
+ None,
+ 5,
],
)
-def test_get(example_link, message):
+def test_rep_get(test_rep_link, message, timeout):
"""Tests if get gets the correct element from the queue."""
- lnk = example_link
+ link, link_req_socket = test_rep_link
- if type(message) is list:
- for i in message:
- lnk.put(i)
- expected = message[0]
- else:
- lnk.put(message)
- expected = message
+ link_req_socket.send_pyobj(message)
+
+ expected = message
- assert lnk.get() == expected
+ assert link.get(timeout=timeout) == expected
-def test_get_empty(example_link):
+def test_get_empty(test_sub_link):
"""Tests if get blocks if the queue is empty."""
- lnk = example_link
- if lnk.queue.empty:
- with pytest.raises(queue.Empty):
- lnk.get(timeout=5.0)
- else:
- pytest.fail("expected a timeout!")
+ link, unused = test_sub_link
+
+ time.sleep(0.1)
+
+ with pytest.raises(TimeoutError):
+ link.get(timeout=0.5)
@pytest.mark.parametrize(
"message",
[
- ("message"),
- (""),
+ "message",
+ "",
([str(i) for i in range(5)]),
],
)
-def test_get_nowait(example_link, message):
+def test_get_nowait(test_sub_link, message):
"""Tests if get_nowait gets the correct element from the queue."""
- lnk = example_link
+ link, link_pub_socket = test_sub_link
if type(message) is list:
for i in message:
- lnk.put(i)
+ link_pub_socket.send_multipart(
+ [link.topic.encode("utf-8"), json.dumps(i).encode("utf-8")]
+ )
expected = message[0]
else:
- lnk.put(message)
+ link_pub_socket.send_multipart(
+ [link.topic.encode("utf-8"), json.dumps(message).encode("utf-8")]
+ )
expected = message
+ for i in range(20):
+ time.sleep(0.1)
+ avail = link.socket.poll(timeout=100)
+ if avail:
+ break
+ else:
+ pytest.fail("Message was not sent to link after 4s")
+
t_0 = time.perf_counter()
- res = lnk.get_nowait()
+ res = link.get_nowait()
t_1 = time.perf_counter()
@@ -359,95 +340,31 @@ def test_get_nowait(example_link, message):
assert t_1 - t_0 < 0.005 # 5 msg
-def test_get_nowait_empty(example_link):
+def test_get_nowait_empty(test_sub_link):
"""Tests if get_nowait raises an error when the queue is empty."""
- lnk = example_link
- if lnk.queue.empty():
- with pytest.raises(queue.Empty):
- lnk.get_nowait()
- else:
- pytest.fail("the queue is not empty")
+ link, unused = test_sub_link
+ with pytest.raises(TimeoutError):
+ link.get_nowait()
@pytest.mark.asyncio
-async def test_get_async_success(example_link):
+async def test_get_async_success(test_sub_link):
"""Tests if async_get gets the correct element from the queue."""
- lnk = example_link
+ link, link_pub_socket = test_sub_link
msg = "message"
- await lnk.put_async(msg)
- res = await lnk.get_async()
+ link_pub_socket.send_multipart(
+ [link.topic.encode("utf-8"), json.dumps(msg).encode("utf-8")]
+ )
+ res = await link.get_async()
assert res == "message"
-@pytest.mark.asyncio
-async def test_get_async_empty(example_link):
- """Tests if get_async times out given an empty queue.
-
- TODO:
- Implement a way to kill the task after execution (subprocess)?
- """
-
- lnk = example_link
- timeout = 5.0
-
- with pytest.raises(asyncio.TimeoutError):
- task = asyncio.create_task(lnk.get_async())
- await asyncio.wait_for(task, timeout)
- task.cancel()
-
- lnk.put("exit") # this is here to break out of get_async()
-
-
-@pytest.mark.skip(reason="unfinished")
-def test_cancel_join_thread(example_link):
- """Tests cancel_join_thread. This test is unfinished
-
- TODO:
- Identify where and when cancel_join_thread is being called.
- """
-
- lnk = example_link
- lnk.cancel_join_thread()
-
- assert lnk._cancelled_join is True
-
-
-@pytest.mark.skip(reason="unfinished")
-@pytest.mark.asyncio
-async def test_join_thread(example_link):
- """Tests join_thread. This test is unfinished
-
- TODO:
- Identify where and when join_thread is being called.
- """
- lnk = example_link
- await lnk.put_async("message")
- # msg = await lnk.get_async()
- lnk.join_thread()
- assert True
-
-
-@pytest.mark.asyncio
-async def test_multi_actor_system(example_actor_system, setup_store):
- """Tests if async puts/gets with many actors have good messages."""
-
- setup_store
-
- graph = example_actor_system
-
- acts = graph[0]
-
- heavy_msg = [str(i) for i in range(10**6)]
- light_msgs = ["message" + str(i) for i in range(3)]
-
- await acts[0].links["q_out_1"].put_async(heavy_msg)
- await acts[1].links["q_out_1"].put_async(light_msgs[0])
- await acts[1].links["q_out_2"].put_async(light_msgs[1])
- await acts[2].links["q_out_1"].put_async(light_msgs[2])
-
- assert await acts[1].links["q_in_1"].get_async() == heavy_msg
- assert await acts[2].links["q_in_1"].get_async() == light_msgs[1]
- assert await acts[3].links["q_in_1"].get_async() == light_msgs[0]
- assert await acts[3].links["q_in_2"].get_async() == light_msgs[2]
+def test_pub_put_no_topic():
+ ctx = zmq.Context()
+ s = ctx.socket(zmq.PUB)
+ with pytest.raises(Exception, match="Cannot open PUB link without topic"):
+ ZmqLink(s, "test")
+ s.close(linger=0)
+ ctx.destroy(linger=0)
diff --git a/test/test_messaging.py b/test/test_messaging.py
new file mode 100644
index 00000000..6b92c866
--- /dev/null
+++ b/test/test_messaging.py
@@ -0,0 +1,127 @@
+import improv.messaging
+
+
+def test_actor_state_msg():
+ name = "test name"
+ status = "test status"
+ port = 12345
+ info = "test info"
+ msg = improv.messaging.ActorStateMsg(name, status, port, info)
+ assert msg.actor_name == name
+ assert msg.status == status
+ assert msg.nexus_in_port == port
+ assert msg.info == info
+
+
+def test_actor_state_reply_msg():
+ name = "test name"
+ status = "test status"
+ info = "test info"
+ msg = improv.messaging.ActorStateReplyMsg(name, status, info)
+ assert msg.actor_name == name
+ assert msg.status == status
+ assert msg.info == info
+
+
+def test_actor_signal_msg():
+ name = "test name"
+ signal = "test signal"
+ info = "test info"
+ msg = improv.messaging.ActorSignalMsg(name, signal, info)
+ assert msg.actor_name == name
+ assert msg.signal == signal
+ assert msg.info == info
+
+
+def test_actor_signal_reply_msg():
+ name = "test name"
+ signal = "test_signal"
+ info = "test info"
+ msg = improv.messaging.ActorSignalReplyMsg(name, signal, info)
+ assert msg.actor_name == name
+ assert msg.info == info
+ assert msg.signal == signal
+
+
+def test_nexus_signal_msg():
+ name = "test name"
+ signal = "test signal"
+ info = "test info"
+ msg = improv.messaging.NexusSignalMsg(name, signal, info)
+ assert msg.actor_name == name
+ assert msg.signal == signal
+ assert msg.info == info
+
+
+def test_nexus_signal_reply_msg():
+ name = "test name"
+ status = "test status"
+ signal = "test_signal"
+ info = "test info"
+ msg = improv.messaging.NexusSignalReplyMsg(name, signal, status, info)
+ assert msg.actor_name == name
+ assert msg.status == status
+ assert msg.info == info
+ assert msg.signal == signal
+
+
+def test_broker_info_msg():
+ name = "test name"
+ pub_port = 54321
+ sub_port = 12345
+ info = "test info"
+ msg = improv.messaging.BrokerInfoMsg(name, pub_port, sub_port, info)
+ assert msg.name == name
+ assert msg.pub_port == pub_port
+ assert msg.sub_port == sub_port
+ assert msg.info == info
+
+
+def test_broker_info_reply_msg():
+ name = "test name"
+ status = "test status"
+ info = "test info"
+ msg = improv.messaging.BrokerInfoReplyMsg(name, status, info)
+ assert msg.name == name
+ assert msg.status == status
+ assert msg.info == info
+
+
+def test_log_info_msg():
+ name = "test name"
+ pull_port = 54321
+ pub_port = 12345
+ info = "test info"
+ msg = improv.messaging.LogInfoMsg(name, pull_port, pub_port, info)
+ assert msg.name == name
+ assert msg.pub_port == pub_port
+ assert msg.pull_port == pull_port
+ assert msg.info == info
+
+
+def test_log_info_reply_msg():
+ name = "test name"
+ status = "test status"
+ info = "test info"
+ msg = improv.messaging.LogInfoReplyMsg(name, status, info)
+ assert msg.name == name
+ assert msg.status == status
+ assert msg.info == info
+
+
+def test_harvester_info_msg():
+ name = "test name"
+ info = "test info"
+ msg = improv.messaging.HarvesterInfoMsg(name, info)
+ assert msg.name == name
+ assert msg.info == info
+
+
+def test_harvester_info_reply_msg():
+ name = "test name"
+ status = "test status"
+ info = "test info"
+ msg = improv.messaging.HarvesterInfoReplyMsg(name, status, info)
+ assert msg.name == name
+ assert msg.status == status
+ assert msg.info == info
diff --git a/test/test_nexus.py b/test/test_nexus.py
index e1a0e5f6..72cafd99 100644
--- a/test/test_nexus.py
+++ b/test/test_nexus.py
@@ -2,77 +2,21 @@
import shutil
import time
import os
+
import pytest
import logging
-import subprocess
-import signal
-import yaml
-
-from improv.nexus import Nexus
-from improv.store import StoreInterface
-
-# from improv.actor import Actor
-# from improv.store import StoreInterface
-
-SERVER_COUNTER = 0
-
-
-@pytest.fixture
-def ports():
- global SERVER_COUNTER
- CONTROL_PORT = 5555
- OUTPUT_PORT = 5556
- LOGGING_PORT = 5557
- yield (
- CONTROL_PORT + SERVER_COUNTER,
- OUTPUT_PORT + SERVER_COUNTER,
- LOGGING_PORT + SERVER_COUNTER,
- )
- SERVER_COUNTER += 3
-
-
-@pytest.fixture
-def setdir():
- prev = os.getcwd()
- os.chdir(os.path.dirname(__file__) + "/configs")
- yield None
- os.chdir(prev)
-
-
-@pytest.fixture
-def sample_nex(setdir, ports):
- nex = Nexus("test")
- nex.createNexus(
- file="good_config.yaml",
- store_size=40000000,
- control_port=ports[0],
- output_port=ports[1],
- )
- yield nex
- nex.destroyNexus()
-
-
-# @pytest.fixture
-# def setup_store(setdir):
-# """ Fixture to set up the store subprocess with 10 mb.
-# This fixture runs a subprocess that instantiates the store with a
-# memory of 10 megabytes. It specifies that "/tmp/store/" is the
-# location of the store socket.
+import zmq
-# Yields:
-# StoreInterface: An instance of the store.
-
-# TODO:
-# Figure out the scope.
-# """
-# setdir
-# p = subprocess.Popen(
-# ['plasma_store', '-s', '/tmp/store/', '-m', str(10000000)],\
-# stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
-# store = StoreInterface(store_loc = "/tmp/store/")
-# yield store
-# p.kill()
+from improv.config import CannotCreateConfigException
+from improv.messaging import ActorStateMsg
+from improv.nexus import (
+ Nexus,
+ ConfigFileNotProvidedException,
+ ConfigFileNotValidException,
+)
+from improv.store import StoreInterface
+from conftest import SignalManager
def test_init(setdir):
@@ -85,32 +29,34 @@ def test_init(setdir):
"cfg_name",
[
"good_config.yaml",
- "good_config_plasma.yaml",
],
)
-def test_createNexus(setdir, ports, cfg_name):
+def test_create_nexus(setdir, ports, cfg_name):
nex = Nexus("test")
- nex.createNexus(file=cfg_name, control_port=ports[0], output_port=ports[1])
- assert list(nex.comm_queues.keys()) == [
- "GUI_comm",
- "Acquirer_comm",
- "Analysis_comm",
- ]
- assert list(nex.sig_queues.keys()) == ["Acquirer_sig", "Analysis_sig"]
- assert list(nex.data_queues.keys()) == ["Acquirer.q_out", "Analysis.q_in"]
+ nex.create_nexus(
+ file=cfg_name,
+ control_port=ports[0],
+ output_port=ports[1],
+ log_server_pub_port=ports[2],
+ log_server_pull_port=ports[3],
+ )
assert list(nex.actors.keys()) == ["Acquirer", "Analysis"]
assert list(nex.flags.keys()) == ["quit", "run", "load"]
assert nex.processes == []
- nex.destroyNexus()
+ nex.destroy_nexus()
assert True
def test_config_logged(setdir, ports, caplog):
nex = Nexus("test")
- nex.createNexus(
- file="minimal_with_settings.yaml", control_port=ports[0], output_port=ports[1]
+ nex.create_nexus(
+ file="minimal_with_settings.yaml",
+ control_port=ports[0],
+ output_port=ports[1],
+ log_server_pub_port=ports[2],
+ log_server_pull_port=ports[3],
)
- nex.destroyNexus()
+ nex.destroy_nexus()
assert any(
[
"not_relevant: for testing purposes" in record.msg
@@ -119,54 +65,52 @@ def test_config_logged(setdir, ports, caplog):
)
-def test_loadConfig(sample_nex):
+def test_load_config(sample_nex):
nex = sample_nex
- nex.loadConfig("good_config.yaml")
- assert set(nex.comm_queues.keys()) == set(
- ["Acquirer_comm", "Analysis_comm", "GUI_comm"]
+ assert any(
+ [
+ link_info.link_name == "q_out"
+ for link_info in nex.outgoing_topics["Acquirer"]
+ ]
+ )
+ assert any(
+ [link_info.link_name == "q_in" for link_info in nex.incoming_topics["Analysis"]]
)
def test_argument_config_precedence(setdir, ports):
nex = Nexus("test")
- nex.createNexus(
+ nex.create_nexus(
file="minimal_with_settings.yaml",
control_port=ports[0],
output_port=ports[1],
+ log_server_pub_port=ports[2],
+ log_server_pull_port=ports[3],
store_size=11_000_000,
- use_watcher=True,
)
cfg = nex.config.settings
- nex.destroyNexus()
+ nex.destroy_nexus()
assert cfg["control_port"] == ports[0]
assert cfg["output_port"] == ports[1]
- assert cfg["store_size"] == 20_000_000
- assert not cfg["use_watcher"]
+ assert cfg["store_size"] == 11_000_000
-def test_settings_override_random_ports(setdir, ports):
- config_file = "minimal_with_settings.yaml"
- nex = Nexus("test")
- with open(config_file, "r") as ymlfile:
- cfg = yaml.safe_load(ymlfile)["settings"]
- control_port, output_port = nex.createNexus(
- file=config_file, control_port=0, output_port=0
- )
- nex.destroyNexus()
- assert control_port == cfg["control_port"]
- assert output_port == cfg["output_port"]
+# delete this comment later
+@pytest.mark.skip(reason="makes use of parameterized start_nexus")
+def test_start_nexus(sample_nex):
+ with SignalManager():
+ async def set_quit_flag(test_nex):
+ test_nex.flags["quit"] = True
-# delete this comment later
-@pytest.mark.skip(reason="unfinished")
-def test_startNexus(sample_nex):
- nex = sample_nex
- nex.startNexus()
- assert [p.name for p in nex.processes] == ["Acquirer", "Analysis"]
- nex.destroyNexus()
+ nex = sample_nex
+ nex.start_nexus(nex.poll_queues, poll_function=set_quit_flag, test_nex=nex)
+ assert [p.name for p in nex.processes] == ["Acquirer", "Analysis"]
-# @pytest.mark.skip(reason="This test is unfinished")
+@pytest.mark.skip(
+ reason="This test is unfinished - it does not validate link structure"
+)
@pytest.mark.parametrize(
("cfg_name", "actor_list", "link_list"),
[
@@ -197,20 +141,23 @@ def test_config_construction(cfg_name, actor_list, link_list, setdir, ports):
"""
nex = Nexus("test")
- nex.createNexus(file=cfg_name, control_port=ports[0], output_port=ports[1])
+ nex.create_nexus(
+ file=cfg_name,
+ control_port=ports[0],
+ output_port=ports[1],
+ log_server_pub_port=ports[2],
+ log_server_pull_port=ports[3],
+ )
logging.info(cfg_name)
# Check for actors
act_lst = list(nex.actors)
- lnk_lst = list(nex.sig_queues)
- nex.destroyNexus()
+ nex.destroy_nexus()
assert actor_list == act_lst
- assert link_list == lnk_lst
act_lst = []
- lnk_lst = []
assert True
@@ -218,165 +165,270 @@ def test_config_construction(cfg_name, actor_list, link_list, setdir, ports):
"cfg_name",
[
"single_actor.yaml",
- "single_actor_plasma.yaml",
],
)
def test_single_actor(setdir, ports, cfg_name):
nex = Nexus("test")
with pytest.raises(AttributeError):
- nex.createNexus(
- file="single_actor.yaml", control_port=ports[0], output_port=ports[1]
+ nex.create_nexus(
+ file="single_actor.yaml",
+ control_port=ports[0],
+ output_port=ports[1],
+ log_server_pub_port=ports[2],
+ log_server_pull_port=ports[3],
)
- nex.destroyNexus()
+ nex.destroy_nexus()
def test_cyclic_graph(setdir, ports):
nex = Nexus("test")
- nex.createNexus(
- file="cyclic_config.yaml", control_port=ports[0], output_port=ports[1]
+ nex.create_nexus(
+ file="cyclic_config.yaml",
+ control_port=ports[0],
+ output_port=ports[1],
+ log_server_pub_port=ports[2],
+ log_server_pull_port=ports[3],
)
assert True
- nex.destroyNexus()
+ nex.destroy_nexus()
def test_blank_cfg(setdir, caplog, ports):
nex = Nexus("test")
- with pytest.raises(TypeError):
- nex.createNexus(
- file="blank_file.yaml", control_port=ports[0], output_port=ports[1]
+ with pytest.raises(CannotCreateConfigException):
+ nex.create_nexus(
+ file="blank_file.yaml",
+ control_port=ports[0],
+ output_port=ports[1],
+ log_server_pub_port=ports[2],
+ log_server_pull_port=ports[3],
)
assert any(
["The config file is empty" in record.msg for record in list(caplog.records)]
)
- nex.destroyNexus()
+ nex.destroy_nexus()
-# def test_hasGUI_True(setdir):
-# setdir
-# nex = Nexus("test")
-# nex.createNexus(file="basic_demo_with_GUI.yaml")
+def test_start_store(caplog):
+ nex = Nexus("test")
+ nex._start_store_interface(100_000_000) # 100 MB store
-# assert True
-# nex.destroyNexus()
+ assert any(
+ "StoreInterface start successful" in record.msg for record in caplog.records
+ )
-# @pytest.mark.skip(reason="This test is unfinished.")
-# def test_hasGUI_False():
-# assert True
+ nex._close_store_interface()
+ nex.destroy_nexus()
+ assert True
-@pytest.mark.skip(reason="unfinished")
-def test_queue_message(setdir, sample_nex):
- nex = sample_nex
- nex.startNexus()
- time.sleep(20)
- nex.setup()
- time.sleep(20)
- nex.run()
- time.sleep(10)
- acq_comm = nex.comm_queues["Acquirer_comm"]
- acq_comm.put("Test Message")
-
- assert nex.comm_queues is None
- nex.destroyNexus()
- assert True
+def test_close_store(caplog):
+ nex = Nexus("test")
+ nex._start_store_interface(10000)
+ nex._close_store_interface()
+ assert any(
+ "StoreInterface close successful" in record.msg for record in caplog.records
+ )
-@pytest.mark.asyncio
-@pytest.mark.skip(reason="This test is unfinished.")
-async def test_queue_readin(sample_nex, caplog):
- nex = sample_nex
- nex.startNexus()
- # cqs = nex.comm_queues
- # assert cqs == None
- assert [record.msg for record in caplog.records] is None
- # cqs["Acquirer_comm"].put('quit')
- # assert "quit" == cqs["Acquirer_comm"].get()
- # await nex.pollQueues()
- assert True
+ # write to store
+ with pytest.raises(AttributeError):
+ nex.p_StoreInterface.put("Message in", "Message in Label")
-@pytest.mark.skip(reason="This test is unfinished.")
-def test_queue_sendout():
+ nex.destroy_nexus()
assert True
-@pytest.mark.skip(reason="This test is unfinished.")
-def test_run_sig():
- assert True
+def test_start_harvester(caplog, setdir, ports):
+ nex = Nexus("test")
+ try:
+ nex.create_nexus(
+ file="minimal_harvester.yaml",
+ store_size=100_000_000,
+ control_port=ports[0],
+ output_port=ports[1],
+ log_server_pub_port=ports[2],
+ log_server_pull_port=ports[3],
+ )
+ time.sleep(3)
-@pytest.mark.skip(reason="This test is unfinished.")
-def test_setup_sig():
- assert True
+ nex.destroy_nexus()
+ except Exception as e:
+ print(f"error caught in test harness: {e}")
+ logging.error(f"error caught in test harness: {e}")
+ assert any("Harvester server started" in record.msg for record in caplog.records)
-@pytest.mark.skip(reason="This test is unfinished.")
-def test_quit_sig():
- assert True
+def test_process_actor_state_update(caplog, setdir, ports):
+ nex = Nexus("test")
+ try:
+ nex.create_nexus(
+ file="minimal_harvester.yaml",
+ store_size=100_000_000,
+ control_port=ports[0],
+ output_port=ports[1],
+ log_server_pub_port=ports[2],
+ log_server_pull_port=ports[3],
+ )
-@pytest.mark.skip(reason="This test is unfinished.")
-def test_usehdd_True():
- assert True
+ time.sleep(3)
+ new_actor_message = ActorStateMsg("test actor", "waiting", 1234, "test info")
-@pytest.mark.skip(reason="This test is unfinished.")
-def test_usehdd_False():
- assert True
+ nex.process_actor_state_update(new_actor_message)
+ assert "test actor" in nex.actor_states
+ assert nex.actor_states["test actor"].actor_name == new_actor_message.actor_name
+ assert (
+ nex.actor_states["test actor"].nexus_in_port
+ == new_actor_message.nexus_in_port
+ )
+ assert nex.actor_states["test actor"].status == new_actor_message.status
+ update_actor_message = ActorStateMsg("test actor", "waiting", 1234, "test info")
-def test_startstore(caplog):
- nex = Nexus("test")
- nex._startStoreInterface(10000000) # 10 MB store
+ nex.process_actor_state_update(update_actor_message)
+ nex.destroy_nexus()
+ except Exception as e:
+ print(f"error caught in test harness: {e}")
+ logging.error(f"error caught in test harness: {e}")
assert any(
- "StoreInterface start successful" in record.msg for record in caplog.records
+ "Received state message from new actor test actor" in record.msg
+ for record in caplog.records
+ )
+ assert any(
+ "Received state message from actor test actor" in record.msg
+ for record in caplog.records
)
- nex._closeStoreInterface()
- nex.destroyNexus()
- assert True
+def test_process_actor_state_update_allows_run(caplog, setdir, ports):
+ nex = Nexus("test")
+ try:
+ nex.create_nexus(
+ file="minimal_harvester.yaml",
+ store_size=100_000_000,
+ control_port=ports[0],
+ output_port=ports[1],
+ log_server_pub_port=ports[2],
+ log_server_pull_port=ports[3],
+ )
+
+ time.sleep(3)
+
+ nex.actor_states["test actor1"] = None
+ nex.actor_states["test actor2"] = None
+
+ actor1_message = ActorStateMsg("test actor1", "ready", 1234, "test info")
-def test_closestore(caplog):
+ nex.process_actor_state_update(actor1_message)
+ assert "test actor1" in nex.actor_states
+ assert nex.actor_states["test actor1"].actor_name == actor1_message.actor_name
+ assert (
+ nex.actor_states["test actor1"].nexus_in_port
+ == actor1_message.nexus_in_port
+ )
+ assert nex.actor_states["test actor1"].status == actor1_message.status
+
+ assert not nex.allowStart
+
+ actor2_message = ActorStateMsg("test actor2", "ready", 5678, "test info2")
+
+ nex.process_actor_state_update(actor2_message)
+ assert "test actor2" in nex.actor_states
+ assert nex.actor_states["test actor2"].actor_name == actor2_message.actor_name
+ assert (
+ nex.actor_states["test actor2"].nexus_in_port
+ == actor2_message.nexus_in_port
+ )
+ assert nex.actor_states["test actor2"].status == actor2_message.status
+
+ nex.destroy_nexus()
+ except Exception as e:
+ print(f"error caught in test harness: {e}")
+ logging.error(f"error caught in test harness: {e}")
+ assert nex.allowStart
+
+
+@pytest.mark.asyncio
+async def test_process_actor_message(caplog, setdir, ports):
nex = Nexus("test")
- nex._startStoreInterface(10000)
- nex._closeStoreInterface()
+ try:
+ nex.create_nexus(
+ file="minimal_harvester.yaml",
+ store_size=100_000_000,
+ control_port=ports[0],
+ output_port=ports[1],
+ log_server_pub_port=ports[2],
+ log_server_pull_port=ports[3],
+ )
- assert any(
- "StoreInterface close successful" in record.msg for record in caplog.records
- )
+ time.sleep(3)
- # write to store
+ nex.actor_states["test actor1"] = None
+ nex.actor_states["test actor2"] = None
- with pytest.raises(AttributeError):
- nex.p_StoreInterface.put("Message in", "Message in Label")
+ actor1_message = ActorStateMsg("test actor1", "ready", 1234, "test info")
- nex.destroyNexus()
- assert True
+ ctx = nex.zmq_context
+ cfg = nex.config.settings
+ s = ctx.socket(zmq.REQ)
+ s.connect(f"tcp://localhost:{cfg['control_port']}")
+
+ s.send_pyobj(actor1_message)
+
+ await nex.process_actor_message()
+
+ nex.process_actor_state_update(actor1_message)
+ assert "test actor1" in nex.actor_states
+ assert nex.actor_states["test actor1"].actor_name == actor1_message.actor_name
+ assert (
+ nex.actor_states["test actor1"].nexus_in_port
+ == actor1_message.nexus_in_port
+ )
+ assert nex.actor_states["test actor1"].status == actor1_message.status
+
+ s.close(linger=0)
+ nex.destroy_nexus()
+ except Exception as e:
+ print(f"error caught in test harness: {e}")
+ logging.error(f"error caught in test harness: {e}")
+
+ assert not nex.allowStart
def test_specified_free_port(caplog, setdir, ports):
nex = Nexus("test")
- nex.createNexus(
- file="minimal_with_fixed_redis_port.yaml",
- store_size=10000000,
- control_port=ports[0],
- output_port=ports[1],
- )
+ try:
+ nex.create_nexus(
+ file="minimal_with_fixed_redis_port.yaml",
+ store_size=100_000_000,
+ control_port=ports[0],
+ output_port=ports[1],
+ log_server_pub_port=ports[2],
+ log_server_pull_port=ports[3],
+ )
- store = StoreInterface(server_port_num=6378)
- store.connect_to_server()
- key = store.put("port 6378")
- assert store.get(key) == "port 6378"
+ store = StoreInterface(server_port_num=6378)
+ store.connect_to_server()
+ key = store.put("port 6378")
+ assert store.get(key) == "port 6378"
- assert any(
- "Successfully connected to redis datastore on port 6378" in record.msg
- for record in caplog.records
- )
+ assert any(
+ "Successfully connected to redis datastore on port 6378" in record.msg
+ for record in caplog.records
+ )
+
+ time.sleep(3)
- nex.destroyNexus()
+ nex.destroy_nexus()
+ except Exception as e:
+ print(f"error caught in test harness: {e}")
+ logging.error(f"error caught in test harness: {e}")
assert any(
"StoreInterface start successful on port 6378" in record.msg
@@ -386,32 +438,51 @@ def test_specified_free_port(caplog, setdir, ports):
def test_specified_busy_port(caplog, setdir, ports, setup_store):
nex = Nexus("test")
- with pytest.raises(Exception, match="Could not start Redis on specified port."):
- nex.createNexus(
+ try:
+ nex.create_nexus(
file="minimal_with_fixed_default_redis_port.yaml",
- store_size=10000000,
+ store_size=100_000_000,
control_port=ports[0],
output_port=ports[1],
+ log_server_pub_port=ports[2],
+ log_server_pull_port=ports[3],
)
- nex.destroyNexus()
+ time.sleep(3)
+
+ nex.destroy_nexus()
+ except Exception as e:
+ print(f"error caught in test harness: {e}")
+ logging.error(f"error caught in test harness: {e}")
+
+ assert any(
+ "Could not connect to port 6379" in record.msg for record in caplog.records
+ )
assert any(
- "Could not start Redis on specified port number." in record.msg
+ "StoreInterface start successful on port 6380" in record.msg
for record in caplog.records
)
def test_unspecified_port_default_free(caplog, setdir, ports):
nex = Nexus("test")
- nex.createNexus(
- file="minimal.yaml",
- store_size=10000000,
- control_port=ports[0],
- output_port=ports[1],
- )
+ try:
+ nex.create_nexus(
+ file="minimal.yaml",
+ store_size=100_000_000,
+ control_port=ports[0],
+ output_port=ports[1],
+ log_server_pub_port=ports[2],
+ log_server_pull_port=ports[3],
+ )
+
+ time.sleep(3)
- nex.destroyNexus()
+ nex.destroy_nexus()
+ except Exception as e:
+ print(f"error caught in test harness: {e}")
+ logging.error(f"error caught in test harness: {e}")
assert any(
"StoreInterface start successful on port 6379" in record.msg
@@ -421,14 +492,22 @@ def test_unspecified_port_default_free(caplog, setdir, ports):
def test_unspecified_port_default_busy(caplog, setdir, ports, setup_store):
nex = Nexus("test")
- nex.createNexus(
- file="minimal.yaml",
- store_size=10000000,
- control_port=ports[0],
- output_port=ports[1],
- )
+ try:
+ nex.create_nexus(
+ file="minimal.yaml",
+ store_size=100_000_000,
+ control_port=ports[0],
+ output_port=ports[1],
+ log_server_pub_port=ports[2],
+ log_server_pull_port=ports[3],
+ )
+
+ time.sleep(3)
- nex.destroyNexus()
+ nex.destroy_nexus()
+ except Exception as e:
+ print(f"error caught in test harness: {e}")
+ logging.error(f"error caught in test harness: {e}")
assert any(
"StoreInterface start successful on port 6380" in record.msg
for record in caplog.records
@@ -436,15 +515,27 @@ def test_unspecified_port_default_busy(caplog, setdir, ports, setup_store):
def test_no_aof_dir_by_default(caplog, setdir, ports):
- nex = Nexus("test")
- nex.createNexus(
- file="minimal.yaml",
- store_size=10000000,
- control_port=ports[0],
- output_port=ports[1],
- )
+ try:
+ if "appendonlydir" in os.listdir("."):
+ shutil.rmtree("appendonlydir")
+ else:
+ logging.info("didn't find dbfilename")
- nex.destroyNexus()
+ nex = Nexus("test")
+
+ nex.create_nexus(
+ file="minimal.yaml",
+ store_size=100_000_000,
+ control_port=ports[0],
+ output_port=ports[1],
+ )
+
+ time.sleep(3)
+
+ nex.destroy_nexus()
+ except Exception as e:
+ print(f"error caught in test harness: {e}")
+ logging.error(f"error caught in test harness: {e}")
assert "appendonlydir" not in os.listdir(".")
assert all(["improv_persistence_" not in name for name in os.listdir(".")])
@@ -452,19 +543,25 @@ def test_no_aof_dir_by_default(caplog, setdir, ports):
def test_default_aof_dir_if_none_specified(caplog, setdir, ports, server_port_num):
nex = Nexus("test")
- nex.createNexus(
- file="minimal_with_redis_saving.yaml",
- store_size=10000000,
- control_port=ports[0],
- output_port=ports[1],
- )
+ try:
+ nex.create_nexus(
+ file="minimal_with_redis_saving.yaml",
+ store_size=100_000_000,
+ control_port=ports[0],
+ output_port=ports[1],
+ log_server_pub_port=ports[2],
+ log_server_pull_port=ports[3],
+ )
- store = StoreInterface(server_port_num=server_port_num)
- store.put(1)
+ store = StoreInterface(server_port_num=server_port_num)
+ store.put(1)
- time.sleep(3)
+ time.sleep(3)
- nex.destroyNexus()
+ nex.destroy_nexus()
+ except Exception as e:
+ print(f"error caught in test harness: {e}")
+ logging.error(f"error caught in test harness: {e}")
assert "appendonlydir" in os.listdir(".")
@@ -478,19 +575,25 @@ def test_default_aof_dir_if_none_specified(caplog, setdir, ports, server_port_nu
def test_specify_static_aof_dir(caplog, setdir, ports, server_port_num):
nex = Nexus("test")
- nex.createNexus(
- file="minimal_with_custom_aof_dirname.yaml",
- store_size=10000000,
- control_port=ports[0],
- output_port=ports[1],
- )
+ try:
+ nex.create_nexus(
+ file="minimal_with_custom_aof_dirname.yaml",
+ store_size=100_000_000,
+ control_port=ports[0],
+ output_port=ports[1],
+ log_server_pub_port=ports[2],
+ log_server_pull_port=ports[3],
+ )
- store = StoreInterface(server_port_num=server_port_num)
- store.put(1)
+ store = StoreInterface(server_port_num=server_port_num)
+ store.put(1)
- time.sleep(3)
+ time.sleep(3)
- nex.destroyNexus()
+ nex.destroy_nexus()
+ except Exception as e:
+ print(f"error caught in test harness: {e}")
+ logging.error(f"error caught in test harness: {e}")
assert "custom_aof_dirname" in os.listdir(".")
@@ -504,19 +607,25 @@ def test_specify_static_aof_dir(caplog, setdir, ports, server_port_num):
def test_use_ephemeral_aof_dir(caplog, setdir, ports, server_port_num):
nex = Nexus("test")
- nex.createNexus(
- file="minimal_with_ephemeral_aof_dirname.yaml",
- store_size=10000000,
- control_port=ports[0],
- output_port=ports[1],
- )
+ try:
+ nex.create_nexus(
+ file="minimal_with_ephemeral_aof_dirname.yaml",
+ store_size=100_000_000,
+ control_port=ports[0],
+ output_port=ports[1],
+ log_server_pub_port=ports[2],
+ log_server_pull_port=ports[3],
+ )
- store = StoreInterface(server_port_num=server_port_num)
- store.put(1)
+ store = StoreInterface(server_port_num=server_port_num)
+ store.put(1)
- time.sleep(3)
+ time.sleep(3)
- nex.destroyNexus()
+ nex.destroy_nexus()
+ except Exception as e:
+ print(f"error caught in test harness: {e}")
+ logging.error(f"error caught in test harness: {e}")
assert any(["improv_persistence_" in name for name in os.listdir(".")])
@@ -527,18 +636,26 @@ def test_use_ephemeral_aof_dir(caplog, setdir, ports, server_port_num):
def test_save_no_schedule(caplog, setdir, ports, server_port_num):
nex = Nexus("test")
- nex.createNexus(
- file="minimal_with_no_schedule_saving.yaml",
- store_size=10000000,
- control_port=ports[0],
- output_port=ports[1],
- )
+ try:
+ nex.create_nexus(
+ file="minimal_with_no_schedule_saving.yaml",
+ store_size=100_000_000,
+ control_port=ports[0],
+ output_port=ports[1],
+ log_server_pub_port=ports[2],
+ log_server_pull_port=ports[3],
+ )
+
+ store = StoreInterface(server_port_num=server_port_num)
- store = StoreInterface(server_port_num=server_port_num)
+ fsync_schedule = store.client.config_get("appendfsync")
- fsync_schedule = store.client.config_get("appendfsync")
+ time.sleep(3)
- nex.destroyNexus()
+ nex.destroy_nexus()
+ except Exception as e:
+ print(f"error caught in test harness: {e}")
+ logging.error(f"error caught in test harness: {e}")
assert "appendonlydir" in os.listdir(".")
shutil.rmtree("appendonlydir")
@@ -548,18 +665,26 @@ def test_save_no_schedule(caplog, setdir, ports, server_port_num):
def test_save_every_second(caplog, setdir, ports, server_port_num):
nex = Nexus("test")
- nex.createNexus(
- file="minimal_with_every_second_saving.yaml",
- store_size=10000000,
- control_port=ports[0],
- output_port=ports[1],
- )
+ try:
+ nex.create_nexus(
+ file="minimal_with_every_second_saving.yaml",
+ store_size=100_000_000,
+ control_port=ports[0],
+ output_port=ports[1],
+ log_server_pub_port=ports[2],
+ log_server_pull_port=ports[3],
+ )
+
+ store = StoreInterface(server_port_num=server_port_num)
- store = StoreInterface(server_port_num=server_port_num)
+ fsync_schedule = store.client.config_get("appendfsync")
- fsync_schedule = store.client.config_get("appendfsync")
+ time.sleep(3)
- nex.destroyNexus()
+ nex.destroy_nexus()
+ except Exception as e:
+ print(f"error caught in test harness: {e}")
+ logging.error(f"error caught in test harness: {e}")
assert "appendonlydir" in os.listdir(".")
shutil.rmtree("appendonlydir")
@@ -569,18 +694,26 @@ def test_save_every_second(caplog, setdir, ports, server_port_num):
def test_save_every_write(caplog, setdir, ports, server_port_num):
nex = Nexus("test")
- nex.createNexus(
- file="minimal_with_every_write_saving.yaml",
- store_size=10000000,
- control_port=ports[0],
- output_port=ports[1],
- )
+ try:
+ nex.create_nexus(
+ file="minimal_with_every_write_saving.yaml",
+ store_size=100_000_000,
+ control_port=ports[0],
+ output_port=ports[1],
+ log_server_pub_port=ports[2],
+ log_server_pull_port=ports[3],
+ )
+
+ store = StoreInterface(server_port_num=server_port_num)
- store = StoreInterface(server_port_num=server_port_num)
+ fsync_schedule = store.client.config_get("appendfsync")
- fsync_schedule = store.client.config_get("appendfsync")
+ time.sleep(3)
- nex.destroyNexus()
+ nex.destroy_nexus()
+ except Exception as e:
+ print(f"error caught in test harness: {e}")
+ logging.error(f"error caught in test harness: {e}")
assert "appendonlydir" in os.listdir(".")
shutil.rmtree("appendonlydir")
@@ -588,65 +721,98 @@ def test_save_every_write(caplog, setdir, ports, server_port_num):
assert fsync_schedule["appendfsync"] == "always"
-@pytest.mark.skip(reason="Nexus no longer deletes files on shutdown. Nothing to test.")
-def test_store_already_deleted_issues_warning(caplog):
- nex = Nexus("test")
- nex._startStoreInterface(10000)
- store_location = nex.store_loc
- StoreInterface(store_loc=nex.store_loc)
- os.remove(nex.store_loc)
- nex.destroyNexus()
- assert any(
- "StoreInterface file {} is already deleted".format(store_location) in record.msg
- for record in caplog.records
- )
+# def test_sigint_exits_cleanly(ports, set_dir_config_parent):
+# server_opts = [
+# "improv",
+# "server",
+# "-c",
+# str(ports[0]),
+# "-o",
+# str(ports[1]),
+# "-f",
+# "global.log",
+# "configs/minimal.yaml",
+# ]
+#
+# env = os.environ.copy()
+# env["PYTHONPATH"] += ":" + os.getcwd()
+#
+# server = subprocess.Popen(
+# server_opts, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=env
+# )
+#
+# time.sleep(5)
+#
+# server.send_signal(signal.SIGINT)
+#
+# server.wait(10)
+# assert True
-@pytest.mark.skip(reason="unfinished")
-def test_actor_sub(setdir, capsys, monkeypatch, ports):
- monkeypatch.setattr("improv.nexus.input", lambda: "setup\n")
- cfg_file = "sample_config.yaml"
+def test_nexus_create_nexus_no_cfg_file(ports):
nex = Nexus("test")
+ with pytest.raises(ConfigFileNotProvidedException):
+ nex.create_nexus()
- nex.createNexus(
- file=cfg_file, store_size=4000, control_port=ports[0], output_port=ports[1]
- )
- print("Nexus Created")
-
- nex.startNexus()
- print("Nexus Started")
- # time.sleep(5)
- # print("Printing...")
- # subprocess.Popen(["setup"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
- # time.sleep(2)
- # subprocess.Popen(["run"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
- # time.sleep(5)
- # subprocess.Popen(["quit"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
-
- nex.destroyNexus()
- assert True
+#
+# @pytest.mark.skip(reason="Blocking comms so this won't work as-is")
+# def test_nexus_actor_comm_setup(ports, setdir):
+# filename = "minimal_zmq.yaml"
+# nex = Nexus("test")
+# nex.create_nexus(
+# file=filename,
+# store_size=10000000,
+# control_port=ports[0],
+# output_port=ports[1],
+# actor_in_port=ports[2],
+# )
+#
+# actor = nex.actors["Generator"]
+# actor.register_with_nexus()
+#
+# nex.process_actor_message()
+#
+#
+# @pytest.mark.skip(reason="Test isn't meant to be used for coverage")
+# def test_debug_nex(ports, setdir):
+# filename = "minimal_zmq.yaml"
+# conftest.nex_startup(ports, filename)
+#
+#
+# @pytest.mark.skip(reason="Test isn't meant to be used for coverage")
+# def test_nex_cfg(ports, setdir):
+# filename = "minimal_zmq.yaml"
+# nex = Nexus("test")
+# nex.create_nexus(
+# file=filename,
+# store_size=100000000,
+# control_port=ports[0],
+# output_port=ports[1],
+# actor_in_port=ports[2],
+# )
+# nex.start_nexus()
-@pytest.mark.skip(
- reason="skipping to prevent issues with orphaned stores. TODO fix this"
-)
-def test_sigint_exits_cleanly(ports, tmp_path):
- server_opts = [
- "improv",
- "server",
- "-c",
- str(ports[0]),
- "-o",
- str(ports[1]),
- "-f",
- tmp_path / "global.log",
- ]
-
- server = subprocess.Popen(
- server_opts, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
- )
- server.send_signal(signal.SIGINT)
+def test_nexus_bad_config_actor_args(setdir):
+ nex = Nexus("test")
+ # with pytest.raises(ConfigFileNotValidException):
+ try:
+ nex.create_nexus("bad_args.yaml")
+ except ConfigFileNotValidException:
+ assert True
+ except Exception as e:
+ print(f"error caught in test harness: {e}")
+ logging.error(f"error caught in test harness: {e}")
+
- server.wait(10)
- assert True
+def test_nexus_no_config_file():
+ nex = Nexus("test")
+ # with pytest.raises(ConfigFileNotProvidedException):
+ try:
+ nex.create_nexus()
+ except ConfigFileNotProvidedException:
+ assert True
+ except Exception as e:
+ print(f"error caught in test harness: {e}")
+ logging.error(f"error caught in test harness: {e}")
diff --git a/test/test_store_with_errors.py b/test/test_store_with_errors.py
index b95781a0..c3f5b0c9 100644
--- a/test/test_store_with_errors.py
+++ b/test/test_store_with_errors.py
@@ -1,9 +1,7 @@
import pytest
-from pyarrow import plasma
-from improv.store import StoreInterface, RedisStoreInterface, PlasmaStoreInterface
+from improv.store import StoreInterface, RedisStoreInterface, ObjectNotFoundError
-from pyarrow._plasma import PlasmaObjectExists
from scipy.sparse import csc_matrix
import numpy as np
import redis
@@ -11,97 +9,34 @@
from improv.store import CannotConnectToStoreInterfaceError
-WAIT_TIMEOUT = 10
-
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
-# TODO: add docstrings!!!
-# TODO: clean up syntax - consistent capitalization, function names, etc.
-# TODO: decide to keep classes
-# TODO: increase coverage!!! SEE store.py
-
-# Separate each class as individual file - individual tests???
-
-
def test_connect(setup_store, server_port_num):
store = StoreInterface(server_port_num=server_port_num)
assert isinstance(store.client, redis.Redis)
-def test_plasma_connect(setup_plasma_store, set_store_loc):
- store = PlasmaStoreInterface(store_loc=set_store_loc)
- assert isinstance(store.client, plasma.PlasmaClient)
-
-
def test_redis_connect(setup_store, server_port_num):
store = RedisStoreInterface(server_port_num=server_port_num)
- assert isinstance(store.client, redis.Redis)
assert store.client.ping()
-def test_connect_incorrect_path(setup_plasma_store, set_store_loc):
- # TODO: shorter name???
- # TODO: passes, but refactor --- see comments
- store_loc = "asdf"
- # Handle exception thrown - assert name == 'CannotConnectToStoreInterfaceError'
- # and message == 'Cannot connect to store at {}'.format(str(store_loc))
- # with pytest.raises(Exception, match='CannotConnectToStoreInterfaceError') as cm:
- # store.connect_store(store_loc)
- # # Check that the exception thrown is a CannotConnectToStoreInterfaceError
- # raise Exception('Cannot connect to store: {0}'.format(e))
- with pytest.raises(CannotConnectToStoreInterfaceError) as e:
- store = PlasmaStoreInterface(store_loc=store_loc)
- store.connect_store(store_loc)
- # Check that the exception thrown is a CannotConnectToStoreInterfaceError
- assert e.value.message == "Cannot connect to store at {}".format(str(store_loc))
-
-
-def test_redis_connect_wrong_port(setup_store, server_port_num):
+def test_redis_connect_wrong_port(server_port_num):
bad_port_num = 1234
with pytest.raises(CannotConnectToStoreInterfaceError) as e:
RedisStoreInterface(server_port_num=bad_port_num)
assert e.value.message == "Cannot connect to store at {}".format(str(bad_port_num))
-def test_connect_none_path(setup_plasma_store):
- # BUT default should be store_loc = '/tmp/store' if not entered?
- store_loc = None
- # Handle exception thrown - assert name == 'CannotConnectToStoreInterfaceError'
- # and message == 'Cannot connect to store at {}'.format(str(store_loc))
- # with pytest.raises(Exception) as cm:
- # store.connnect_store(store_loc)
- # Check that the exception thrown is a CannotConnectToStoreInterfaceError
- # assert cm.exception.name == 'CannotConnectToStoreInterfaceError'
- # with pytest.raises(Exception, match='CannotConnectToStoreInterfaceError') as cm:
- # store.connect_store(store_loc)
- # Check that the exception thrown is a CannotConnectToStoreInterfaceError
- # raise Exception('Cannot connect to store: {0}'.format(e))
- with pytest.raises(CannotConnectToStoreInterfaceError) as e:
- store = PlasmaStoreInterface(store_loc=store_loc)
- store.connect_store(store_loc)
- # Check that the exception thrown is a CannotConnectToStoreInterfaceError
- assert e.value.message == "Cannot connect to store at {}".format(str(store_loc))
-
-
-# class StoreInterfaceGet(self):
-
-
# TODO: @pytest.parameterize...store.get and store.getID for diff datatypes,
-# pickleable and not, etc.
-# Check raises...CannotGetObjectError (object never stored)
def test_init_empty(setup_store, server_port_num):
store = StoreInterface(server_port_num=server_port_num)
# logger.info(store.client.config_get())
assert store.get_all() == []
-def test_plasma_init_empty(setup_plasma_store, set_store_loc):
- store = PlasmaStoreInterface(store_loc=set_store_loc)
- assert store.get_all() == {}
-
-
def test_is_csc_matrix_and_put(setup_store, server_port_num):
mat = csc_matrix((3, 4), dtype=np.int8)
store = StoreInterface(server_port_num=server_port_num)
@@ -109,21 +44,13 @@ def test_is_csc_matrix_and_put(setup_store, server_port_num):
assert isinstance(store.get(x), csc_matrix)
-def test_plasma_is_csc_matrix_and_put(setup_plasma_store, set_store_loc):
- mat = csc_matrix((3, 4), dtype=np.int8)
- store = PlasmaStoreInterface(store_loc=set_store_loc)
- x = store.put(mat, "matrix")
- assert isinstance(store.getID(x), csc_matrix)
-
-
-@pytest.mark.skip
def test_get_list_and_all(setup_store, server_port_num):
store = StoreInterface(server_port_num=server_port_num)
- # id = store.put(1, "one")
- # id2 = store.put(2, "two")
- # id3 = store.put(3, "three")
- assert [1, 2] == store.getList(["one", "two"])
- assert [1, 2, 3] == store.get_all()
+ id = store.put(1)
+ id2 = store.put(2)
+ store.put(3)
+ assert [1, 2] == store.get_list([id, id2])
+ assert [1, 2, 3] == sorted(store.get_all())
def test_reset(setup_store, server_port_num):
@@ -133,13 +60,6 @@ def test_reset(setup_store, server_port_num):
assert store.get(id) == 1
-def test_plasma_reset(setup_plasma_store, set_store_loc):
- store = PlasmaStoreInterface(store_loc=set_store_loc)
- store.reset()
- id = store.put(1, "one")
- assert store.get(id) == 1
-
-
def test_put_one(setup_store, server_port_num):
store = StoreInterface(server_port_num=server_port_num)
id = store.put(1)
@@ -152,27 +72,10 @@ def test_redis_put_one(setup_store, server_port_num):
assert 1 == store.get(key)
-def test_plasma_put_one(setup_plasma_store, set_store_loc):
- store = PlasmaStoreInterface(store_loc=set_store_loc)
- id = store.put(1, "one")
- assert 1 == store.get(id)
-
-
-@pytest.mark.skip(reason="Error not being raised")
-def test_put_twice(setup_store):
- # store = StoreInterface()
- with pytest.raises(PlasmaObjectExists) as e:
- # id = store.put(2, "two")
- # id2 = store.put(2, "two")
- pass
- # Check that the exception thrown is an PlasmaObjectExists
- assert e.value.message == "Object already exists. Meant to call replace?"
-
-
-def test_getOne(setup_store, server_port_num):
- store = StoreInterface(server_port_num=server_port_num)
- id = store.put(1)
- assert 1 == store.get(id)
+def test_redis_get_unknown_object(setup_store, server_port_num):
+ store = RedisStoreInterface(server_port_num=server_port_num)
+ with pytest.raises(ObjectNotFoundError):
+ store.get("unknown")
def test_redis_get_one(setup_store, server_port_num):
diff --git a/test/test_tui.py b/test/test_tui.py
index d45bccfe..38d8174d 100644
--- a/test/test_tui.py
+++ b/test/test_tui.py
@@ -1,13 +1,12 @@
import pytest
import time
import improv.tui as tui
+from improv.messaging import ActorSignalReplyMsg
import logging
import zmq.asyncio as zmq
from zmq import PUB, REP
from zmq.log.handlers import PUBHandler
-from test_nexus import ports
-
@pytest.fixture
def logger(ports):
@@ -17,6 +16,7 @@ def logger(ports):
logger.addHandler(zmq_log_handler)
yield logger
logger.removeHandler(zmq_log_handler)
+ zmq_log_handler.close()
@pytest.fixture
@@ -31,17 +31,17 @@ async def sockets(ports):
@pytest.fixture
async def app(ports):
- mock = tui.TUI(*ports)
+ mock = tui.TUI(*ports, testing=True)
yield mock
time.sleep(0.5)
-async def test_console_panel_receives_broadcast(app, sockets, logger):
+async def test_console_panel_receives_broadcast(app, sockets):
async with app.run_test() as pilot:
await sockets[1].send_string("received")
await pilot.pause(0.1)
console = pilot.app.get_widget_by_id("console")
- console.history[0] == "received"
+ assert console.history[0] == "received"
async def test_quit_from_socket(app, sockets):
@@ -59,10 +59,14 @@ async def test_log_panel_receives_logging(app, logger):
assert "test" in log_window.history[0]
-async def test_input_box_echoed_to_console(app):
+async def test_input_box_echoed_to_console(app, sockets):
async with app.run_test() as pilot:
await pilot.press(*"foo", "enter")
+ request = await sockets[0].recv_pyobj()
+ reply_obj = ActorSignalReplyMsg("TUI", "OK", "foo")
+ await sockets[0].send_pyobj(reply_obj)
console = pilot.app.get_widget_by_id("console")
+ assert request.signal == "foo"
assert console.history[0] == "foo"