diff --git a/.github/workflows/CI.yaml b/.github/workflows/CI.yaml
index 0c80283d..2d5fb25b 100644
--- a/.github/workflows/CI.yaml
+++ b/.github/workflows/CI.yaml
@@ -20,7 +20,7 @@ jobs:
max-parallel: 1
fail-fast: false
matrix:
- python-version: ["3.8", "3.9", "3.10"]
+ python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
os: [ubuntu-latest, macos-latest] # [ubuntu-latest, macos-latest, windows-latest]
steps:
@@ -65,7 +65,7 @@ jobs:
- name: Test with pytest
run: |
- python -m pytest --cov=improv
+ python -m pytest -x -s -l --cov=improv
- name: Coveralls
uses: coverallsapp/github-action@v2
diff --git a/.gitignore b/.gitignore
index b76c5666..bedf9d58 100644
--- a/.gitignore
+++ b/.gitignore
@@ -128,6 +128,9 @@ dmypy.json
arrow
.vscode/
+.idea/
*.code-workspace
improv/_version.py
+
+*venv
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 00000000..13566b81
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 00000000..105ce2da
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 00000000..69d3d2cc
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 00000000..35eb1ddf
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/demos/1p_caiman/1p_demo.yaml b/demos/1p_caiman/1p_demo.yaml
index 6f754860..e60e469a 100644
--- a/demos/1p_caiman/1p_demo.yaml
+++ b/demos/1p_caiman/1p_demo.yaml
@@ -34,6 +34,3 @@ connections:
Processor.q_out: [Analysis.q_in]
Analysis.q_out: [Visual.q_in]
InputStim.q_out: [Analysis.input_stim_queue]
-
-# settings:
-# use_watcher: [Acquirer, Processor, Visual, Analysis]
\ No newline at end of file
diff --git a/demos/basic/Behavior_demo.py b/demos/basic/Behavior_demo.py
index 8117189c..4a26aea4 100644
--- a/demos/basic/Behavior_demo.py
+++ b/demos/basic/Behavior_demo.py
@@ -7,7 +7,7 @@
loadFile = "./Behavior_demo.yaml"
nexus = Nexus("Nexus")
-nexus.createNexus(file=loadFile)
+nexus.create_nexus(file=loadFile)
# All modules needed have been imported
# so we can change the level of logging here
@@ -20,4 +20,4 @@
# logger = logging.getLogger("improv")
# logger.setLevel(logging.INFO)
-nexus.startNexus()
+nexus.start_nexus()
diff --git a/demos/basic/Behavior_demo.yaml b/demos/basic/Behavior_demo.yaml
index 8f105507..992f1528 100644
--- a/demos/basic/Behavior_demo.yaml
+++ b/demos/basic/Behavior_demo.yaml
@@ -1,6 +1,3 @@
-settings:
- use_watcher: [Acquirer, Processor, Visual, Analysis, Behavior, Motion]
-
actors:
GUI:
package: actors.visual
diff --git a/demos/basic/basic_demo.py b/demos/basic/basic_demo.py
index 953948ef..971c42fd 100644
--- a/demos/basic/basic_demo.py
+++ b/demos/basic/basic_demo.py
@@ -7,7 +7,7 @@
loadFile = "./basic_demo.yaml"
nexus = Nexus("Nexus")
-nexus.createNexus(file=loadFile)
+nexus.create_nexus(file=loadFile)
# All modules needed have been imported
# so we can change the level of logging here
@@ -20,4 +20,4 @@
# logger = logging.getLogger("improv")
# logger.setLevel(logging.INFO)
-nexus.startNexus()
+nexus.start_nexus()
diff --git a/demos/basic/basic_demo.yaml b/demos/basic/basic_demo.yaml
index b007147c..1736bd64 100644
--- a/demos/basic/basic_demo.yaml
+++ b/demos/basic/basic_demo.yaml
@@ -33,6 +33,3 @@ connections:
Processor.q_out: [Analysis.q_in]
Analysis.q_out: [Visual.q_in]
InputStim.q_out: [Analysis.input_stim_queue]
-
-# settings:
-# use_watcher: [Acquirer, Processor, Visual, Analysis]
\ No newline at end of file
diff --git a/demos/bubblewrap/actors/bubble.py b/demos/bubblewrap/actors/bubble.py
index c63c1e3c..5f4a33ea 100644
--- a/demos/bubblewrap/actors/bubble.py
+++ b/demos/bubblewrap/actors/bubble.py
@@ -55,7 +55,7 @@ def setup(self):
self.bw.init_nodes()
logger.info("Nodes initialized")
- self._getStoreInterface()
+ self._get_store_interface()
def runStep(self):
"""Observe new data from dim reduction and update bubblewrap"""
diff --git a/demos/live/live_demo.py b/demos/live/live_demo.py
index fff6839f..52f39fb8 100644
--- a/demos/live/live_demo.py
+++ b/demos/live/live_demo.py
@@ -7,7 +7,7 @@
loadFile = "./live_demo.yaml"
nexus = Nexus("Nexus")
-nexus.createNexus(file=loadFile)
+nexus.create_nexus(file=loadFile)
# All modules needed have been imported
# so we can change the level of logging here
@@ -20,4 +20,4 @@
# logger = logging.getLogger("improv")
# logger.setLevel(logging.INFO)
-nexus.startNexus()
+nexus.start_nexus()
diff --git a/demos/zmq/actors/zmq_ps_sample_generator.py b/demos/minimal/actors/sample_generator_zmq.py
similarity index 54%
rename from demos/zmq/actors/zmq_ps_sample_generator.py
rename to demos/minimal/actors/sample_generator_zmq.py
index 42b0afc4..585423f9 100644
--- a/demos/zmq/actors/zmq_ps_sample_generator.py
+++ b/demos/minimal/actors/sample_generator_zmq.py
@@ -1,15 +1,14 @@
+from improv.actor import ZmqActor
+from datetime import date # used for saving
import numpy as np
import logging
-from demos.sample_actors.zmqActor import ZmqActor
-
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class Generator(ZmqActor):
- """Sample actor to generate data to pass into a sample processor
- using sync ZMQ to communicate.
+ """Sample actor to generate data to pass into a sample processor.
Intended for use along with sample_processor.py.
"""
@@ -25,39 +24,53 @@ def __str__(self):
def setup(self):
"""Generates an array that serves as an initial source of data.
- Sets up a ZmqPSActor to send data to the processor.
Initial array is a 100 row, 5 column numpy matrix that contains
integers from 1-99, inclusive.
"""
- logger.info("Beginning setup for Generator")
self.data = np.asmatrix(np.random.randint(100, size=(100, 5)))
- logger.info("Completed setup for Generator")
+ self.improv_logger.info("Completed setup for Generator")
+
+ # def run(self):
+ # """ Send array into the store.
+ # """
+ # self.fcns = {}
+ # self.fcns['setup'] = self.setup
+ # self.fcns['run'] = self.runStep
+ # self.fcns['stop'] = self.stop
+
+ # with RunManager(self.name, self.fcns, self.links) as rm:
+ # logger.info(rm)
def stop(self):
"""Save current randint vector to a file."""
- logger.info("Generator stopping")
- np.save("sample_generator_data.npy", self.data)
+ self.improv_logger.info("Generator stopping")
+ np.save(f"sample_generator_data", self.data)
+ # This is not the best example of a save function,
+ # will overwrite previous files with the same name.
return 0
- def runStep(self):
+ def run_step(self):
"""Generates additional data after initial setup data is exhausted.
- Sends data to the processor using a ZmqPSActor.
Data is of a different form as the setup data in that although it is
the same size (5x1 vector), it is uniformly distributed in [1, 10]
instead of in [1, 100]. Therefore, the average over time should
converge to 5.5.
"""
+
if self.frame_num < np.shape(self.data)[0]:
- data_id = self.client.put(self.data[self.frame_num], str(f"Gen_raw: {self.frame_num}"))
+ data_id = self.client.put(self.data[self.frame_num])
try:
- self.put(data_id) #[data_id, str(self.frame_num)])
+ self.q_out.put(data_id)
+ # self.improv_logger.info(f"Sent {self.data[self.frame_num]} with key {data_id}")
self.frame_num += 1
+
except Exception as e:
- logger.error(f"---------Generator Exception: {e}")
+ self.improv_logger.error(f"Generator Exception: {e}")
else:
- new_data = np.asmatrix(np.random.randint(10, size=(1, 5)))
- self.data = np.concatenate((self.data, new_data), axis=0)
+ self.data = np.concatenate(
+ (self.data, np.asmatrix(np.random.randint(10, size=(1, 5)))), axis=0
+ )
diff --git a/demos/minimal/actors/sample_persistence_generator.py b/demos/minimal/actors/sample_persistence_generator.py
new file mode 100644
index 00000000..5f6afbf1
--- /dev/null
+++ b/demos/minimal/actors/sample_persistence_generator.py
@@ -0,0 +1,78 @@
+import random
+import time
+
+from improv.actor import ZmqActor
+from datetime import date # used for saving
+import numpy as np
+import logging
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
+
+class Generator(ZmqActor):
+ """Sample actor to generate data to pass into a sample processor.
+
+ Intended for use along with sample_processor.py.
+ """
+
+ def __init__(self, output_filename, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.data = None
+ self.name = "Generator"
+ self.frame_num = 0
+ self.output_filename = output_filename
+
+ def __str__(self):
+ return f"Name: {self.name}, Data: {self.data}"
+
+ def setup(self):
+ """Generates an array that serves as an initial source of data.
+
+ Initial array is a 100 row, 5 column numpy matrix that contains
+ integers from 1-99, inclusive.
+ """
+
+ self.data = np.asmatrix(np.random.randint(100, size=(100, 5)))
+ self.improv_logger.info("Completed setup for Generator")
+
+ def stop(self):
+ """Save current randint vector to a file."""
+
+ self.improv_logger.info("Generator stopping")
+ return 0
+
+ def run_step(self):
+ """Generates additional data after initial setup data is exhausted.
+
+ Data is of a different form as the setup data in that although it is
+ the same size (5x1 vector), it is uniformly distributed in [1, 10]
+ instead of in [1, 100]. Therefore, the average over time should
+ converge to 5.5.
+ """
+
+
+ device_time = time.time_ns()
+ time.sleep(0.0003)
+ acquired_time = time.time_ns() # mock the time the generator "actually received" the data
+
+ if self.frame_num < np.shape(self.data)[0]:
+
+ with open(self.output_filename, "a+") as f:
+ device_data = self.data[self.frame_num]
+ packaged_data = (acquired_time, (device_time, device_data))
+ # save the data to the flat file before sending it downstream
+ f.write(f"{packaged_data[0]}, {packaged_data[1][0]}, {packaged_data[1][1]}\n")
+
+ data_id = self.client.put(packaged_data)
+ try:
+ self.q_out.put(data_id)
+ # self.improv_logger.info(f"Sent {self.data[self.frame_num]} with key {data_id}")
+ self.frame_num += 1
+
+ except Exception as e:
+ self.improv_logger.error(f"Generator Exception: {e}")
+ else:
+ self.data = np.concatenate(
+ (self.data, np.asmatrix(np.random.randint(10, size=(1, 5)))), axis=0
+ )
diff --git a/demos/zmq/actors/zmq_ps_sample_processor.py b/demos/minimal/actors/sample_persistence_processor.py
similarity index 53%
rename from demos/zmq/actors/zmq_ps_sample_processor.py
rename to demos/minimal/actors/sample_persistence_processor.py
index f54c00c8..3a38d3fc 100644
--- a/demos/zmq/actors/zmq_ps_sample_processor.py
+++ b/demos/minimal/actors/sample_persistence_processor.py
@@ -1,44 +1,46 @@
+from improv.actor import ZmqActor
import numpy as np
import logging
-from demos.sample_actors.zmqActor import ZmqActor
-
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class Processor(ZmqActor):
- """Sample processor used to calculate the average of an array of integers
- using sync ZMQ to communicate.
+ """Sample processor used to calculate the average of an array of integers.
Intended for use with sample_generator.py.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
+ if "name" in kwargs:
+ self.name = kwargs["name"]
def setup(self):
"""Initializes all class variables.
- Sets up a ZmqRRActor to receive data from the generator.
self.name (string): name of the actor.
- self.frame (ObjectID): Store object id referencing data from the store.
+ self.frame (ObjectID): StoreInterface object id referencing data from the store.
self.avg_list (list): list that contains averages of individual vectors.
self.frame_num (int): index of current frame.
"""
- self.name = "Processor"
+
+ if not hasattr(self, "name"):
+ self.name = "Processor"
self.frame = None
self.avg_list = []
self.frame_num = 1
- logger.info("Completed setup for Processor")
+ self.improv_logger.info("Completed setup for Processor")
def stop(self):
"""Trivial stop function for testing purposes."""
- logger.info("Processor stopping; have received {} frames so far".format(self.frame_num))
- def runStep(self):
+ self.improv_logger.info("Processor stopping")
+ return 0
+
+ def run_step(self):
"""Gets from the input queue and calculates the average.
- Receives data from the generator using a ZmqRRActor.
Receives an ObjectID, references data in the store using that
ObjectID, calculates the average of that data, and finally prints
@@ -47,19 +49,18 @@ def runStep(self):
frame = None
try:
- frame = self.get()
-
- except:
- logger.error("Could not get frame!")
+ frame = self.q_in.get(timeout=0.05)
+ except Exception as e:
+ # logger.error(f"{self.name} could not get frame! At {self.frame_num}: {e}")
pass
- if frame is not None:
+ if frame is not None and self.frame_num is not None:
self.done = False
- self.frame = self.client.getID(frame)
- avg = np.mean(self.frame[0])
-
- # logger.info(f"Average: {avg}")
+ self.frame = self.client.get(frame)
+ device_data = self.frame[1][1]
+ avg = np.mean(device_data)
+ # self.improv_logger.info(f"Average: {avg}")
self.avg_list.append(avg)
- logger.info(f"Overall Average: {np.mean(self.avg_list)}")
- # logger.info(f"Frame number: {self.frame_num}")
+ # self.improv_logger.info(f"Overall Average: {np.mean(self.avg_list)}")
+ # self.improv_logger.info(f"Frame number: {self.frame_num}")
self.frame_num += 1
diff --git a/demos/minimal/actors/sample_processor.py b/demos/minimal/actors/sample_processor_zmq.py
similarity index 64%
rename from demos/minimal/actors/sample_processor.py
rename to demos/minimal/actors/sample_processor_zmq.py
index 40fbf4d5..3ff9cd25 100644
--- a/demos/minimal/actors/sample_processor.py
+++ b/demos/minimal/actors/sample_processor_zmq.py
@@ -1,4 +1,4 @@
-from improv.actor import Actor
+from improv.actor import ZmqActor
import numpy as np
import logging
@@ -6,7 +6,7 @@
logger.setLevel(logging.INFO)
-class Processor(Actor):
+class Processor(ZmqActor):
"""Sample processor used to calculate the average of an array of integers.
Intended for use with sample_generator.py.
@@ -14,6 +14,8 @@ class Processor(Actor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
+ if "name" in kwargs:
+ self.name = kwargs["name"]
def setup(self):
"""Initializes all class variables.
@@ -24,18 +26,20 @@ def setup(self):
self.frame_num (int): index of current frame.
"""
- self.name = "Processor"
+ if not hasattr(self, "name"):
+ self.name = "Processor"
self.frame = None
self.avg_list = []
self.frame_num = 1
- logger.info("Completed setup for Processor")
+ self.improv_logger.info("Completed setup for Processor")
def stop(self):
"""Trivial stop function for testing purposes."""
- logger.info("Processor stopping")
+ self.improv_logger.info("Processor stopping")
+ return 0
- def runStep(self):
+ def run_step(self):
"""Gets from the input queue and calculates the average.
Receives an ObjectID, references data in the store using that
@@ -46,22 +50,16 @@ def runStep(self):
frame = None
try:
frame = self.q_in.get(timeout=0.05)
-
- except Exception:
- logger.error("Could not get frame!")
+ except Exception as e:
+ # logger.error(f"{self.name} could not get frame! At {self.frame_num}: {e}")
pass
if frame is not None and self.frame_num is not None:
self.done = False
- if self.store_loc:
- self.frame = self.client.getID(frame[0][0])
- else:
- self.frame = self.client.get(frame)
+ self.frame = self.client.get(frame)
avg = np.mean(self.frame[0])
-
- logger.info(f"Average: {avg}")
+ # self.improv_logger.info(f"Average: {avg}")
self.avg_list.append(avg)
- logger.info(f"Overall Average: {np.mean(self.avg_list)}")
- logger.info(f"Frame number: {self.frame_num}")
-
+ # self.improv_logger.info(f"Overall Average: {np.mean(self.avg_list)}")
+ # self.improv_logger.info(f"Frame number: {self.frame_num}")
self.frame_num += 1
diff --git a/demos/minimal/actors/sample_spawn_processor.py b/demos/minimal/actors/sample_spawn_processor.py
index cdc8bbb7..0f8962ee 100644
--- a/demos/minimal/actors/sample_spawn_processor.py
+++ b/demos/minimal/actors/sample_spawn_processor.py
@@ -1,4 +1,4 @@
-from improv.actor import Actor
+from improv.actor import ZmqActor
import numpy as np
from queue import Empty
import logging
@@ -7,7 +7,7 @@
logger.setLevel(logging.INFO)
-class Processor(Actor):
+class Processor(ZmqActor):
"""Sample processor used to calculate the average of an array of integers.
Intended for use with sample_generator.py.
@@ -31,7 +31,7 @@ def setup(self):
self.frame_num = 1
logger.info("Completed setup for Processor")
- self._getStoreInterface()
+ self._get_store_interface()
def stop(self):
"""Trivial stop function for testing purposes."""
diff --git a/demos/minimal/minimal.yaml b/demos/minimal/minimal.yaml
index c8e0b24e..d752b415 100644
--- a/demos/minimal/minimal.yaml
+++ b/demos/minimal/minimal.yaml
@@ -1,14 +1,15 @@
actors:
Generator:
- package: actors.sample_generator
+ package: actors.sample_generator_zmq
class: Generator
Processor:
- package: actors.sample_processor
+ package: actors.sample_processor_zmq
class: Processor
connections:
- Generator.q_out: [Processor.q_in]
-
-redis_config:
- port: 6379
\ No newline at end of file
+ Generator-Processor:
+ sources:
+ - Generator.q_out
+ sinks:
+ - Processor.q_in
diff --git a/demos/minimal/minimal_persistence.yaml b/demos/minimal/minimal_persistence.yaml
new file mode 100644
index 00000000..65226762
--- /dev/null
+++ b/demos/minimal/minimal_persistence.yaml
@@ -0,0 +1,16 @@
+actors:
+ Generator:
+ package: actors.sample_persistence_generator
+ class: Generator
+ output_filename: test_persistence.csv
+
+ Processor:
+ package: actors.sample_persistence_processor
+ class: Processor
+
+connections:
+ Generator-Processor:
+ sources:
+ - Generator.q_out
+ sinks:
+ - Processor.q_in
diff --git a/demos/minimal/minimal_plasma.yaml b/demos/minimal/minimal_plasma.yaml
deleted file mode 100644
index ab869678..00000000
--- a/demos/minimal/minimal_plasma.yaml
+++ /dev/null
@@ -1,13 +0,0 @@
-actors:
- Generator:
- package: actors.sample_generator
- class: Generator
-
- Processor:
- package: actors.sample_processor
- class: Processor
-
-connections:
- Generator.q_out: [Processor.q_in]
-
-plasma_config:
\ No newline at end of file
diff --git a/demos/minimal/minimal_spawn.yaml b/demos/minimal/minimal_spawn.yaml
index 6f22e653..7baab190 100644
--- a/demos/minimal/minimal_spawn.yaml
+++ b/demos/minimal/minimal_spawn.yaml
@@ -1,6 +1,6 @@
actors:
Generator:
- package: actors.sample_generator
+ package: actors.sample_generator_zmq
class: Generator
Processor:
@@ -9,7 +9,8 @@ actors:
method: spawn
connections:
- Generator.q_out: [Processor.q_in]
-
-redis_config:
- port: 6378
\ No newline at end of file
+ Generator-Processor:
+ sources:
+ - Generator.q_out
+ sinks:
+ - Processor.q_in
\ No newline at end of file
diff --git a/demos/neurofinder/neurofind_demo.py b/demos/neurofinder/neurofind_demo.py
index 5f49399c..ad327d73 100644
--- a/demos/neurofinder/neurofind_demo.py
+++ b/demos/neurofinder/neurofind_demo.py
@@ -7,7 +7,7 @@
loadFile = "./neurofind_demo.yaml"
nexus = Nexus("Nexus")
-nexus.createNexus(file=loadFile)
+nexus.create_nexus(file=loadFile)
# All modules needed have been imported
# so we can change the level of logging here
@@ -20,4 +20,4 @@
# logger = logging.getLogger("improv")
# logger.setLevel(logging.INFO)
-nexus.startNexus()
+nexus.start_nexus()
diff --git a/demos/neurofinder/neurofind_demo.yaml b/demos/neurofinder/neurofind_demo.yaml
index eb7796bc..483f6dc8 100644
--- a/demos/neurofinder/neurofind_demo.yaml
+++ b/demos/neurofinder/neurofind_demo.yaml
@@ -1,6 +1,3 @@
-settings:
- use_watcher: [Acquirer, Processor, Visual, Analysis]
-
actors:
GUI:
package: demos.basic.actors.visual
diff --git a/demos/sample_actors/acquire.py b/demos/sample_actors/acquire.py
index 28dfb4f7..c355ed43 100644
--- a/demos/sample_actors/acquire.py
+++ b/demos/sample_actors/acquire.py
@@ -5,7 +5,7 @@
import numpy as np
from skimage.io import imread
-from improv.actor import Actor
+from improv.actor import ZmqActor
import logging
@@ -15,7 +15,7 @@
# Classes: File Acquirer, Stim, Behavior, Tiff
-class FileAcquirer(Actor):
+class FileAcquirer(ZmqActor):
"""Class to import data from files and output
frames in a buffer, or discrete.
"""
@@ -109,7 +109,7 @@ def saveFrame(self, frame):
self.f.flush()
-class StimAcquirer(Actor):
+class StimAcquirer(ZmqActor):
"""Class to load visual stimuli data from file
and stream into the pipeline
"""
@@ -149,7 +149,7 @@ def runStep(self):
self.n += 1
-class BehaviorAcquirer(Actor):
+class BehaviorAcquirer(ZmqActor):
"""Actor that acquires information of behavioral stimulus
during the experiment
@@ -188,7 +188,7 @@ def runStep(self):
self.n += 1
-class FileStim(Actor):
+class FileStim(ZmqActor):
"""Actor that acquires information of behavioral stimulus
during the experiment from a file
"""
@@ -216,7 +216,7 @@ def runStep(self):
self.n += 1
-class TiffAcquirer(Actor):
+class TiffAcquirer(ZmqActor):
"""Loops through a TIF file."""
def __init__(self, *args, filename=None, framerate=30, **kwargs):
diff --git a/demos/sample_actors/analysis.py b/demos/sample_actors/analysis.py
index 664ddaf3..854e7a50 100644
--- a/demos/sample_actors/analysis.py
+++ b/demos/sample_actors/analysis.py
@@ -4,7 +4,7 @@
from queue import Empty
import os
-from improv.actor import Actor, RunManager
+from improv.actor import ZmqActor
from improv.store import ObjectNotFoundError
import logging
@@ -13,7 +13,7 @@
logger.setLevel(logging.INFO)
-class MeanAnalysis(Actor):
+class MeanAnalysis(ZmqActor):
# TODO: Add additional error handling
# TODO: this is too complex for a sample actor?
def __init__(self, *args, **kwargs):
diff --git a/demos/sample_actors/simple_analysis.py b/demos/sample_actors/simple_analysis.py
index ff350b58..78890a43 100644
--- a/demos/sample_actors/simple_analysis.py
+++ b/demos/sample_actors/simple_analysis.py
@@ -3,7 +3,7 @@
from queue import Empty
import os
-from improv.actor import Actor, RunManager
+from improv.actor import ZmqActor
from improv.store import ObjectNotFoundError
import logging
@@ -12,7 +12,7 @@
logger.setLevel(logging.INFO)
-class SimpleAnalysis(Actor):
+class SimpleAnalysis(ZmqActor):
# TODO: Add additional error handling
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
diff --git a/demos/sample_actors/zmqActor.py b/demos/sample_actors/zmqActor.py
deleted file mode 100644
index 6455b0b4..00000000
--- a/demos/sample_actors/zmqActor.py
+++ /dev/null
@@ -1,206 +0,0 @@
-import asyncio
-import time
-
-import zmq
-from zmq import (
- PUB,
- SUB,
- SUBSCRIBE,
- REQ,
- REP,
- LINGER,
- Again,
- NOBLOCK,
- ZMQError,
- EAGAIN,
- ETERM,
-)
-from zmq.log.handlers import PUBHandler
-import zmq.asyncio
-
-from improv.actor import Actor
-
-import logging
-
-logger = logging.getLogger(__name__)
-logger.setLevel(logging.INFO)
-
-
-class ZmqActor(Actor):
- """
- Zmq actor with pub/sub or rep/req pattern.
- """
-
- def __init__(self, *args, type="PUB", ip="127.0.0.1", port=5555, **kwargs):
- super().__init__(*args, **kwargs)
- logger.info("Constructed Zmq Actor")
- if str(type) in "PUB" or str(type) in "SUB":
- self.pub_sub_flag = True # default
- else:
- self.pub_sub_flag = False
- self.rep_req_flag = not self.pub_sub_flag
- self.ip = ip
- self.port = port
- self.address = "tcp://{}:{}".format(self.ip, self.port)
-
- self.send_socket = None
- self.recv_socket = None
- self.req_socket = None
- self.rep_socket = None
-
- self.context = zmq.Context.instance()
-
- def sendMsg(self, msg, msg_type="pyobj"):
- """
- Sends a message to the controller.
- """
- if not self.send_socket:
- self.setSendSocket()
-
- if msg_type == "multipart":
- self.send_socket.send_multipart(msg)
- if msg_type == "pyobj":
- self.send_socket.send_pyobj(msg)
- elif msg_type == "single":
- self.send_socket.send(msg)
-
- def recvMsg(self, msg_type="pyobj", flags=0):
- """
- Receives a message from the controller.
-
- NOTE: default flag=0 instead of flag=NOBLOCK
- """
- if not self.recv_socket:
- self.setRecvSocket()
-
- while True:
- try:
- if msg_type == "multipart":
- recv_msg = self.recv_socket.recv_multipart(flags=flags)
- elif msg_type == "pyobj":
- recv_msg = self.recv_socket.recv_pyobj(flags=flags)
- elif msg_type == "single":
- recv_msg = self.recv_socket.recv(flags=flags)
- break
- except Again:
- pass
- except ZMQError as e:
- logger.info(f"ZMQ error: {e}")
- if e.errno == ETERM:
- pass # interrupted - pass or break if in try loop
- if e.errno == EAGAIN:
- pass # no message was ready (yet!)
-
- # self.recv_socket.close()
- return recv_msg
-
- def requestMsg(self, msg):
- """Safe version of send/receive with controller.
- Based on the Lazy Pirate pattern [here]
- (https://zguide.zeromq.org/docs/chapter4/#Client-Side-Reliability-Lazy-Pirate-Pattern)
- """
- REQUEST_TIMEOUT = 500
- REQUEST_RETRIES = 3
- retries_left = REQUEST_RETRIES
-
- self.setReqSocket()
-
- reply = None
- try:
- logger.debug(f"Sending {msg} to controller.")
- self.req_socket.send_pyobj(msg)
-
- while True:
- ready = self.req_socket.poll(REQUEST_TIMEOUT)
-
- if ready:
- reply = self.req_socket.recv_pyobj()
- logger.debug(f"Received {reply} from controller.")
- break
- else:
- retries_left -= 1
- logger.debug("No response from server.")
-
- # try to close and reconnect
- self.req_socket.setsockopt(LINGER, 0)
- self.req_socket.close()
- if retries_left == 0:
- logger.debug("Server seems to be offline. Giving up.")
- break
-
- logger.debug("Attempting to reconnect to server...")
-
- self.setReqSocket()
-
- logger.debug(f"Resending {msg} to controller.")
- self.req_socket.send_pyobj(msg)
-
- except asyncio.CancelledError:
- pass
-
- self.req_socket.close()
- return reply
-
- def replyMsg(self, reply, delay=0.0001):
- """
- Safe version of receive/reply with controller.
- """
-
- self.setRepSocket()
-
- msg = self.rep_socket.recv_pyobj()
- time.sleep(delay)
- self.rep_socket.send_pyobj(reply)
- self.rep_socket.close()
-
- return msg
-
- def put(self, msg=None):
- logger.debug(f"Putting message {msg}")
- if self.pub_sub_flag:
- logger.debug(f"putting message {msg} using pub/sub")
- return self.sendMsg(msg)
- else:
- logger.debug(f"putting message {msg} using rep/req")
- return self.requestMsg(msg)
-
- def get(self, reply=None):
- if self.pub_sub_flag:
- logger.debug(f"getting message with pub/sub")
- return self.recvMsg()
- else:
- logger.debug(f"getting message using reply {reply} with pub/sub")
- return self.replyMsg(reply)
-
- def setSendSocket(self, timeout=1.001):
- """
- Sets up the send socket for the actor.
- """
- self.send_socket = self.context.socket(PUB)
- self.send_socket.bind(self.address)
- time.sleep(timeout)
-
- def setRecvSocket(self, timeout=1.001):
- """
- Sets up the receive socket for the actor.
- """
- self.recv_socket = self.context.socket(SUB)
- self.recv_socket.connect(self.address)
- self.recv_socket.setsockopt(SUBSCRIBE, b"")
- time.sleep(timeout)
-
- def setReqSocket(self, timeout=0.0001):
- """
- Sets up the request socket for the actor.
- """
- self.req_socket = self.context.socket(REQ)
- self.req_socket.connect(self.address)
- time.sleep(timeout)
-
- def setRepSocket(self, timeout=0.0001):
- """
- Sets up the reply socket for the actor.
- """
- self.rep_socket = self.context.socket(REP)
- self.rep_socket.bind(self.address)
- time.sleep(timeout)
diff --git a/demos/spike/spike_demo.py b/demos/spike/spike_demo.py
index ca4f110d..24a37b04 100644
--- a/demos/spike/spike_demo.py
+++ b/demos/spike/spike_demo.py
@@ -7,7 +7,7 @@
loadFile = "./spike_demo.yaml"
nexus = Nexus("Nexus")
-nexus.createNexus(file=loadFile)
+nexus.create_nexus(file=loadFile)
# All modules needed have been imported
# so we can change the level of logging here
@@ -20,4 +20,4 @@
# logger = logging.getLogger("improv")
# logger.setLevel(logging.INFO)
-nexus.startNexus()
+nexus.start_nexus()
diff --git a/demos/spike/spike_demo.yaml b/demos/spike/spike_demo.yaml
index 6aad9992..8b46f7a7 100644
--- a/demos/spike/spike_demo.yaml
+++ b/demos/spike/spike_demo.yaml
@@ -1,6 +1,3 @@
-# settings:
-# use_watcher: [Acquirer, Analysis]
-
actors:
# GUI:
# package: actors.visual
diff --git a/demos/zmq/README.md b/demos/zmq/README.md
deleted file mode 100644
index f02edc26..00000000
--- a/demos/zmq/README.md
+++ /dev/null
@@ -1,14 +0,0 @@
-# ZMQ Demo
-
-This demo is intended to show how actors can communicate using zmq. There are two options for the demo, each corresponding to a different config file. The difference between these two options is in how they send and receive messages. One uses the publish/send concept, while the other uses the request/reply concept.
-
-## Instructions
-
-Just like any other demo, you can run `improv run ` in order to run the demo. For example, from within the `improv/demos/zmq` directory, running `improv run zmq_rr_demo.yaml` will run the zmq_rr demo.
-
-After running the demo, a tui (text user interface) should show up. From here, we can type `setup` followed by `run` to run the demo. After we are done, we can type `stop` to pause, or `quit` to exit.
-
-## Expected Output
-
-Currently, only logging output is supported. There will be no live output during the run.
-
diff --git a/demos/zmq/actors/zmq_rr_sample_generator.py b/demos/zmq/actors/zmq_rr_sample_generator.py
deleted file mode 100644
index ef941ae6..00000000
--- a/demos/zmq/actors/zmq_rr_sample_generator.py
+++ /dev/null
@@ -1,59 +0,0 @@
-import numpy as np
-import logging
-
-from demos.sample_actors.zmqActor import ZmqActor
-
-logger = logging.getLogger(__name__)
-logger.setLevel(logging.INFO)
-
-class Generator(ZmqActor):
- """Sample actor to generate data to pass into a sample processor
- using async ZMQ to communicate.
-
- Intended for use along with sample_processor.py.
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.data = None
- self.name = "Generator"
- self.frame_num = 0
-
- def __str__(self):
- return f"Name: {self.name}, Data: {self.data}"
-
- def setup(self):
- """Generates an array that serves as an initial source of data.
- Sets up a ZmqRRActor to send data to the processor.
-
- Initial array is a 100 row, 5 column numpy matrix that contains
- integers from 1-99, inclusive.
- """
- logger.info("Beginning setup for Generator")
- self.data = np.asmatrix(np.random.randint(100, size=(100, 5)))
- logger.info("Completed setup for Generator")
-
- def stop(self):
- """Save current randint vector to a file."""
- logger.info("Generator stopping")
- np.save("sample_generator_data.npy", self.data)
- return 0
-
- def runStep(self):
- """Generates additional data after initial setup data is exhausted.
- Sends data to the processor using a ZmqRRActor.
-
- Data is of a different form as the setup data in that although it is
- the same size (5x1 vector), it is uniformly distributed in [1, 10]
- instead of in [1, 100]. Therefore, the average over time should
- converge to 5.5.
- """
- if self.frame_num < np.shape(self.data)[0]:
- data_id = self.client.put(self.data[self.frame_num], str(f"Gen_raw_{self.frame_num}"))
- try:
- self.put(data_id)
- self.frame_num += 1
- except Exception as e:
- logger.error(f"Generator Exception: {e}")
- else:
- self.data = np.concatenate((self.data, np.asmatrix(np.random.randint(10, size=(1, 5)))), axis=0)
diff --git a/demos/zmq/config/zmq_ps_demo.yaml b/demos/zmq/config/zmq_ps_demo.yaml
deleted file mode 100644
index 230282d1..00000000
--- a/demos/zmq/config/zmq_ps_demo.yaml
+++ /dev/null
@@ -1,11 +0,0 @@
-actors:
- Generator:
- package: actors.sample_generator
- class: Generator
-
- Processor:
- package: actors.sample_processor
- class: Processor
-
-connections:
- Generator.q_out: [Processor.q_in]
\ No newline at end of file
diff --git a/demos/zmq/config/zmq_rr_demo.yaml b/demos/zmq/config/zmq_rr_demo.yaml
deleted file mode 100644
index 230282d1..00000000
--- a/demos/zmq/config/zmq_rr_demo.yaml
+++ /dev/null
@@ -1,11 +0,0 @@
-actors:
- Generator:
- package: actors.sample_generator
- class: Generator
-
- Processor:
- package: actors.sample_processor
- class: Processor
-
-connections:
- Generator.q_out: [Processor.q_in]
\ No newline at end of file
diff --git a/demos/zmq/zmq_ps_demo.yaml b/demos/zmq/zmq_ps_demo.yaml
deleted file mode 100644
index 63fca91c..00000000
--- a/demos/zmq/zmq_ps_demo.yaml
+++ /dev/null
@@ -1,17 +0,0 @@
-actors:
- Generator:
- package: actors.zmq_ps_sample_generator
- class: Generator
- ip: 127.0.0.1
- port: 5555
- type: PUB
-
- Processor:
- package: actors.zmq_ps_sample_processor
- class: Processor
- ip: 127.0.0.1
- port: 5555
- type: SUB
-
-connections:
- Generator.q_out: [Processor.q_in]
diff --git a/demos/zmq/zmq_rr_demo.yaml b/demos/zmq/zmq_rr_demo.yaml
deleted file mode 100644
index 585cebd1..00000000
--- a/demos/zmq/zmq_rr_demo.yaml
+++ /dev/null
@@ -1,17 +0,0 @@
-actors:
- Generator:
- package: actors.zmq_rr_sample_generator
- class: Generator
- ip: 127.0.0.1
- port: 5555
- type: REQ
-
- Processor:
- package: actors.zmq_rr_sample_processor
- class: Processor
- ip: 127.0.0.1
- port: 5555
- type: REP
-
-connections:
- Generator.q_out: [Processor.q_in]
diff --git a/docs/running.ipynb b/docs/running.ipynb
index 47b7b2b0..2ed13a09 100644
--- a/docs/running.ipynb
+++ b/docs/running.ipynb
@@ -19,10 +19,10 @@
"languageId": "shellscript"
}
},
- "outputs": [],
"source": [
"!improv --help"
- ]
+ ],
+ "outputs": []
},
{
"cell_type": "markdown",
@@ -43,10 +43,10 @@
"languageId": "shellscript"
}
},
- "outputs": [],
"source": [
"!improv run --help"
- ]
+ ],
+ "outputs": []
},
{
"cell_type": "markdown",
@@ -110,40 +110,10 @@
"languageId": "shellscript"
}
},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "usage: improv server [-h] [-c CONTROL_PORT] [-o OUTPUT_PORT] [-l LOGGING_PORT]\n",
- " [-f LOGFILE] [-a ACTOR_PATH]\n",
- " configfile\n",
- "\n",
- "Start the improv server\n",
- "\n",
- "positional arguments:\n",
- " configfile YAML file specifying improv pipeline\n",
- "\n",
- "options:\n",
- " -h, --help show this help message and exit\n",
- " -c CONTROL_PORT, --control-port CONTROL_PORT\n",
- " local port on which control signals are received\n",
- " -o OUTPUT_PORT, --output-port OUTPUT_PORT\n",
- " local port on which output messages are broadcast\n",
- " -l LOGGING_PORT, --logging-port LOGGING_PORT\n",
- " local port on which logging messages are broadcast\n",
- " -f LOGFILE, --logfile LOGFILE\n",
- " name of log file\n",
- " -a ACTOR_PATH, --actor-path ACTOR_PATH\n",
- " search path to add to sys.path when looking for\n",
- " actors; defaults to the directory containing\n",
- " configfile\n"
- ]
- }
- ],
"source": [
"!improv server --help"
- ]
+ ],
+ "outputs": []
},
{
"cell_type": "markdown",
@@ -160,30 +130,10 @@
"languageId": "shellscript"
}
},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "usage: improv client [-h] [-c CONTROL_PORT] [-s SERVER_PORT] [-l LOGGING_PORT]\n",
- "\n",
- "Start the improv client\n",
- "\n",
- "options:\n",
- " -h, --help show this help message and exit\n",
- " -c CONTROL_PORT, --control-port CONTROL_PORT\n",
- " address on which control signals are sent to the\n",
- " server\n",
- " -s SERVER_PORT, --server-port SERVER_PORT\n",
- " address on which messages from the server are received\n",
- " -l LOGGING_PORT, --logging-port LOGGING_PORT\n",
- " address on which logging messages are broadcast\n"
- ]
- }
- ],
"source": [
"!improv client --help"
- ]
+ ],
+ "outputs": []
},
{
"cell_type": "markdown",
diff --git a/improv/actor.py b/improv/actor.py
index 3d6a2a23..09dd3b91 100644
--- a/improv/actor.py
+++ b/improv/actor.py
@@ -1,10 +1,17 @@
+from __future__ import annotations
+
import time
import signal
import asyncio
import traceback
from queue import Empty
-import improv.store
+import zmq
+from zmq import SocketOption
+
+from improv import log
+from improv.link import ZmqLink
+from improv.messaging import ActorStateMsg, ActorStateReplyMsg, ActorSignalReplyMsg
from improv.store import StoreInterface
import logging
@@ -13,6 +20,17 @@
logger.setLevel(logging.INFO)
+# TODO construct log handler within post-fork setup based on
+# TODO arguments sent to actor by nexus
+# TODO and add it below; fine to leave logger here as-is
+
+
+class LinkInfo:
+ def __init__(self, link_name, link_topic):
+ self.link_name = link_name
+ self.link_topic = link_topic
+
+
class AbstractActor:
"""Base class for an actor that Nexus
controls and interacts with.
@@ -20,9 +38,7 @@ class AbstractActor:
Also needs to be responsive to sent Signals (e.g. run, setup, etc)
"""
- def __init__(
- self, name, store_loc=None, method="fork", store_port_num=None, *args, **kwargs
- ):
+ def __init__(self, name, method="fork", store_port_num=None, *args, **kwargs):
"""Require a name for multiple instances of the same actor/class
Create initial empty dict of Links for easier referencing
"""
@@ -31,8 +47,8 @@ def __init__(
self.links = {}
self.method = method
self.client = None
- self.store_loc = store_loc
self.lower_priority = False
+ self.improv_logger = None
self.store_port_num = store_port_num
# Start with no explicit data queues.
@@ -49,7 +65,7 @@ def __repr__(self):
"""
return self.name + ": " + str(self.links.keys())
- def setStoreInterface(self, client):
+ def set_store_interface(self, client):
"""Sets the client interface to the store
Args:
@@ -57,17 +73,14 @@ def setStoreInterface(self, client):
"""
self.client = client
- def _getStoreInterface(self):
+ def _get_store_interface(self):
# TODO: Where do we require this be run? Add a Signal and include in RM?
if not self.client:
store = None
- if StoreInterface == improv.store.RedisStoreInterface:
- store = StoreInterface(self.name, self.store_port_num)
- else:
- store = StoreInterface(self.name, self.store_loc)
- self.setStoreInterface(store)
+ store = StoreInterface(self.name, self.store_port_num)
+ self.set_store_interface(store)
- def setLinks(self, links):
+ def set_links(self, links):
"""General full dict set for links
Args:
@@ -75,7 +88,7 @@ def setLinks(self, links):
"""
self.links = links
- def setCommLinks(self, q_comm, q_sig):
+ def set_comm_links(self, q_comm, q_sig):
"""Set explicit communication links to/from Nexus (q_comm, q_sig)
Args:
@@ -86,7 +99,7 @@ def setCommLinks(self, q_comm, q_sig):
self.q_sig = q_sig
self.links.update({"q_comm": self.q_comm, "q_sig": self.q_sig})
- def setLinkIn(self, q_in):
+ def set_link_in(self, q_in):
"""Set the dedicated input queue
Args:
@@ -95,7 +108,7 @@ def setLinkIn(self, q_in):
self.q_in = q_in
self.links.update({"q_in": self.q_in})
- def setLinkOut(self, q_out):
+ def set_link_out(self, q_out):
"""Set the dedicated output queue
Args:
@@ -104,7 +117,7 @@ def setLinkOut(self, q_out):
self.q_out = q_out
self.links.update({"q_out": self.q_out})
- def setLinkWatch(self, q_watch):
+ def set_link_watch(self, q_watch):
"""Set the dedicated watchout queue
Args:
@@ -113,7 +126,7 @@ def setLinkWatch(self, q_watch):
self.q_watchout = q_watch
self.links.update({"q_watchout": self.q_watchout})
- def addLink(self, name, link):
+ def add_link(self, name, link):
"""Function provided to add additional data links by name
using same form as q_in or q_out
Must be done during registration and not during run
@@ -126,7 +139,7 @@ def addLink(self, name, link):
# User can then use: self.my_queue = self.links['my_queue'] in a setup fcn,
# or continue to reference it using self.links['my_queue']
- def getLinks(self):
+ def get_links(self):
"""Returns dictionary of links for the current actor
Returns:
@@ -134,24 +147,6 @@ def getLinks(self):
"""
return self.links
- def put(self, idnames, q_out=None, save=None):
- """TODO: This is deprecated? Prefer using Links explicitly"""
- if save is None:
- save = [False] * len(idnames)
-
- if len(save) < len(idnames):
- save = save + [False] * (len(idnames) - len(save))
-
- if q_out is None:
- q_out = self.q_out
-
- q_out.put(idnames)
-
- for i in range(len(idnames)):
- if save[i]:
- if self.q_watchout:
- self.q_watchout.put(idnames[i])
-
def setup(self):
"""Essenitally the registration process
Can also be an initialization for the actor
@@ -174,7 +169,7 @@ def stop(self):
"""
pass
- def changePriority(self):
+ def change_priority(self):
"""Try to lower this process' priority
Only changes priority if lower_priority is set
TODO: Only works on unix machines. Add Windows functionality
@@ -188,6 +183,18 @@ def changePriority(self):
logger.info("Lowered priority of this process: {}".format(self.name))
print("Lowered ", os.getpid(), " for ", self.name)
+ def register_with_nexus(self):
+ pass
+
+ def register_with_broker(self):
+ pass
+
+ def setup_links(self):
+ pass
+
+ def setup_logging(self):
+ raise NotImplementedError
+
class ManagedActor(AbstractActor):
def __init__(self, *args, **kwargs):
@@ -196,17 +203,146 @@ def __init__(self, *args, **kwargs):
# Define dictionary of actions for the RunManager
self.actions = {}
self.actions["setup"] = self.setup
- self.actions["run"] = self.runStep
+ self.actions["run"] = self.run_step
self.actions["stop"] = self.stop
+ self.nexus_sig_port: int = None
def run(self):
- with RunManager(self.name, self.actions, self.links):
+ self.register_with_nexus()
+ self.setup_logging()
+ self.register_with_broker()
+ self.setup_links()
+ with RunManager(
+ self.name, self.actions, self.links, self.nexus_sig_port, self.improv_logger
+ ):
pass
- def runStep(self):
+ def run_step(self):
raise NotImplementedError
+class ZmqActor(ManagedActor):
+ def __init__(
+ self,
+ nexus_comm_port,
+ broker_sub_port,
+ broker_pub_port,
+ log_pull_port,
+ outgoing_links,
+ incoming_links,
+ broker_host="localhost",
+ log_host="localhost",
+ nexus_host="localhost",
+ *args,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+ self.broker_pub_socket: zmq.Socket | None = None
+ self.zmq_context: zmq.Context | None = None
+ self.nexus_host: str = nexus_host
+ self.nexus_comm_port: int = nexus_comm_port
+ self.nexus_sig_port: int | None = None
+ self.nexus_comm_socket: zmq.Socket | None = None
+ self.nexus_sig_socket: zmq.Socket | None = None
+ self.broker_sub_port: int = broker_sub_port
+ self.broker_pub_port: int = broker_pub_port
+ self.broker_host: str = broker_host
+ self.log_pull_port: int = log_pull_port
+ self.log_host: str = log_host
+ self.outgoing_links: list[LinkInfo] = outgoing_links
+ self.incoming_links: list[LinkInfo] = incoming_links
+ self.incoming_sockets: dict[str, zmq.Socket] = dict()
+
+ # Redefine dictionary of actions for the RunManager
+ self.actions = {"setup": self.setup, "run": self.run_step, "stop": self.stop}
+
+ def register_with_nexus(self):
+ logger.info(f"Actor {self.name} registering with nexus")
+ self.zmq_context = zmq.Context()
+ self.zmq_context.setsockopt(SocketOption.LINGER, 0)
+ # create a REQ socket pointed at nexus' global actor in port
+ self.nexus_comm_socket = self.zmq_context.socket(zmq.REQ)
+ self.nexus_comm_socket.connect(
+ f"tcp://{self.nexus_host}:{self.nexus_comm_port}"
+ )
+
+ # create a REP socket for comms from Nexus and save its state
+ self.nexus_sig_socket = self.zmq_context.socket(zmq.REP)
+ self.nexus_sig_socket.bind("tcp://*:0") # find any available port
+ sig_socket_addr = self.nexus_sig_socket.getsockopt_string(
+ SocketOption.LAST_ENDPOINT
+ )
+ self.nexus_sig_port = int(sig_socket_addr.split(":")[-1])
+
+ # build and send a message to nexus
+ actor_state = ActorStateMsg(
+ self.name,
+ Signal.waiting(),
+ self.nexus_sig_port,
+ "comms opened and actor ready to initialize",
+ )
+
+ self.nexus_comm_socket.send_pyobj(actor_state)
+
+ rep: ActorStateReplyMsg = self.nexus_comm_socket.recv_pyobj()
+ logger.info(
+ f"Got response from nexus:\n"
+ f"Status: {rep.status}\n"
+ f"Info: {rep.info}\n"
+ )
+
+ self.links["q_comm"] = ZmqLink(self.nexus_comm_socket, f"{self.name}.q_comm")
+ self.links["q_sig"] = ZmqLink(self.nexus_sig_socket, f"{self.name}.q_sig")
+
+ logger.info(f"Actor {self.name} registered with Nexus")
+
+ def register_with_broker(self): # really opening sockets here
+ self.improv_logger.info(f"Actor {self.name} registering with broker")
+ if len(self.outgoing_links) > 0:
+ self.broker_pub_socket = self.zmq_context.socket(zmq.PUB)
+ self.broker_pub_socket.connect(
+ f"tcp://{self.broker_host}:{self.broker_sub_port}"
+ )
+
+ for incoming_link in self.incoming_links:
+ new_socket: zmq.Socket = self.zmq_context.socket(zmq.SUB)
+ new_socket.connect(f"tcp://{self.broker_host}:{self.broker_pub_port}")
+ new_socket.subscribe(incoming_link.link_topic)
+ self.incoming_sockets[incoming_link.link_name] = new_socket
+ self.improv_logger.info(f"Actor {self.name} registered with broker")
+
+ def setup_links(self):
+ self.improv_logger.info(f"Actor {self.name} setting up links")
+ for outgoing_link in self.outgoing_links:
+ self.links[outgoing_link.link_name] = ZmqLink(
+ self.broker_pub_socket,
+ outgoing_link.link_name,
+ outgoing_link.link_topic,
+ )
+
+ for incoming_link in self.incoming_links:
+ self.links[incoming_link.link_name] = ZmqLink(
+ self.incoming_sockets[incoming_link.link_name],
+ incoming_link.link_name,
+ incoming_link.link_topic,
+ )
+
+ if "q_out" in self.links.keys():
+ self.q_out = self.links["q_out"]
+ if "q_in" in self.links.keys():
+ self.q_in = self.links["q_in"]
+ self.improv_logger.info(f"Actor {self.name} finished setting up links")
+
+ def setup_logging(self):
+ self.improv_logger = logging.getLogger(self.name)
+ self.improv_logger.setLevel(logging.INFO)
+ for handler in logger.handlers:
+ self.improv_logger.addHandler(handler)
+ self.improv_logger.addHandler(
+ log.ZmqLogHandler(self.log_host, self.log_pull_port, self.zmq_context)
+ )
+
+
class AsyncActor(AbstractActor):
def __init__(self, *args, **kwargs):
super().__init__(*args)
@@ -214,7 +350,7 @@ def __init__(self, *args, **kwargs):
# Define dictionary of actions for the RunManager
self.actions = {}
self.actions["setup"] = self.setup
- self.actions["run"] = self.runStep
+ self.actions["run"] = self.run_step
self.actions["stop"] = self.stop
def run(self):
@@ -231,7 +367,7 @@ async def setup(self):
"""
pass
- async def runStep(self):
+ async def run_step(self):
raise NotImplementedError
async def stop(self):
@@ -243,13 +379,24 @@ async def stop(self):
class RunManager:
- def __init__(self, name, actions, links, runStoreInterface=None, timeout=1e-6):
+ def __init__(
+ self,
+ name,
+ actions,
+ links,
+ nexus_sig_port,
+ improv_logger,
+ runStoreInterface=None,
+ timeout=1e-6,
+ ):
self.run = False
self.stop = False
self.config = False
+ self.nexus_sig_port = nexus_sig_port
+ self.improv_logger = improv_logger
self.actorName = name
- logger.debug("RunManager for {} created".format(self.actorName))
+ self.improv_logger.debug("RunManager for {} created".format(self.actorName))
self.actions = actions
self.links = links
@@ -269,58 +416,82 @@ def __enter__(self):
try:
self.actions["run"]()
except Exception as e:
- logger.error("Actor {} error in run: {}".format(an, e))
- logger.error(traceback.format_exc())
+ self.improv_logger.error("Actor {} error in run: {}".format(an, e))
+ self.improv_logger.error(traceback.format_exc())
elif self.stop:
try:
self.actions["stop"]()
except Exception as e:
- logger.error("Actor {} error in stop: {}".format(an, e))
- logger.error(traceback.format_exc())
+ self.improv_logger.error("Actor {} error in stop: {}".format(an, e))
+ self.improv_logger.error(traceback.format_exc())
self.stop = False # Run once
elif self.config:
try:
if self.runStoreInterface:
self.runStoreInterface()
self.actions["setup"]()
- self.q_comm.put([Signal.ready()])
+ self.q_comm.put(
+ ActorStateMsg(
+ self.actorName, Signal.ready(), self.nexus_sig_port, ""
+ )
+ )
+ res = self.q_comm.get()
+ self.improv_logger.info(
+ f"Actor {res.actor_name} got state update reply:\n"
+ f"Status: {res.status}\n"
+ f"Info: {res.info}\n"
+ )
except Exception as e:
- logger.error("Actor {} error in setup: {}".format(an, e))
- logger.error(traceback.format_exc())
+ self.improv_logger.error(
+ "Actor {} error in setup: {}".format(an, e)
+ )
+ self.improv_logger.error(traceback.format_exc())
self.config = False
# Check for new Signals received from Nexus
try:
- signal = self.q_sig.get(timeout=self.timeout)
- logger.debug("{} received Signal {}".format(self.actorName, signal))
+ signal_msg = self.q_sig.get(timeout=self.timeout)
+ signal = signal_msg.signal
+ self.q_sig.put(ActorSignalReplyMsg(an, signal, "OK", ""))
+ self.improv_logger.warning(
+ "{} received Signal {}".format(self.actorName, signal)
+ )
if signal == Signal.run():
self.run = True
- logger.warning("Received run signal, begin running")
+ self.improv_logger.warning("Received run signal, begin running")
elif signal == Signal.setup():
self.config = True
elif signal == Signal.stop():
self.run = False
self.stop = True
- logger.warning(f"actor {self.actorName} received stop signal")
+ self.improv_logger.warning(
+ f"actor {self.actorName} received stop signal"
+ )
elif signal == Signal.quit():
- logger.warning("Received quit signal, aborting")
+ self.improv_logger.warning("Received quit signal, aborting")
break
elif signal == Signal.pause():
- logger.warning("Received pause signal, pending...")
+ self.improv_logger.warning("Received pause signal, pending...")
self.run = False
elif signal == Signal.resume(): # currently treat as same as run
- logger.warning("Received resume signal, resuming")
+ self.improv_logger.warning("Received resume signal, resuming")
self.run = True
+ elif signal == Signal.status():
+ self.improv_logger.info(
+ f"Actor {self.actorName} received status request"
+ )
except KeyboardInterrupt:
break
except Empty:
pass # No signal from Nexus
+ except TimeoutError:
+ pass # No signal from Nexus over zmq
return None
def __exit__(self, type, value, traceback):
- logger.info("Ran for " + str(time.time() - self.start) + " seconds")
- logger.warning("Exiting RunManager")
+ self.improv_logger.info("Ran for " + str(time.time() - self.start) + " seconds")
+ self.improv_logger.warning("Exiting RunManager")
return None
@@ -473,3 +644,11 @@ def stop():
@staticmethod
def stop_success():
return "stop success"
+
+ @staticmethod
+ def status():
+ return "status"
+
+ @staticmethod
+ def waiting():
+ return "waiting"
diff --git a/improv/broker.py b/improv/broker.py
new file mode 100644
index 00000000..e9055080
--- /dev/null
+++ b/improv/broker.py
@@ -0,0 +1,131 @@
+from __future__ import annotations
+
+import logging
+import signal
+
+import zmq
+from zmq import SocketOption
+
+from improv.messaging import BrokerInfoMsg
+
+DEBUG = True
+
+local_log = logging.getLogger(__name__)
+
+
+def bootstrap_broker(nexus_hostname, nexus_port):
+ if DEBUG:
+ local_log.addHandler(logging.FileHandler("broker_server.log"))
+ try:
+ broker = PubSubBroker(nexus_hostname, nexus_port)
+ broker.register_with_nexus()
+ broker.serve(broker.read_and_pub_message)
+ except Exception as e:
+ local_log.error(e)
+ for handler in local_log.handlers:
+ handler.close()
+
+
+class PubSubBroker:
+ def __init__(self, nexus_hostname, nexus_comm_port):
+ self.running = True
+ self.nexus_hostname: str = nexus_hostname
+ self.nexus_comm_port: int = nexus_comm_port
+ self.zmq_context: zmq.Context | None = None
+ self.nexus_socket: zmq.Socket | None = None
+ self.pub_port: int
+ self.sub_port: int
+ self.pub_socket: zmq.Socket | None = None
+ self.sub_socket: zmq.Socket | None = None
+
+ signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
+ for s in signals:
+ signal.signal(s, self.stop)
+
+ def register_with_nexus(self):
+ # connect to nexus
+ self.zmq_context = zmq.Context()
+ self.zmq_context.setsockopt(SocketOption.LINGER, 0)
+ self.nexus_socket = self.zmq_context.socket(zmq.REQ)
+ self.nexus_socket.connect(f"tcp://{self.nexus_hostname}:{self.nexus_comm_port}")
+
+ self.sub_socket = self.zmq_context.socket(zmq.SUB)
+ self.sub_socket.bind("tcp://*:0")
+ sub_port_string = self.sub_socket.getsockopt_string(SocketOption.LAST_ENDPOINT)
+ self.sub_port = int(sub_port_string.split(":")[-1])
+ self.sub_socket.subscribe("") # receive all incoming messages
+
+ self.pub_socket = self.zmq_context.socket(zmq.PUB)
+ self.pub_socket.bind("tcp://*:0")
+ pub_port_string = self.pub_socket.getsockopt_string(SocketOption.LAST_ENDPOINT)
+ self.pub_port = int(pub_port_string.split(":")[-1])
+
+ port_info = BrokerInfoMsg(
+ "broker",
+ self.pub_port,
+ self.sub_port,
+ "Ports up and running, ready to serve messages",
+ )
+
+ self.nexus_socket.send_pyobj(port_info)
+
+ local_log.info("broker attempting to get message from nexus")
+ msg_available = 0
+ retries = 3
+ while (retries > 3) and (msg_available == 0):
+ msg_available = self.nexus_socket.poll(timeout=1000)
+ if msg_available == 0:
+ local_log.info(
+ "broker didn't get a reply from nexus. cycling socket and resending"
+ )
+ self.nexus_socket.close(linger=0)
+ self.nexus_socket = self.zmq_context.socket(zmq.REQ)
+ self.nexus_socket.connect(
+ f"tcp://{self.nexus_hostname}:{self.nexus_comm_port}"
+ )
+ self.nexus_socket.send_pyobj(port_info)
+ local_log.info("broker resent message")
+ retries -= 1
+
+ self.nexus_socket.recv_pyobj()
+
+ local_log.info("broker got message back from nexus")
+
+ return
+
+ def serve(self, message_process_func):
+ local_log.info("broker serving")
+ while self.running:
+ # this is more testable but may have a performance overhead
+ message_process_func()
+ self.shutdown()
+
+ def read_and_pub_message(self):
+ try:
+ msg_ready = self.sub_socket.poll(timeout=0)
+ if msg_ready != 0:
+ msg = self.sub_socket.recv_multipart()
+ self.pub_socket.send_multipart(msg)
+ except zmq.error.ZMQError:
+ self.running = False
+
+ def shutdown(self):
+ for handler in local_log.handlers:
+ handler.close()
+
+ if self.sub_socket:
+ self.sub_socket.close(linger=0)
+
+ if self.pub_socket:
+ self.sub_socket.close(linger=0)
+
+ if self.nexus_socket:
+ self.sub_socket.close(linger=0)
+
+ if self.zmq_context:
+ self.zmq_context.destroy(linger=0)
+
+ def stop(self, signum, frame):
+ local_log.info(f"Log server shutting down due to signal {signum}")
+
+ self.running = False
diff --git a/improv/cli.py b/improv/cli.py
index 720d4e2f..b3f6e586 100644
--- a/improv/cli.py
+++ b/improv/cli.py
@@ -8,15 +8,10 @@
import psutil
import time
import datetime
-from zmq import SocketOption
-from zmq.log.handlers import PUBHandler
from improv.tui import TUI
from improv.nexus import Nexus
MAX_PORT = 2**16 - 1
-DEFAULT_CONTROL_PORT = "0"
-DEFAULT_OUTPUT_PORT = "0"
-DEFAULT_LOGGING_PORT = "0"
def file_exists(fname):
@@ -80,21 +75,18 @@ def parse_cli_args(args):
"-c",
"--control-port",
type=is_valid_port,
- default=DEFAULT_CONTROL_PORT,
help="local port on which control are sent to/from server",
)
run_parser.add_argument(
"-o",
"--output-port",
type=is_valid_port,
- default=DEFAULT_OUTPUT_PORT,
help="local port on which server output messages are broadcast",
)
run_parser.add_argument(
"-l",
"--logging-port",
type=is_valid_port,
- default=DEFAULT_LOGGING_PORT,
help="local port on which logging messages are broadcast",
)
run_parser.add_argument(
@@ -121,21 +113,18 @@ def parse_cli_args(args):
"-c",
"--control-port",
type=is_valid_ip_addr,
- default=DEFAULT_CONTROL_PORT,
help="address on which control signals are sent to the server",
)
client_parser.add_argument(
"-s",
"--server-port",
type=is_valid_ip_addr,
- default=DEFAULT_OUTPUT_PORT,
help="address on which messages from the server are received",
)
client_parser.add_argument(
"-l",
"--logging-port",
type=is_valid_ip_addr,
- default=DEFAULT_LOGGING_PORT,
help="address on which logging messages are broadcast",
)
client_parser.set_defaults(func=run_client)
@@ -147,21 +136,18 @@ def parse_cli_args(args):
"-c",
"--control-port",
type=is_valid_port,
- default=DEFAULT_CONTROL_PORT,
help="local port on which control signals are received",
)
server_parser.add_argument(
"-o",
"--output-port",
type=is_valid_port,
- default=DEFAULT_OUTPUT_PORT,
help="local port on which output messages are broadcast",
)
server_parser.add_argument(
"-l",
"--logging-port",
type=is_valid_port,
- default=DEFAULT_LOGGING_PORT,
help="local port on which logging messages are broadcast",
)
server_parser.add_argument(
@@ -212,18 +198,11 @@ def run_server(args):
"""
Runs the improv server in headless mode.
"""
- zmq_log_handler = PUBHandler("tcp://*:%s" % args.logging_port)
- # in case we bound to a random port (default), get port number
- logging_port = int(
- zmq_log_handler.socket.getsockopt_string(SocketOption.LAST_ENDPOINT).split(":")[
- -1
- ]
- )
logging.basicConfig(
level=logging.DEBUG,
format="%(name)s %(message)s",
- handlers=[logging.FileHandler(args.logfile), zmq_log_handler],
+ handlers=[logging.FileHandler("improv-debug.log")],
)
if not args.actor_path:
@@ -232,18 +211,25 @@ def run_server(args):
sys.path.extend(args.actor_path)
server = Nexus()
- control_port, output_port = server.createNexus(
+ control_port, output_port, log_port = server.create_nexus(
file=args.configfile,
control_port=args.control_port,
output_port=args.output_port,
+ log_server_pub_port=args.logging_port,
+ logfile=args.logfile,
)
curr_dt = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(
f"{curr_dt} Server running on (control, output, log) ports "
- f"({control_port}, {output_port}, {logging_port}).\n"
+ f"({control_port}, {output_port}, {log_port}).\n"
f"Press Ctrl-C to quit."
)
- server.startNexus()
+ try:
+ server.start_nexus(server.poll_queues, poll_function=server.poll_kernel)
+ except Exception as e:
+ print(f"CLI-started server run encountered uncaught error {e}")
+ logging.error(f"CLI-started server run encountered uncaught error {e}")
+ raise e
if args.actor_path:
for p in args.actor_path:
@@ -254,7 +240,7 @@ def run_server(args):
def run_list(args, printit=True):
out_list = []
- pattern = re.compile(r"(improv (run|client|server)|plasma_store|redis-server)")
+ pattern = re.compile(r"(improv (run|client|server)|redis-server)")
# mp_pattern = re.compile(r"-c from multiprocessing") # TODO is this right?
for proc in psutil.process_iter(["pid", "name", "cmdline"]):
if proc.info["cmdline"]:
@@ -281,22 +267,38 @@ def run_cleanup(args, headless=False):
if res.lower() == "y":
for proc in proc_list:
- if not proc.status() == psutil.STATUS_STOPPED:
- logging.info(
- f"process {proc.pid} {proc.name()}"
- f" has status {proc.status()}. Interrupting."
- )
- try:
+ try:
+ if not proc.status() == psutil.STATUS_STOPPED:
+ logging.info(
+ f"process {proc.pid} {proc.name()}"
+ f" has status {proc.status()}. Interrupting."
+ )
proc.send_signal(signal.SIGINT)
- except psutil.NoSuchProcess:
- pass
+ except psutil.NoSuchProcess:
+ pass
gone, alive = psutil.wait_procs(proc_list, timeout=3)
for p in alive:
- p.send_signal(signal.SIGINT)
try:
+ p.terminate()
p.wait(timeout=10)
except psutil.TimeoutExpired as e:
logging.warning(f"{e}: Process did not exit on time.")
+ try:
+ p.kill()
+ except psutil.NoSuchProcess as e:
+ logging.warning(
+ f"{e}: Process exited after wait timeout"
+ f" but before kill signal attempted."
+ )
+ # this happens sometimes because Nexus uses gracious
+ # timeout periods.
+ except psutil.NoSuchProcess as e:
+ logging.warning(
+ f"{e}: Process exited after wait timeout"
+ f" but before kill signal attempted."
+ )
+ # this happens sometimes because Nexus uses gracious
+ # timeout periods.
else:
if not headless:
@@ -313,19 +315,28 @@ def run(args, timeout=10):
server_opts = [
"improv",
"server",
- "-c",
- str(args.control_port),
- "-o",
- str(args.output_port),
- "-l",
- str(args.logging_port),
"-f",
args.logfile,
]
+
+ if args.control_port:
+ server_opts.append("-c")
+ server_opts.append(str(args.control_port))
+
+ if args.output_port:
+ server_opts.append("-o")
+ server_opts.append(str(args.output_port))
+
+ if args.logging_port:
+ server_opts.append("-l")
+ server_opts.append(str(args.logging_port))
+
server_opts.extend(apath_opts)
server_opts.append(args.configfile)
- with open(args.logfile, mode="a+") as logfile:
+ print(" ".join(server_opts))
+
+ with open("improv-debug.log", mode="a+") as logfile:
server = subprocess.Popen(server_opts, stdout=logfile, stderr=logfile)
# wait for server to start up
@@ -338,7 +349,9 @@ def run(args, timeout=10):
run_client(args)
try:
- server.wait(timeout=2)
+ wait_timeout = 60
+ print(f"Waiting {wait_timeout} seconds for Nexus to complete shutdown.")
+ server.wait(timeout=wait_timeout)
except subprocess.TimeoutExpired:
print("Cleaning up the hard way. May have exited dirty.")
server.terminate()
@@ -354,9 +367,9 @@ def get_server_ports(args, timeout):
time_now = 0
ports = None
while time_now < timeout:
- server_start_time = _server_start_logged(args.logfile)
+ server_start_time = _server_start_logged("improv-debug.log")
if server_start_time and server_start_time >= curr_dt:
- ports = _get_ports(args.logfile)
+ ports = _get_ports("improv-debug.log")
if ports:
break
@@ -365,12 +378,12 @@ def get_server_ports(args, timeout):
if not server_start_time:
print(
- f"Unable to read server start time from {args.logfile}.\n"
+ f"Unable to read server start time from {'improv-debug.log'}.\n"
"This may be because the server could not be started or "
"did not log its activity."
)
elif not ports:
- print(f"Unable to read ports from {args.logfile}.")
+ print(f"Unable to read ports from {'improv-debug.log'}.")
return ports
@@ -393,13 +406,16 @@ def _get_ports(logfile):
# read logfile to get ports
with open(logfile, mode="r") as logfile:
contents = logfile.read()
+ return _read_log_contents_for_ports(contents)
- pattern = re.compile(r"(?<=\(control, output, log\) ports \()\d*, \d*, \d*")
- # get most recent match (log file may contain old runs)
- port_str_list = pattern.findall(contents)
- if port_str_list:
- port_str = port_str_list[-1]
- return (int(p) for p in port_str.split(", "))
- else:
- return None
+def _read_log_contents_for_ports(logfile_contents):
+ pattern = re.compile(r"(?<=\(control, output, log\) ports \()\d*, \d*, \d*")
+
+ # get most recent match (log file may contain old runs)
+ port_str_list = pattern.findall(logfile_contents)
+ if port_str_list:
+ port_str = port_str_list[-1]
+ return (int(p) for p in port_str.split(", "))
+ else:
+ return None
diff --git a/improv/config.py b/improv/config.py
index 5f73c18b..59446fa5 100644
--- a/improv/config.py
+++ b/improv/config.py
@@ -6,46 +6,48 @@
logger = logging.getLogger(__name__)
+class CannotCreateConfigException(Exception):
+ def __init__(self, msg):
+ super().__init__("Cannot create config: {}".format(msg))
+
+
class Config:
"""Handles configuration and logs of configs for
the entire server/processing pipeline.
"""
- def __init__(self, configFile):
- if configFile is None:
- logger.error("Need to specify a config file")
- raise Exception
- else:
- # Reading config from other yaml file
- self.configFile = configFile
-
- with open(self.configFile, "r") as ymlfile:
- cfg = yaml.safe_load(ymlfile)
-
- try:
- if "settings" in cfg:
- self.settings = cfg["settings"]
- else:
- self.settings = {}
+ def __init__(self, config_file):
+ self.actors = {}
+ self.connections = {}
+ self.hasGUI = False
+ self.config_file = config_file
- if "use_watcher" not in self.settings:
- self.settings["use_watcher"] = False
+ with open(self.config_file, "r") as ymlfile:
+ self.config = yaml.safe_load(ymlfile)
- except TypeError:
- if cfg is None:
- logger.error("Error: The config file is empty")
+ if self.config is None:
+ logger.error("The config file is empty")
+ raise CannotCreateConfigException("The config file is empty")
- if type(cfg) is not dict:
+ if type(self.config) is not dict:
logger.error("Error: The config file is not in dictionary format")
raise TypeError
- self.config = cfg
+ def parse_config(self):
+ self.populate_defaults()
+ self.validate_config()
- self.actors = {}
- self.connections = {}
- self.hasGUI = False
+ self.settings = self.config["settings"]
+ self.redis_config = self.config["redis_config"]
+
+ def populate_defaults(self):
+ self.populate_settings_defaults()
+ self.popoulate_redis_defaults()
+
+ def validate_config(self):
+ self.validate_redis_config()
- def createConfig(self):
+ def create_config(self):
"""Read yaml config file and create config for Nexus
TODO: check for config file compliance, error handle it
beyond what we have below.
@@ -53,9 +55,6 @@ def createConfig(self):
cfg = self.config
for name, actor in cfg["actors"].items():
- if name in self.actors.keys():
- raise RepeatedActorError(name)
-
packagename = actor.pop("package")
classname = actor.pop("class")
@@ -63,11 +62,15 @@ def createConfig(self):
__import__(packagename, fromlist=[classname])
mod = import_module(packagename)
- clss = getattr(mod, classname)
- sig = signature(clss)
- configModule = ConfigModule(name, packagename, classname, options=actor)
- sig.bind(configModule.options)
+ actor_class = getattr(mod, classname)
+ sig = signature(actor_class)
+ config_module = ConfigModule(
+ name, packagename, classname, options=actor
+ )
+ sig.bind(config_module.options)
+ # TODO: this is not trivial to test, since our code formatting
+ # tools won't allow a file with a syntax error to exist
except SyntaxError as e:
logger.error(f"Error: syntax error when initializing actor {name}: {e}")
return -1
@@ -87,101 +90,123 @@ def createConfig(self):
except TypeError:
logger.error("Error: Invalid arguments passed")
params = ""
- for parameter in sig.parameters:
- params = params + " " + parameter.name
+ for param_name, param in sig.parameters.items():
+ params = params + ", " + param.name
logger.warning("Expected Parameters:" + params)
return -1
- except Exception as e:
+ except Exception as e: # TODO: figure out how to test this
logger.error(f"Error: {e}")
return -1
if "GUI" in name:
logger.info(f"Config detected a GUI actor: {name}")
self.hasGUI = True
- self.gui = configModule
+ self.gui = config_module
else:
- self.actors.update({name: configModule})
+ self.actors.update({name: config_module})
for name, conn in cfg["connections"].items():
- if name in self.connections.keys():
- raise RepeatedConnectionsError(name)
-
self.connections.update({name: conn})
- if "datastore" in cfg.keys():
- self.datastore = cfg["datastore"]
-
return 0
- def addParams(self, type, param):
- """Function to add paramter param of type type
- TODO: Future work
- """
- pass
-
- def saveActors(self):
+ def save_actors(self):
"""Saves the actors config to a specific file."""
wflag = True
- saveFile = self.configFile.split(".")[0]
+ saveFile = self.config_file.split(".")[0]
pathName = saveFile + "_actors.yaml"
for a in self.actors.values():
- wflag = a.saveConfigModules(pathName, wflag)
-
- def use_plasma(self):
- return "plasma_config" in self.config.keys()
-
- def get_redis_port(self):
- if self.redis_port_specified():
- return self.config["redis_config"]["port"]
- else:
- return Config.get_default_redis_port()
-
- def redis_port_specified(self):
- if "redis_config" in self.config.keys():
- return "port" in self.config["redis_config"]
- return False
-
- def redis_saving_enabled(self):
- if "redis_config" in self.config.keys():
- return (
- self.config["redis_config"]["enable_saving"]
- if "enable_saving" in self.config["redis_config"]
- else None
+ wflag = a.save_config_modules(pathName, wflag)
+
+ def populate_settings_defaults(self):
+ if "settings" not in self.config:
+ self.config["settings"] = {}
+
+ if "store_size" not in self.config["settings"]:
+ self.config["settings"]["store_size"] = 250_000_000
+ if "control_port" not in self.config["settings"]:
+ self.config["settings"]["control_port"] = 5555
+ if "output_port" not in self.config["settings"]:
+ self.config["settings"]["output_port"] = 5556
+ if "actor_in_port" not in self.config["settings"]:
+ self.config["settings"]["actor_in_port"] = 0
+ if "harvest_data_from_memory" not in self.config["settings"]:
+ self.config["settings"]["harvest_data_from_memory"] = None
+
+ def popoulate_redis_defaults(self):
+ if "redis_config" not in self.config:
+ self.config["redis_config"] = {}
+
+ if "enable_saving" not in self.config["redis_config"]:
+ self.config["redis_config"]["enable_saving"] = None
+ if "aof_dirname" not in self.config["redis_config"]:
+ self.config["redis_config"]["aof_dirname"] = None
+ if "generate_ephemeral_aof_dirname" not in self.config["redis_config"]:
+ self.config["redis_config"]["generate_ephemeral_aof_dirname"] = False
+ if "fsync_frequency" not in self.config["redis_config"]:
+ self.config["redis_config"]["fsync_frequency"] = None
+
+ # enable saving automatically if the user configured a saving option
+ if (
+ self.config["redis_config"]["aof_dirname"]
+ or self.config["redis_config"]["generate_ephemeral_aof_dirname"]
+ or self.config["redis_config"]["fsync_frequency"]
+ ) and self.config["redis_config"]["enable_saving"] is None:
+ self.config["redis_config"]["enable_saving"] = True
+
+ if "port" not in self.config["redis_config"]:
+ self.config["redis_config"]["port"] = 6379
+
+ def validate_redis_config(self):
+ fsync_name_dict = {
+ "every_write": "always",
+ "every_second": "everysec",
+ "no_schedule": "no",
+ }
+ if (
+ self.config["redis_config"]["aof_dirname"]
+ and self.config["redis_config"]["generate_ephemeral_aof_dirname"]
+ ):
+ logger.error(
+ "Cannot both generate a unique dirname and use the one provided."
)
-
- def generate_ephemeral_aof_dirname(self):
- if "redis_config" in self.config.keys():
- return (
- self.config["redis_config"]["generate_ephemeral_aof_dirname"]
- if "generate_ephemeral_aof_dirname" in self.config["redis_config"]
- else None
- )
- return False
-
- def get_redis_aof_dirname(self):
- if "redis_config" in self.config.keys():
- return (
- self.config["redis_config"]["aof_dirname"]
- if "aof_dirname" in self.config["redis_config"]
- else None
+ raise Exception("Cannot use unique dirname and use the one provided.")
+
+ if (
+ self.config["redis_config"]["aof_dirname"]
+ or self.config["redis_config"]["generate_ephemeral_aof_dirname"]
+ or self.config["redis_config"]["fsync_frequency"]
+ ):
+ if not self.config["redis_config"]["enable_saving"]:
+ logger.error(
+ "Invalid configuration. Cannot save to disk with saving disabled."
+ )
+ raise Exception("Cannot persist to disk with saving disabled.")
+
+ if self.config["redis_config"]["fsync_frequency"] and self.config[
+ "redis_config"
+ ]["fsync_frequency"] not in [
+ "every_write",
+ "every_second",
+ "no_schedule",
+ ]:
+ logger.error(
+ f"Cannot use unknown fsync frequency "
+ f'{self.config["redis_config"]["fsync_frequency"]}'
)
- return None
-
- def get_redis_fsync_frequency(self):
- if "redis_config" in self.config.keys():
- frequency = (
- self.config["redis_config"]["fsync_frequency"]
- if "fsync_frequency" in self.config["redis_config"]
- else None
+ raise Exception(
+ f"Cannot use unknown fsync frequency "
+ f'{self.config["redis_config"]["fsync_frequency"]}'
)
- return frequency
+ if self.config["redis_config"]["fsync_frequency"] is None:
+ self.config["redis_config"]["fsync_frequency"] = "no_schedule"
- @staticmethod
- def get_default_redis_port():
- return "6379"
+ self.config["redis_config"]["fsync_frequency"] = (fsync_name_dict)[
+ self.config["redis_config"]["fsync_frequency"]
+ ]
class ConfigModule:
@@ -191,11 +216,11 @@ def __init__(self, name, packagename, classname, options=None):
self.classname = classname
self.options = options
- def saveConfigModules(self, pathName, wflag):
+ def save_config_modules(self, path_name, wflag):
"""Loops through each actor to save the modules to the config file.
Args:
- pathName:
+ path_name:
wflag (bool):
Returns:
@@ -213,32 +238,7 @@ def saveConfigModules(self, pathName, wflag):
for key, value in self.options.items():
cfg[self.name].update({key: value})
- with open(pathName, writeOption) as file:
+ with open(path_name, writeOption) as file:
yaml.dump(cfg, file)
return wflag
-
-
-class RepeatedActorError(Exception):
- def __init__(self, repeat):
- super().__init__()
-
- self.name = "RepeatedActorError"
- self.repeat = repeat
-
- self.message = 'Actor name has already been used: "{}"'.format(repeat)
-
- def __str__(self):
- return self.message
-
-
-class RepeatedConnectionsError(Exception):
- def __init__(self, repeat):
- super().__init__()
- self.name = "RepeatedConnectionsError"
- self.repeat = repeat
-
- self.message = 'Connection name has already been used: "{}"'.format(repeat)
-
- def __str__(self):
- return self.message
diff --git a/improv/harvester.py b/improv/harvester.py
new file mode 100644
index 00000000..ba66bdf6
--- /dev/null
+++ b/improv/harvester.py
@@ -0,0 +1,155 @@
+from __future__ import annotations
+
+import logging
+import signal
+import time
+
+import zmq
+
+from improv.link import ZmqLink
+from improv.log import ZmqLogHandler
+from improv.store import RedisStoreInterface
+from zmq import SocketOption
+
+from improv.messaging import HarvesterInfoMsg
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
+
+def bootstrap_harvester(
+ nexus_hostname,
+ nexus_port,
+ redis_hostname,
+ redis_port,
+ broker_hostname,
+ broker_port,
+ logger_hostname,
+ logger_port,
+):
+ harvester = RedisHarvester(
+ nexus_hostname,
+ nexus_port,
+ redis_hostname,
+ redis_port,
+ broker_hostname,
+ broker_port,
+ logger_hostname,
+ logger_port,
+ )
+ harvester.establish_connections()
+ harvester.register_with_nexus()
+ harvester.serve(harvester.collect)
+
+
+class RedisHarvester:
+ def __init__(
+ self,
+ nexus_hostname,
+ nexus_comm_port,
+ redis_hostname,
+ redis_port,
+ broker_hostname,
+ broker_port,
+ logger_hostname,
+ logger_port,
+ ):
+ self.link: ZmqLink | None = None
+ self.running = True
+ self.nexus_hostname: str = nexus_hostname
+ self.nexus_comm_port: int = nexus_comm_port
+ self.redis_hostname: str = redis_hostname
+ self.redis_port: int = redis_port
+ self.broker_hostname: str = broker_hostname
+ self.broker_port: int = broker_port
+ self.zmq_context: zmq.Context | None = None
+ self.nexus_socket: zmq.Socket | None = None
+ self.sub_port: int | None = None
+ self.sub_socket: zmq.Socket | None = None
+ self.store_client: RedisStoreInterface | None = None
+ self.logger_hostname: str = logger_hostname
+ self.logger_port: int = logger_port
+
+ logger.addHandler(ZmqLogHandler(logger_hostname, logger_port))
+
+ signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
+ for s in signals:
+ signal.signal(s, self.stop)
+
+ def establish_connections(self):
+ logger.info("Registering with Nexus")
+ # connect to nexus
+ self.zmq_context = zmq.Context()
+ self.zmq_context.setsockopt(SocketOption.LINGER, 0)
+ self.nexus_socket = self.zmq_context.socket(zmq.REQ)
+ self.nexus_socket.connect(f"tcp://{self.nexus_hostname}:{self.nexus_comm_port}")
+
+ self.sub_socket = self.zmq_context.socket(zmq.SUB)
+ self.sub_socket.connect(f"tcp://{self.broker_hostname}:{self.broker_port}")
+ sub_port_string = self.sub_socket.getsockopt_string(SocketOption.LAST_ENDPOINT)
+ self.sub_port = int(sub_port_string.split(":")[-1])
+ self.sub_socket.subscribe("") # receive all incoming messages
+
+ self.store_client = RedisStoreInterface(
+ "harvester", self.redis_port, self.redis_hostname
+ )
+
+ self.link = ZmqLink(self.sub_socket, "harvester", "")
+
+ # TODO: this is a small function that is currently easy to verify
+ # it could be unit-tested with a multiprocess that spins up a socket
+ # which this can communicate with, but isn't currenlty worth the time
+ def register_with_nexus(self):
+ port_info = HarvesterInfoMsg(
+ "harvester",
+ "Ports up and running, ready to serve messages",
+ )
+
+ self.nexus_socket.send_pyobj(port_info)
+ self.nexus_socket.recv_pyobj()
+
+ return
+
+ def serve(self, message_process_func, *args, **kwargs):
+ logger.info("Harvester beginning harvest")
+ while self.running:
+ message_process_func(*args, **kwargs)
+ self.shutdown()
+
+ def collect(self, *args, **kwargs):
+ db_info = self.store_client.client.info()
+ max_memory = db_info["maxmemory"]
+ used_memory = db_info["used_memory"]
+ used_max_ratio = used_memory / max_memory
+ if used_max_ratio > 0.75:
+ while self.running and (used_max_ratio > 0.50):
+ try:
+ key = self.link.get(timeout=100) # 100ms
+ self.store_client.client.delete(key)
+ except TimeoutError:
+ pass
+ db_info = self.store_client.client.info()
+ max_memory = db_info["maxmemory"]
+ used_memory = db_info["used_memory"]
+ used_max_ratio = used_memory / max_memory
+ time.sleep(0.1)
+ return
+
+ def shutdown(self):
+ for handler in logger.handlers:
+ if isinstance(handler, ZmqLogHandler):
+ handler.close()
+ logger.removeHandler(handler)
+
+ if self.sub_socket:
+ self.sub_socket.close(linger=0)
+
+ if self.nexus_socket:
+ self.nexus_socket.close(linger=0)
+
+ if self.zmq_context:
+ self.zmq_context.destroy(linger=0)
+
+ def stop(self, signum, frame):
+ self.running = False
+ logger.info(f"Harvester shutting down due to signal {signum}")
diff --git a/improv/link.py b/improv/link.py
index 5014204b..81fee4e5 100644
--- a/improv/link.py
+++ b/improv/link.py
@@ -1,91 +1,27 @@
import asyncio
+import json
import logging
-from multiprocessing import Manager, cpu_count
+from multiprocessing import cpu_count
from concurrent.futures import ThreadPoolExecutor
-from concurrent.futures._base import CancelledError
+
+import zmq
+from zmq import SocketOption
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
-def Link(name, start, end):
- """Function to construct a queue that Nexus uses for
- inter-process (actor) signaling and information passing.
-
- A Link has an internal queue that can be synchronous (put, get)
- as inherited from multiprocessing.Manager.Queue
- or asynchronous (put_async, get_async) using async executors.
-
- Args:
- See AsyncQueue constructor
-
- Returns:
- AsyncQueue: queue for communicating between actors and with Nexus
- """
-
- m = Manager()
- q = AsyncQueue(m.Queue(maxsize=0), name, start, end)
- return q
-
-
-class AsyncQueue(object):
- """Single-output and asynchronous queue class.
-
- Attributes:
- queue:
- real_executor:
- cancelled_join: boolean
- name:
- start:
- end:
- status:
- result:
- num:
- dict:
- """
-
- def __init__(self, q, name, start, end):
- """Constructor for the queue class.
-
- Args:
- q (Queue): A queue from the Manager class
- name (str): String description of this queue
- start (str): The producer (input) actor name for the queue
- end (str): The consumer (output) actor name for the queue
-
- """
- self.queue = q
- self.real_executor = None
- self.cancelled_join = False
-
+class ZmqLink:
+ def __init__(self, socket: zmq.Socket, name, topic=None):
self.name = name
- self.start = start
- self.end = end
-
+ self.real_executor = None
+ self.socket = socket # this must already be set up
+ self.socket_type = self.socket.getsockopt(SocketOption.TYPE)
+ if self.socket_type == zmq.PUB and topic is None:
+ raise Exception("Cannot open PUB link without topic")
self.status = "pending"
self.result = None
-
- def getStart(self):
- """Gets the starting actor.
-
- The starting actor is the actor that is at the tail of the link.
- This actor is the one that gives output.
-
- Returns:
- start (str): The starting actor name
- """
- return self.start
-
- def getEnd(self):
- """Gets the ending actor.
-
- The ending actor is the actor that is at the head of the link.
- This actor is the one that takes input.
-
- Returns:
- end (str): The ending actor name
- """
- return self.end
+ self.topic = topic
@property
def _executor(self):
@@ -93,195 +29,57 @@ def _executor(self):
self.real_executor = ThreadPoolExecutor(max_workers=cpu_count())
return self.real_executor
- def __getstate__(self):
- """Gets a dictionary of attributes.
-
- This function gets a dictionary, with keys being the names of
- the attributes, and values being the values of the attributes.
+ def get(self, timeout=None):
+ if timeout is None:
+ if self.socket_type == zmq.SUB:
+ res = self.socket.recv_multipart()
+ return json.loads(res[1].decode("utf-8"))
- Returns:
- self_dict (dict): A dictionary containing attributes.
- """
- self_dict = self.__dict__
- self_dict["_real_executor"] = None
- return self_dict
+ return self.socket.recv_pyobj()
- def __getattr__(self, name):
- """Gets the attribute specified by "name".
+ msg_ready = self.socket.poll(timeout=timeout)
+ if msg_ready == 0:
+ raise TimeoutError
- Args:
- name (str): Name of the attribute to be returned.
+ if self.socket_type == zmq.SUB:
+ res = self.socket.recv_multipart()
+ return json.loads(res[1])
- Raises:
- AttributeError: Restricts the available attributes to a
- specific list. This error is raised if a different attribute
- of the queue is requested.
+ return self.socket.recv_pyobj()
- TODO:
- Don't raise this?
+ def get_nowait(self):
+ return self.get(timeout=0)
- Returns:
- (object): Value of the attribute specified by "name".
- """
- if name in ["qsize", "empty", "full", "get", "get_nowait", "close"]:
- return getattr(self.queue, name)
+ def put(
+ self, item
+ ): # TODO: is it a problem we're implicitly handling inputs differently here?
+ if self.socket_type == zmq.PUB:
+ self.socket.send_multipart(
+ [self.topic.encode("utf-8"), json.dumps(item).encode("utf-8")]
+ )
else:
- cn = self.__class__.__name__
- raise AttributeError("{} object has no attribute {}".format(cn, name))
-
- def __repr__(self):
- """String representation for Link.
-
- Returns:
- (str): "Link" followed by the name given in the constructor.
- """
- return "Link " + self.name
-
- def put(self, item):
- """Function wrapper for put.
-
- Args:
- item (object): Any item that can be sent through a queue
- """
- self.queue.put(item)
-
- def put_nowait(self, item):
- """Function wrapper for put without waiting
-
- Args:
- item (object): Any item that can be sent through a queue
- """
- self.queue.put_nowait(item)
+ self.socket.send_pyobj(item)
async def put_async(self, item):
- """Coroutine for an asynchronous put
-
- It adds the put request to the event loop and awaits.
-
- Args:
- item (object): Any item that can be sent through a queue
-
- Returns:
- Awaitable or result of the put
- """
loop = asyncio.get_event_loop()
- try:
- res = await loop.run_in_executor(self._executor, self.put, item)
- return res
- except EOFError:
- logger.warn("Link probably killed (EOF)")
- except FileNotFoundError:
- logger.warn("probably killed (file not found)")
+ res = await loop.run_in_executor(self._executor, self.put, item)
+ return res
async def get_async(self):
- """Coroutine for an asynchronous get
-
- It adds the get request to the event loop and awaits, setting
- the status to pending. Once the get has returned, it returns the
- result of the get and sets its status as done.
-
- Explicitly passes any exceptions to not hinder execution.
- Errors are logged with the get_async tag.
-
- Returns:
- Awaitable or result of the get.
-
- Raises:
- CancelledError: task is cancelled
- EOFError:
- FileNotFoundError:
- Exception:
- """
loop = asyncio.get_event_loop()
self.status = "pending"
- try:
- self.result = await loop.run_in_executor(self._executor, self.get)
- self.status = "done"
- return self.result
- except CancelledError:
- logger.info("Task {} Canceled".format(self.name))
- except EOFError:
- logger.info("probably killed")
- except FileNotFoundError:
- logger.info("probably killed")
- except Exception as e:
- logger.exception("Error in get_async: {}".format(e))
-
- def cancel_join_thread(self):
- """Function wrapper for cancel_join_thread."""
- self._cancelled_join = True
- self._queue.cancel_join_thread()
-
- def join_thread(self):
- """Function wrapper for join_thread."""
- self._queue.join_thread()
- if self._real_executor and not self._cancelled_join:
- self._real_executor.shutdown()
-
-
-def MultiLink(name, start, end):
- """Function to generate links for the multi-output queue case.
-
- Args:
- See constructor for AsyncQueue or MultiAsyncQueue
-
- Returns:
- MultiAsyncQueue: Producer end of the queue
- List: AsyncQueues for consumers
- """
- m = Manager()
-
- q_out = []
- for endpoint in end:
- q = AsyncQueue(m.Queue(maxsize=0), name, start, endpoint)
- q_out.append(q)
-
- q = MultiAsyncQueue(m.Queue(maxsize=0), q_out, name, start, end)
-
- return q, q_out
-
-
-class MultiAsyncQueue(AsyncQueue):
- """Extension of AsyncQueue to have multiple endpoints.
-
- Inherits from AsyncQueue.
- A single producer queue's 'put' is copied to multiple consumer's
- queues, q_in is the producer queue, q_out are the consumer queues.
-
- TODO:
- Test the async nature of this group of queues
- """
-
- def __init__(self, q_in, q_out, name, start, end):
- self.queue = q_in
- self.output = q_out
-
- self.real_executor = None
- self.cancelled_join = False
-
- self.name = name
- self.start = start
- self.end = end[0]
- self.status = "pending"
- self.result = None
-
- def __repr__(self):
- return "MultiLink " + self.name
-
- def __getattr__(self, name):
- # Remove put and put_nowait and define behavior specifically
- # TODO: remove get capability?
- if name in ["qsize", "empty", "full", "get", "get_nowait", "close"]:
- return getattr(self.queue, name)
- else:
- raise AttributeError(
- "'%s' object has no attribute '%s'" % (self.__class__.__name__, name)
- )
-
- def put(self, item):
- for q in self.output:
- q.put(item)
-
- def put_nowait(self, item):
- for q in self.output:
- q.put_nowait(item)
+ # try:
+ self.result = await loop.run_in_executor(self._executor, self.get)
+ self.status = "done"
+ return self.result
+ # TODO: explicitly commenting these out because testing them is hard.
+ # It's better to let them bubble up so that we don't miss them
+ # due to being caught without a clear reproducible mechanism
+ # except CancelledError:
+ # logger.info("Task {} Canceled".format(self.name))
+ # except EOFError:
+ # logger.info("probably killed")
+ # except FileNotFoundError:
+ # logger.info("probably killed")
+ # except Exception as e:
+ # logger.exception("Error in get_async: {}".format(e))
diff --git a/improv/log.py b/improv/log.py
new file mode 100644
index 00000000..70373d57
--- /dev/null
+++ b/improv/log.py
@@ -0,0 +1,173 @@
+from __future__ import annotations
+
+import logging
+import signal
+from logging import handlers
+from logging.handlers import QueueHandler
+
+import zmq
+from zmq import SocketOption
+from zmq.log.handlers import PUBHandler
+
+from improv.messaging import LogInfoMsg
+
+local_log = logging.getLogger(__name__)
+
+DEBUG = True
+
+# TODO: ideally there should be some kind of drain at shutdown time
+# so we don't miss any log messages, but that would make shutdown
+# also take longer. TBD?
+
+
+def bootstrap_log_server(
+ nexus_hostname, nexus_port, log_filename="global.log", logger_pull_port=None
+):
+ if DEBUG:
+ local_log.addHandler(logging.FileHandler("log_server.log"))
+ try:
+ log_server = LogServer(
+ nexus_hostname, nexus_port, log_filename, logger_pull_port
+ )
+ log_server.register_with_nexus()
+ log_server.serve(log_server.read_and_log_message)
+ except Exception as e:
+ local_log.error(e)
+ for handler in local_log.handlers:
+ handler.close()
+
+
+class ZmqPullListener(handlers.QueueListener):
+ def __init__(self, ctx, /, *handlers, **kwargs):
+ self.sentinel = False
+ self.ctx = ctx
+ self.pull_socket = self.ctx.socket(zmq.PULL)
+ self.pull_socket.bind("tcp://*:0")
+ pull_port_string = self.pull_socket.getsockopt_string(
+ SocketOption.LAST_ENDPOINT
+ )
+ self.pull_port = int(pull_port_string.split(":")[-1])
+ super().__init__(self.pull_socket, *handlers, **kwargs)
+
+ def dequeue(self, block=True):
+ msg = None
+ while msg is None:
+ if self.sentinel:
+ return handlers.QueueListener._sentinel
+ msg_ready = self.queue.poll(timeout=1000)
+ if msg_ready != 0:
+ msg = self.queue.recv_json()
+ return logging.makeLogRecord(msg)
+
+ def enqueue_sentinel(self):
+ self.sentinel = True
+
+
+class ZmqLogHandler(QueueHandler):
+ def __init__(self, hostname, port, ctx=None):
+ self.ctx = ctx if ctx else zmq.Context()
+ self.ctx.setsockopt(SocketOption.LINGER, 0)
+ self.socket = self.ctx.socket(zmq.PUSH)
+ self.socket.connect(f"tcp://{hostname}:{port}")
+ super().__init__(self.socket)
+
+ def enqueue(self, record):
+ self.queue.send_json(record.__dict__)
+
+ def close(self):
+ self.queue.close(linger=0)
+
+
+class LogServer:
+ def __init__(self, nexus_hostname, nexus_comm_port, log_filename, pub_port):
+ self.running = True
+ self.pub_port: int | None = pub_port if pub_port else 0
+ self.pub_socket: zmq.Socket | None = None
+ self.log_filename = log_filename
+ self.nexus_hostname: str = nexus_hostname
+ self.nexus_comm_port: int = nexus_comm_port
+ self.zmq_context: zmq.Context | None = None
+ self.nexus_socket: zmq.Socket | None = None
+ self.pull_socket: zmq.Socket | None = None
+ self.listener: ZmqPullListener | None = None
+
+ signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
+ for s in signals:
+ signal.signal(s, self.stop)
+
+ def register_with_nexus(self):
+ # connect to nexus
+ self.zmq_context = zmq.Context()
+ self.zmq_context.setsockopt(SocketOption.LINGER, 0)
+ self.nexus_socket = self.zmq_context.socket(zmq.REQ)
+ self.nexus_socket.connect(f"tcp://{self.nexus_hostname}:{self.nexus_comm_port}")
+
+ self.pub_socket = self.zmq_context.socket(zmq.PUB)
+ self.pub_socket.bind(f"tcp://*:{self.pub_port}")
+ pub_port_string = self.pub_socket.getsockopt_string(SocketOption.LAST_ENDPOINT)
+ self.pub_port = int(pub_port_string.split(":")[-1])
+
+ self.listener = ZmqPullListener(
+ self.zmq_context,
+ # logging.StreamHandler(sys.stdout),
+ logging.FileHandler(self.log_filename),
+ PUBHandler(self.pub_socket, self.zmq_context, "nexus_logging"),
+ )
+
+ self.listener.start()
+
+ local_log.info("logger started listening")
+
+ port_info = LogInfoMsg(
+ "broker",
+ self.listener.pull_port,
+ self.pub_port,
+ "Port up and running, ready to log messages",
+ )
+
+ self.nexus_socket.send_pyobj(port_info)
+ local_log.info("logger sent message to nexus")
+ self.nexus_socket.recv_pyobj()
+
+ local_log.info("logger got message from nexus")
+
+ return
+
+ def serve(self, log_func):
+ local_log.info("logger serving")
+ while self.running:
+ log_func() # this is more testable but may have a performance overhead
+ self.shutdown()
+
+ def read_and_log_message(self): # receive and send back out
+ pass
+
+ def shutdown(self):
+ if self.listener:
+ self.listener.stop()
+
+ for handler in self.listener.handlers:
+ try:
+ handler.close()
+ except Exception as e:
+ local_log.error(e)
+
+ if self.pull_socket:
+ self.pull_socket.close(linger=0)
+
+ if self.nexus_socket:
+ self.nexus_socket.close(linger=0)
+
+ if self.pub_socket:
+ self.pub_socket.close(linger=0)
+
+ for handler in local_log.handlers:
+ handler.close()
+
+ if self.zmq_context:
+ self.zmq_context.destroy(linger=0)
+
+ def stop(self, signum, frame):
+ local_log.info(f"Log server shutting down due to signal {signum}")
+
+ self.running = False
diff --git a/improv/messaging.py b/improv/messaging.py
new file mode 100644
index 00000000..29b0461f
--- /dev/null
+++ b/improv/messaging.py
@@ -0,0 +1,71 @@
+class ActorStateMsg:
+ def __init__(self, actor_name, status, nexus_in_port, info):
+ self.actor_name = actor_name
+ self.status = status
+ self.nexus_in_port = nexus_in_port
+ self.info = info
+
+
+class ActorStateReplyMsg:
+ def __init__(self, actor_name, status, info):
+ self.actor_name = actor_name
+ self.status = status
+ self.info = info
+
+
+class ActorSignalMsg:
+ def __init__(self, actor_name, signal, info):
+ self.actor_name = actor_name
+ self.signal = signal
+ self.info = info
+
+
+class ActorSignalReplyMsg:
+ def __init__(self, actor_name, signal, status, info):
+ self.actor_name = actor_name
+ self.signal = signal
+ self.status = status
+ self.info = info
+
+
+class BrokerInfoMsg:
+ def __init__(self, name, pub_port, sub_port, info):
+ self.name = name
+ self.pub_port = pub_port
+ self.sub_port = sub_port
+ self.info = info
+
+
+class BrokerInfoReplyMsg:
+ def __init__(self, name, status, info):
+ self.name = name
+ self.status = status
+ self.info = info
+
+
+class LogInfoMsg:
+ def __init__(self, name, pull_port, pub_port, info):
+ self.name = name
+ self.pull_port = pull_port
+ self.pub_port = pub_port
+ self.info = info
+
+
+class LogInfoReplyMsg:
+ def __init__(self, name, status, info):
+ self.name = name
+ self.status = status
+ self.info = info
+
+
+class HarvesterInfoMsg:
+ def __init__(self, name, info):
+ self.name = name
+ self.info = info
+
+
+class HarvesterInfoReplyMsg:
+ def __init__(self, name, status, info):
+ self.name = name
+ self.status = status
+ self.info = info
diff --git a/improv/nexus.py b/improv/nexus.py
index b5fe19c4..8089f666 100644
--- a/improv/nexus.py
+++ b/improv/nexus.py
@@ -1,4 +1,6 @@
-import os
+from __future__ import annotations
+
+import multiprocessing
import time
import uuid
import signal
@@ -7,46 +9,118 @@
import concurrent
import subprocess
-from queue import Full
from datetime import datetime
-from multiprocessing import Process, get_context
+from multiprocessing import get_context
from importlib import import_module
+import zmq as zmq_sync
import zmq.asyncio as zmq
-from zmq import PUB, REP, SocketOption
-
-from improv.store import StoreInterface, RedisStoreInterface, PlasmaStoreInterface
-from improv.actor import Signal
+from zmq import PUB, REP, REQ, SocketOption
+
+from improv import log
+from improv.broker import bootstrap_broker
+from improv.harvester import bootstrap_harvester
+from improv.log import bootstrap_log_server
+from improv.messaging import (
+ ActorStateMsg,
+ ActorStateReplyMsg,
+ ActorSignalMsg,
+ BrokerInfoReplyMsg,
+ BrokerInfoMsg,
+ LogInfoMsg,
+ LogInfoReplyMsg,
+ HarvesterInfoMsg,
+ HarvesterInfoReplyMsg,
+)
+from improv.store import StoreInterface, RedisStoreInterface
+from improv.actor import Signal, Actor, LinkInfo
from improv.config import Config
-from improv.link import Link, MultiLink
+
+ASYNC_DEBUG = False
logger = logging.getLogger(__name__)
-logger.setLevel(logging.DEBUG)
+logger.setLevel(logging.INFO)
+if logger.level == logging.DEBUG:
+ logger.addHandler(logging.StreamHandler())
+
+# TODO: redo docsctrings since things are pretty different now
+
+# TODO: socket setup can fail - need to check it
+
+# TODO: redo how actors register with nexus so that we get actor states
+# earlier
+
+
+class ConfigFileNotProvidedException(Exception):
+ def __init__(self):
+ super().__init__("Config file not provided")
-# TODO: Set up store.notify in async function (?)
+
+class ConfigFileNotValidException(Exception):
+ def __init__(self):
+ super().__init__("Config file not valid")
+
+
+class ActorState:
+ def __init__(
+ self, actor_name, status, nexus_in_port, hostname="localhost", sig_socket=None
+ ):
+ self.actor_name = actor_name
+ self.status = status
+ self.nexus_in_port = nexus_in_port
+ self.hostname = hostname
+ self.sig_socket = None
class Nexus:
"""Main server class for handling objects in improv"""
def __init__(self, name="Server"):
+ self.logger_in_port: int | None = None
+ self.zmq_sync_context: zmq_sync.Context | None = None
+ self.logfile: str | None = None
+ self.p_harvester: multiprocessing.Process | None = None
+ self.p_broker: multiprocessing.Process | None = None
+ self.actor_in_socket_port: int | None = None
+ self.actor_in_socket: zmq.Socket | None = None
+ self.in_socket: zmq.Socket | None = None
+ self.out_socket: zmq.Socket | None = None
+ self.zmq_context: zmq.Context | None = None
+ self.logger_pub_port: int | None = None
+ self.logger_pull_port: int | None = None
+ self.logger_in_socket: zmq.Socket | None = None
+ self.p_logger: multiprocessing.Process | None = None
+ self.broker_pub_port = None
+ self.broker_sub_port = None
+ self.broker_in_port: int | None = None
+ self.broker_in_socket: zmq_sync.Socket | None = None
+ self.actor_states: dict[str, ActorState | None] = dict()
self.redis_fsync_frequency = None
self.store = None
self.config = None
self.name = name
self.aof_dir = None
self.redis_saving_enabled = False
+ self.allow_setup = False
+ self.outgoing_topics = dict()
+ self.incoming_topics = dict()
+ self.data_queues = {}
+ self.actors = {}
+ self.flags = {}
+ self.processes: list[multiprocessing.Process] = []
def __str__(self):
return self.name
- def createNexus(
+ def create_nexus(
self,
file=None,
- use_watcher=None,
- store_size=10_000_000,
- control_port=0,
- output_port=0,
+ store_size=None,
+ control_port=None,
+ output_port=None,
+ log_server_pub_port=None,
+ actor_in_port=None,
+ logfile="global.log",
):
"""Function to initialize class variables based on config file.
@@ -56,101 +130,68 @@ def createNexus(
Args:
file (string): Name of the config file.
- use_watcher (bool): Whether to use watcher for the store.
store_size (int): initial store size
control_port (int): port number for input socket
output_port (int): port number for output socket
+ actor_in_port (int): port number for the socket which receives
+ actor communications
Returns:
string: "Shutting down", to notify start() that pollQueues has completed.
"""
+ self.logfile = logfile
+
curr_dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
logger.info(f"************ new improv server session {curr_dt} ************")
if file is None:
logger.exception("Need a config file!")
- raise Exception # TODO
- else:
- logger.info(f"Loading configuration file {file}:")
- self.loadConfig(file=file)
- with open(file, "r") as f: # write config file to log
- logger.info(f.read())
-
- # set config options loaded from file
- # in Python 3.9, can just merge dictionaries using precedence
- cfg = self.config.settings
- if "use_watcher" not in cfg:
- cfg["use_watcher"] = use_watcher
- if "store_size" not in cfg:
- cfg["store_size"] = store_size
- if "control_port" not in cfg or control_port != 0:
- cfg["control_port"] = control_port
- if "output_port" not in cfg or output_port != 0:
- cfg["output_port"] = output_port
-
- # set up socket in lieu of printing to stdout
- self.zmq_context = zmq.Context()
- self.zmq_context.setsockopt(SocketOption.LINGER, 1)
- self.out_socket = self.zmq_context.socket(PUB)
- self.out_socket.bind("tcp://*:%s" % cfg["output_port"])
- out_port_string = self.out_socket.getsockopt_string(SocketOption.LAST_ENDPOINT)
- cfg["output_port"] = int(out_port_string.split(":")[-1])
-
- self.in_socket = self.zmq_context.socket(REP)
- self.in_socket.bind("tcp://*:%s" % cfg["control_port"])
- in_port_string = self.in_socket.getsockopt_string(SocketOption.LAST_ENDPOINT)
- cfg["control_port"] = int(in_port_string.split(":")[-1])
-
- self.configure_redis_persistence()
+ raise ConfigFileNotProvidedException
- # default size should be system-dependent
- if self.config and self.config.use_plasma():
- self._startStoreInterface(store_size)
- else:
- self._startStoreInterface(store_size)
- logger.info("Redis server started")
-
- self.out_socket.send_string("StoreInterface started")
+ logger.info(f"Loading configuration file {file}:")
+ self.config = Config(config_file=file)
+ self.config.parse_config()
- # connect to store and subscribe to notifications
- logger.info("Create new store object")
- if self.config and self.config.use_plasma():
- self.store = PlasmaStoreInterface(store_loc=self.store_loc)
- else:
- self.store = StoreInterface(server_port_num=self.store_port)
- logger.info(f"Redis server connected on port {self.store_port}")
+ with open(file, "r") as f: # write config file to log
+ logger.info(f.read())
- self.store.subscribe()
+ logger.debug("Applying CLI parameter configuration overrides")
+ self.apply_cli_config_overrides(
+ store_size=store_size,
+ control_port=control_port,
+ output_port=output_port,
+ actor_in_port=actor_in_port,
+ )
- # TODO: Better logic/flow for using watcher as an option
- self.p_watch = None
- if cfg["use_watcher"]:
- self.startWatcher()
+ logger.debug("Setting up sockets")
+ self.set_up_sockets(actor_in_port=self.config.settings["actor_in_port"])
- # Create dicts for reading config and creating actors
- self.comm_queues = {}
- self.sig_queues = {}
- self.data_queues = {}
- self.actors = {}
- self.flags = {}
- self.processes = []
+ logger.debug("Setting up services")
+ self.start_improv_services(
+ log_server_pub_port=log_server_pub_port,
+ store_size=self.config.settings["store_size"],
+ )
- self.initConfig()
+ logger.debug("initializing config")
+ self.init_config()
self.flags.update({"quit": False, "run": False, "load": False})
self.allowStart = False
self.stopped = False
- return (cfg["control_port"], cfg["output_port"])
-
- def loadConfig(self, file):
- """Load configuration file.
- file: a YAML configuration file name
- """
- self.config = Config(configFile=file)
+ logger.info(
+ f"control: {self.config.settings['control_port']},"
+ f" output: {self.config.settings['output_port']},"
+ f" logging: {self.logger_pub_port}"
+ )
+ return (
+ self.config.settings["control_port"],
+ self.config.settings["output_port"],
+ self.logger_pub_port,
+ )
- def initConfig(self):
+ def init_config(self):
"""For each connection:
create a Link with a name (purpose), start, and end
Start links to one actor's name, end to the other.
@@ -168,121 +209,74 @@ def initConfig(self):
"""
# TODO load from file or user input, as in dialogue through FrontEnd?
- flag = self.config.createConfig()
+ logger.info("initializing config")
+ flag = self.config.create_config()
if flag == -1:
logger.error(
"An error occurred when loading the configuration file. "
"Please see the log file for more details."
)
+ self.destroy_nexus()
+ raise ConfigFileNotValidException
# create all data links requested from Config config
- self.createConnections()
-
- if self.config.hasGUI:
- # Have to load GUI first (at least with Caiman)
- name = self.config.gui.name
- m = self.config.gui # m is ConfigModule
- # treat GUI uniquely since user communication comes from here
- try:
- visualClass = m.options["visual"]
- # need to instantiate this actor
- visualActor = self.config.actors[visualClass]
- self.createActor(visualClass, visualActor)
- # then add links for visual
- for k, l in {
- key: self.data_queues[key]
- for key in self.data_queues.keys()
- if visualClass in key
- }.items():
- self.assignLink(k, l)
-
- # then give it to our GUI
- self.createActor(name, m)
- self.actors[name].setup(visual=self.actors[visualClass])
-
- self.p_GUI = Process(target=self.actors[name].run, name=name)
- self.p_GUI.daemon = True
- self.p_GUI.start()
-
- except Exception as e:
- logger.error(f"Exception in setting up GUI {name}: {e}")
-
- else:
- # have fake GUI for communications
- q_comm = Link("GUI_comm", "GUI", self.name)
- self.comm_queues.update({q_comm.name: q_comm})
+ self.create_connections()
+
+ # if self.config.hasGUI:
+ # # Have to load GUI first (at least with Caiman)
+ # name = self.config.gui.name
+ # m = self.config.gui # m is ConfigModule
+ # # treat GUI uniquely since user communication comes from here
+ # try:
+ # visualClass = m.options["visual"]
+ # # need to instantiate this actor
+ # visualActor = self.config.actors[visualClass]
+ # self.create_actor(visualClass, visualActor)
+ # # then add links for visual
+ # for k, l in {
+ # key: self.data_queues[key]
+ # for key in self.data_queues.keys()
+ # if visualClass in key
+ # }.items():
+ # self.assign_link(k, l)
+ #
+ # # then give it to our GUI
+ # self.create_actor(name, m)
+ # self.actors[name].setup(visual=self.actors[visualClass])
+ #
+ # self.p_GUI = Process(target=self.actors[name].run, name=name)
+ # self.p_GUI.daemon = True
+ # self.p_GUI.start()
+ #
+ # except Exception as e:
+ # logger.error(f"Exception in setting up GUI {name}: {e}")
# First set up each class/actor
for name, actor in self.config.actors.items():
if name not in self.actors.keys():
# Check for actors being instantiated twice
try:
- self.createActor(name, actor)
+ self.create_actor(name, actor)
logger.info(f"Setting up actor {name}")
except Exception as e:
logger.error(f"Exception in setting up actor {name}: {e}.")
self.quit()
+ raise e
# Second set up each connection b/t actors
# TODO: error handling for if a user tries to use q_in without defining it
- for name, link in self.data_queues.items():
- self.assignLink(name, link)
-
- if self.config.settings["use_watcher"]:
- watchin = []
- for name in self.config.settings["use_watcher"]:
- watch_link = Link(name + "_watch", name, "Watcher")
- self.assignLink(name + ".watchout", watch_link)
- watchin.append(watch_link)
- self.createWatcher(watchin)
+ # for name, link in self.data_queues.items():
+ # self.assign_link(name, link)
def configure_redis_persistence(self):
# invalid configs: specifying filename and using an ephemeral filename,
# specifying that saving is off but providing either filename option
- aof_dirname = self.config.get_redis_aof_dirname()
- generate_unique_dirname = self.config.generate_ephemeral_aof_dirname()
- redis_saving_enabled = self.config.redis_saving_enabled()
- redis_fsync_frequency = self.config.get_redis_fsync_frequency()
-
- if aof_dirname and generate_unique_dirname:
- logger.error(
- "Cannot both generate a unique dirname and use the one provided."
- )
- raise Exception("Cannot use unique dirname and use the one provided.")
-
- if aof_dirname or generate_unique_dirname or redis_fsync_frequency:
- if redis_saving_enabled is None:
- redis_saving_enabled = True
- elif not redis_saving_enabled:
- logger.error(
- "Invalid configuration. Cannot save to disk with saving disabled."
- )
- raise Exception("Cannot persist to disk with saving disabled.")
-
- self.redis_saving_enabled = redis_saving_enabled
-
- if redis_fsync_frequency and redis_fsync_frequency not in [
- "every_write",
- "every_second",
- "no_schedule",
- ]:
- logger.error("Cannot use unknown fsync frequency ", redis_fsync_frequency)
- raise Exception(
- "Cannot use unknown fsync frequency ", redis_fsync_frequency
- )
-
- if redis_fsync_frequency is None:
- redis_fsync_frequency = "no_schedule"
-
- if redis_fsync_frequency == "every_write":
- self.redis_fsync_frequency = "always"
- elif redis_fsync_frequency == "every_second":
- self.redis_fsync_frequency = "everysec"
- elif redis_fsync_frequency == "no_schedule":
- self.redis_fsync_frequency = "no"
- else:
- logger.error("Unknown fsync frequency ", redis_fsync_frequency)
- raise Exception("Unknown fsync frequency ", redis_fsync_frequency)
+ aof_dirname = self.config.redis_config["aof_dirname"]
+ generate_unique_dirname = self.config.redis_config[
+ "generate_ephemeral_aof_dirname"
+ ]
+ self.redis_saving_enabled = self.config.redis_config["enable_saving"]
+ self.redis_fsync_frequency = self.config.redis_config["fsync_frequency"]
if aof_dirname:
self.aof_dir = aof_dirname
@@ -307,7 +301,7 @@ def configure_redis_persistence(self):
return
- def startNexus(self):
+ def start_nexus(self, serve_function, *args, **kwargs):
"""
Puts all actors in separate processes and begins polling
to listen to comm queues
@@ -321,21 +315,21 @@ def startNexus(self):
p = ctx.Process(target=m.run, name=name)
else:
ctx = get_context("fork")
- p = ctx.Process(target=self.runActor, name=name, args=(m,))
- if "Watcher" not in name:
- if "daemon" in self.config.actors[name].options:
- p.daemon = self.config.actors[name].options["daemon"]
- logger.info("Setting daemon for {}".format(name))
- else:
- p.daemon = True # default behavior
+ p = ctx.Process(target=self.run_actor, name=name, args=(m,))
+ if "daemon" in self.config.actors[name].options:
+ p.daemon = self.config.actors[name].options["daemon"]
+ logger.info("Setting daemon for {}".format(name))
+ else:
+ p.daemon = True # default behavior
self.processes.append(p)
self.start()
loop = asyncio.get_event_loop()
+ res = ""
try:
self.out_socket.send_string("Awaiting input:")
- res = loop.run_until_complete(self.pollQueues())
+ res = loop.run_until_complete(self.serve(serve_function, *args, **kwargs))
except asyncio.CancelledError:
logger.info("Loop is cancelled")
@@ -346,10 +340,9 @@ def startNexus(self):
logger.info(f"Current loop: {asyncio.get_event_loop()}")
- loop.stop()
- loop.close()
+ # loop.stop()
+ # loop.close()
logger.info("Shutdown loop")
- self.zmq_context.destroy()
def start(self):
"""
@@ -363,30 +356,41 @@ def start(self):
logger.info("All processes started")
- def destroyNexus(self):
+ def destroy_nexus(self):
"""Method that calls the internal method
- to kill the process running the store (plasma server)
+ to kill the processes running the store
+ and the message broker
"""
logger.warning("Destroying Nexus")
- self._closeStoreInterface()
+ self._shutdown_harvester()
+ self._close_store_interface()
- if hasattr(self, "store_loc"):
- try:
- os.remove(self.store_loc)
- except FileNotFoundError:
- logger.warning(
- "StoreInterface file {} is already deleted".format(self.store_loc)
- )
- logger.warning("Delete the store at location {0}".format(self.store_loc))
-
- if hasattr(self, "out_socket"):
+ if self.out_socket:
self.out_socket.close(linger=0)
- if hasattr(self, "in_socket"):
+ if self.in_socket:
self.in_socket.close(linger=0)
- if hasattr(self, "zmq_context"):
+ if self.actor_in_socket:
+ self.actor_in_socket.close(linger=0)
+ if self.broker_in_socket:
+ self.broker_in_socket.close(linger=0)
+ if self.logger_in_socket:
+ self.logger_in_socket.close(linger=0)
+
+ self._shutdown_broker()
+ self._shutdown_logger()
+
+ for handler in logger.handlers:
+ # need to close this one since we're about to destroy the zmq context
+ if isinstance(handler, log.ZmqLogHandler):
+ handler.close()
+ logger.removeHandler(handler)
+
+ if self.zmq_context:
self.zmq_context.destroy(linger=0)
+ if self.zmq_sync_context:
+ self.zmq_sync_context.destroy(linger=0)
- async def pollQueues(self):
+ async def poll_queues(self, poll_function, *args, **kwargs):
"""
Listens to links and processes their signals.
@@ -400,6 +404,7 @@ async def pollQueues(self):
string: "Shutting down", Notifies start() that pollQueues has completed.
"""
self.actorStates = dict.fromkeys(self.actors.keys())
+ self.actor_states = dict.fromkeys(self.actors.keys(), None)
if not self.config.hasGUI:
# Since Visual is not started, it cannot send a ready signal.
try:
@@ -407,13 +412,11 @@ async def pollQueues(self):
except Exception as e:
logger.info("Visual is not started: {0}".format(e))
pass
- polling = list(self.comm_queues.values())
- pollingNames = list(self.comm_queues.keys())
- self.tasks = []
- for q in polling:
- self.tasks.append(asyncio.create_task(q.get_async()))
+ self.tasks = []
+ self.tasks.append(asyncio.create_task(self.process_actor_message()))
self.tasks.append(asyncio.create_task(self.remote_input()))
+
self.early_exit = False
# add signal handlers
@@ -421,40 +424,17 @@ async def pollQueues(self):
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
for s in signals:
loop.add_signal_handler(
- s, lambda s=s: self.stop_polling_and_quit(s, polling)
+ s, lambda s=s: asyncio.create_task(self.stop_polling_and_quit(s))
)
+ logger.info("Nexus signal handler added")
+
while not self.flags["quit"]:
- try:
- done, pending = await asyncio.wait(
- self.tasks, return_when=concurrent.futures.FIRST_COMPLETED
- )
- except asyncio.CancelledError:
- pass
+ await poll_function(*args, **kwargs)
- # sort through tasks to see where we got input from
- # (so we can choose a handler)
- for i, t in enumerate(self.tasks):
- if i < len(polling):
- if t in done or polling[i].status == "done":
- # catch tasks that complete await wait/gather
- r = polling[i].result
- if r:
- if "GUI" in pollingNames[i]:
- self.processGuiSignal(r, pollingNames[i])
- else:
- self.processActorSignal(r, pollingNames[i])
- self.tasks[i] = asyncio.create_task(polling[i].get_async())
- elif t in done:
- logger.debug("t.result = " + str(t.result()))
- self.tasks[i] = asyncio.create_task(self.remote_input())
-
- if not self.early_exit: # don't run this again if we already have
- self.stop_polling(Signal.quit(), polling)
- logger.warning("Shutting down polling")
return "Shutting Down"
- def stop_polling_and_quit(self, signal, queues):
+ async def stop_polling_and_quit(self, signal):
"""
quit the process and stop polling signals from queues
@@ -463,26 +443,108 @@ def stop_polling_and_quit(self, signal, queues):
One of: signal.SIGHUP, signal.SIGTERM, signal.SIGINT
queues (improv.link.AsyncQueue): Comm queues for links.
"""
- logger.warn(
+ logger.warning(
"Shutting down via signal handler due to {}. \
Steps may be out of order or dirty.".format(
signal
)
)
- self.stop_polling(signal, queues)
+ await self.stop_polling(signal)
+ logger.info("Nexus waiting for async tasks to have a chance to send")
+ await asyncio.sleep(0)
self.flags["quit"] = True
self.early_exit = True
self.quit()
+ def process_actor_state_update(self, msg: ActorStateMsg):
+ actor_state = None
+ if msg.actor_name in self.actor_states.keys():
+ actor_state = self.actor_states[msg.actor_name]
+ if not actor_state:
+ logger.info(
+ f"Received state message from new actor {msg.actor_name}"
+ f" with info: {msg.info}\n"
+ )
+ self.actor_states[msg.actor_name] = ActorState(
+ msg.actor_name, msg.status, msg.nexus_in_port
+ )
+ actor_state = self.actor_states[msg.actor_name]
+ else:
+ logger.info(
+ f"Received state message from actor {msg.actor_name}"
+ f" with info: {msg.info}\n"
+ f"Current state:\n"
+ f"Name: {actor_state.actor_name}\n"
+ f"Status: {actor_state.status}\n"
+ f"Nexus control port: {actor_state.nexus_in_port}\n"
+ )
+ if msg.nexus_in_port != actor_state.nexus_in_port:
+ pass
+ # TODO: this actor's signal socket changed
+
+ actor_state.actor_name = msg.actor_name
+ actor_state.status = msg.status
+ actor_state.nexus_in_port = msg.nexus_in_port
+
+ logger.info(
+ "Updated actor state:\n"
+ f"Name: {actor_state.actor_name}\n"
+ f"Status: {actor_state.status}\n"
+ f"Nexus control port: {actor_state.nexus_in_port}\n"
+ )
+
+ if all(
+ [
+ actor_state is not None and actor_state.status == Signal.ready()
+ for actor_state in self.actor_states.values()
+ ]
+ ):
+ self.allowStart = True
+
+ return True
+
+ async def process_actor_message(self):
+ msg = await self.actor_in_socket.recv_pyobj()
+ if isinstance(msg, ActorStateMsg):
+ if self.process_actor_state_update(msg):
+ await self.actor_in_socket.send_pyobj(
+ ActorStateReplyMsg(
+ msg.actor_name, "OK", "actor state updated successfully"
+ )
+ )
+ else:
+ await self.actor_in_socket.send_pyobj(
+ ActorStateReplyMsg(
+ msg.actor_name, "ERROR", "actor state update failed"
+ )
+ )
+ if (not self.allow_setup) and all(
+ [
+ actor_state is not None and actor_state.status == Signal.waiting()
+ for actor_state in self.actor_states.values()
+ ]
+ ):
+ logger.info("All actors connected to Nexus. Allowing setup.")
+ self.allow_setup = True
+
+ if (not self.allowStart) and all(
+ [
+ actor_state is not None and actor_state.status == Signal.ready()
+ for actor_state in self.actor_states.values()
+ ]
+ ):
+ logger.info("All actors ready. Allowing run.")
+ self.allowStart = True
+
async def remote_input(self):
msg = await self.in_socket.recv_multipart()
command = msg[0].decode("utf-8")
await self.in_socket.send_string("Awaiting input:")
if command == Signal.quit():
await self.out_socket.send_string("QUIT")
- self.processGuiSignal([command], "TUI_Nexus")
+ await self.process_gui_signal([command], "TUI_Nexus")
- def processGuiSignal(self, flag, name):
+ async def process_gui_signal(self, flag, name):
"""Receive flags from the Front End as user input"""
name = name.split("_")[0]
if flag:
@@ -490,20 +552,24 @@ def processGuiSignal(self, flag, name):
if flag[0] == Signal.run():
logger.info("Begin run!")
# self.flags['run'] = True
- self.run()
+ await self.run()
elif flag[0] == Signal.setup():
logger.info("Running setup")
- self.setup()
+ await self.setup()
elif flag[0] == Signal.ready():
logger.info("GUI ready")
self.actorStates[name] = flag[0]
elif flag[0] == Signal.quit():
logger.warning("Quitting the program!")
+ task = asyncio.create_task(self.stop_polling_and_quit(Signal.quit()))
+ done, pending = await asyncio.wait(task)
+ while len(done) == 0:
+ done, pending = await asyncio.wait(task)
self.flags["quit"] = True
- self.quit()
elif flag[0] == Signal.load():
logger.info("Loading Config config from file " + flag[1])
- self.loadConfig(flag[1])
+ self.config = Config(flag[1])
+ self.config.parse_config()
elif flag[0] == Signal.pause():
logger.info("Pausing processes")
# TODO. Also resume, reset
@@ -526,20 +592,19 @@ def processGuiSignal(self, flag, name):
p = ctx.Process(target=m.run, name=name)
else:
ctx = get_context("fork")
- p = ctx.Process(target=self.runActor, name=name, args=(m,))
- if "Watcher" not in name:
- if "daemon" in actor.options:
- p.daemon = actor.options["daemon"]
- logger.info("Setting daemon for {}".format(name))
- else:
- p.daemon = True
+ p = ctx.Process(target=self.run_actor, name=name, args=(m,))
+ if "daemon" in actor.options:
+ p.daemon = actor.options["daemon"]
+ logger.info("Setting daemon for {}".format(name))
+ else:
+ p.daemon = True
# Setting the stores for each actor to be the same
# TODO: test if this works for fork -- don't think it does?
al = [act for act in self.actors.values() if act.name != pro.name]
- m.setStoreInterface(al[0].client)
+ m.set_store_interface(al[0].client)
m.client = None
- m._getStoreInterface()
+ m._get_store_interface()
self.processes.append(p)
p.start()
@@ -550,45 +615,33 @@ def processGuiSignal(self, flag, name):
self.processes = [p for p in list(self.processes) if p.exitcode is None]
elif flag[0] == Signal.stop():
logger.info("Nexus received stop signal")
- self.stop()
+ await self.stop()
elif flag:
logger.error("Unknown signal received from Nexus: {}".format(flag))
- def processActorSignal(self, sig, name):
- if sig is not None:
- logger.info("Received signal " + str(sig[0]) + " from " + name)
- state_val = self.actorStates.values()
- if not self.stopped and sig[0] == Signal.ready():
- self.actorStates[name.split("_")[0]] = sig[0]
- if all(val == Signal.ready() for val in state_val):
- self.allowStart = True
- # TODO: replace with q_sig to FE/Visual
- logger.info("Allowing start")
-
- elif self.stopped and sig[0] == Signal.stop_success():
- self.actorStates[name.split("_")[0]] = sig[0]
- if all(val == Signal.stop_success() for val in state_val):
- self.allowStart = True # TODO: replace with q_sig to FE/Visual
- self.stoppped = False
- logger.info("All stops were successful. Allowing start.")
-
- def setup(self):
- for q in self.sig_queues.values():
- try:
- logger.info("Starting setup: " + str(q))
- q.put_nowait(Signal.setup())
- except Full:
- logger.warning("Signal queue" + q.name + "is full")
+ async def setup(self):
+ if not self.allow_setup:
+ logger.error(
+ "Not all actors connected to Nexus. Please wait, then try again."
+ )
+ return
+
+ for actor in self.actor_states.values():
+ logger.info("Starting setup: " + str(actor.actor_name))
+ actor.sig_socket = self.zmq_context.socket(REQ)
+ actor.sig_socket.connect(f"tcp://{actor.hostname}:{actor.nexus_in_port}")
+ await actor.sig_socket.send_pyobj(
+ ActorSignalMsg(actor.actor_name, Signal.setup(), "")
+ )
+ await actor.sig_socket.recv_pyobj()
- def run(self):
+ async def run(self):
if self.allowStart:
- for q in self.sig_queues.values():
- try:
- q.put_nowait(Signal.run())
- except Full:
- logger.warning("Signal queue" + q.name + "is full")
- # queue full, keep going anyway
- # TODO: add repeat trying as async task
+ for actor in self.actor_states.values():
+ await actor.sig_socket.send_pyobj(
+ ActorSignalMsg(actor.actor_name, Signal.run(), "")
+ )
+ await actor.sig_socket.recv_pyobj()
else:
logger.error("Not all actors ready yet, please wait and then try again.")
@@ -596,43 +649,54 @@ def quit(self):
logger.warning("Killing child processes")
self.out_socket.send_string("QUIT")
- for q in self.sig_queues.values():
- try:
- q.put_nowait(Signal.quit())
- except Full:
- logger.warning("Signal queue {} full, cannot quit".format(q.name))
- except FileNotFoundError:
- logger.warning("Queue {} corrupted.".format(q.name))
-
if self.config.hasGUI:
self.processes.append(self.p_GUI)
- if self.p_watch:
- self.processes.append(self.p_watch)
-
for p in self.processes:
p.terminate()
- p.join()
+ p.join(timeout=5)
+ if p.exitcode is None:
+ p.kill()
+ logger.error("Process did not exit in time. Kill signal sent.")
logger.warning("Actors terminated")
- self.destroyNexus()
+ self.destroy_nexus()
- def stop(self):
+ async def stop(self):
logger.warning("Starting stop procedure")
self.allowStart = False
- for q in self.sig_queues.values():
+ for actor in self.actor_states.values():
try:
- q.put_nowait(Signal.stop())
- except Full:
- logger.warning("Signal queue" + q.name + "is full")
+ await actor.sig_socket.send_pyobj(
+ ActorSignalMsg(
+ actor.actor_name, Signal.stop(), "Nexus sending stop signal"
+ )
+ )
+ msg_ready = await actor.sig_socket.poll(timeout=1000)
+ if msg_ready == 0:
+ raise TimeoutError
+ await actor.sig_socket.recv_pyobj()
+ except TimeoutError:
+ logger.info(
+ f"Timed out waiting for reply to stop message "
+ f"from actor {actor.actor_name}. "
+ f"Closing connection."
+ )
+ actor.sig_socket.close(linger=0)
+ except Exception as e:
+ logger.info(
+ f"Unable to send stop message "
+ f"to actor {actor.actor_name}: "
+ f"{e}"
+ )
self.allowStart = True
def revive(self):
logger.warning("Starting revive")
- def stop_polling(self, stop_signal, queues):
+ async def stop_polling(self, stop_signal):
"""Cancels outstanding tasks and fills their last request.
Puts a string into all active queues, then cancels their
@@ -647,11 +711,30 @@ def stop_polling(self, stop_signal, queues):
logger.info(f"Stop signal: {stop_signal}")
shutdown_message = Signal.quit()
- for q in queues:
+ for actor in self.actor_states.values():
try:
- q.put(shutdown_message)
- except Exception:
- logger.info("Unable to send shutdown message to {}.".format(q.name))
+ await actor.sig_socket.send_pyobj(
+ ActorSignalMsg(
+ actor.actor_name, shutdown_message, "Nexus sending quit signal"
+ )
+ )
+ msg_ready = await actor.sig_socket.poll(timeout=1000)
+ if msg_ready == 0:
+ raise TimeoutError
+ await actor.sig_socket.recv_pyobj()
+ except TimeoutError:
+ logger.info(
+ f"Timed out waiting for reply to quit message "
+ f"from actor {actor.actor_name}. "
+ f"Closing connection."
+ )
+ actor.sig_socket.close(linger=0)
+ except Exception as e:
+ logger.info(
+ f"Unable to send shutdown message "
+ f"to actor {actor.actor_name}: "
+ f"{e}"
+ )
logger.info("Canceling outstanding tasks")
@@ -659,19 +742,14 @@ def stop_polling(self, stop_signal, queues):
logger.info("Polling has stopped.")
- def createStoreInterface(self, name):
+ def create_store_interface(self):
"""Creates StoreInterface"""
- if self.config.use_plasma():
- return PlasmaStoreInterface(name, self.store_loc)
- else:
- return RedisStoreInterface(server_port_num=self.store_port)
+ return RedisStoreInterface(server_port_num=self.store_port)
- def _startStoreInterface(self, size, attempts=20):
- """Start a subprocess that runs the plasma store
- Raises a RuntimeError exception size is undefined
- Raises an Exception if the plasma store doesn't start
-
- #TODO: Generalize this to non-plasma stores
+ def _start_store_interface(self, size, attempts=20):
+ """Start a subprocess that runs the redis store
+ Raises a RuntimeError exception if size is undefined
+ Raises an Exception if the redis store doesn't start
Args:
size: in bytes
@@ -683,64 +761,27 @@ def _startStoreInterface(self, size, attempts=20):
"""
if size is None:
raise RuntimeError("Server size needs to be specified")
- self.use_plasma = False
- if self.config and self.config.use_plasma():
- self.use_plasma = True
- self.store_loc = str(os.path.join("/tmp/", str(uuid.uuid4())))
- self.p_StoreInterface = subprocess.Popen(
- [
- "plasma_store",
- "-s",
- self.store_loc,
- "-m",
- str(size),
- "-e",
- "hashtable://test",
- ],
- stdout=subprocess.DEVNULL,
- stderr=subprocess.DEVNULL,
- )
- logger.info("StoreInterface start successful: {}".format(self.store_loc))
- else:
- logger.info("Setting up Redis store.")
- self.store_port = (
- self.config.get_redis_port()
- if self.config and self.config.redis_port_specified()
- else Config.get_default_redis_port()
+
+ logger.info("Setting up Redis store.")
+ self.store_port = self.config.redis_config["port"] if self.config else 6379
+ logger.info("Searching for open port starting at specified port.")
+ for attempt in range(attempts):
+ logger.info(
+ "Attempting to connect to Redis on port {}".format(self.store_port)
)
- if self.config and self.config.redis_port_specified():
- logger.info(
- "Attempting to connect to Redis on port {}".format(self.store_port)
- )
- # try with failure, incrementing port number
- self.p_StoreInterface = self.start_redis(size)
- time.sleep(3)
- if self.p_StoreInterface.poll():
- logger.error("Could not start Redis on specified port number.")
- raise Exception("Could not start Redis on specified port.")
+ # try with failure, incrementing port number
+ self.p_StoreInterface = self.start_redis(size)
+ time.sleep(1)
+ if self.p_StoreInterface.poll(): # Redis could not start
+ logger.info("Could not connect to port {}".format(self.store_port))
+ self.store_port = str(int(self.store_port) + 1)
else:
- logger.info("Redis port not specified. Searching for open port.")
- for attempt in range(attempts):
- logger.info(
- "Attempting to connect to Redis on port {}".format(
- self.store_port
- )
- )
- # try with failure, incrementing port number
- self.p_StoreInterface = self.start_redis(size)
- time.sleep(3)
- if self.p_StoreInterface.poll(): # Redis could not start
- logger.info(
- "Could not connect to port {}".format(self.store_port)
- )
- self.store_port = str(int(self.store_port) + 1)
- else:
- break
- else:
- logger.error("Could not start Redis on any tried port.")
- raise Exception("Could not start Redis on any tried ports.")
+ break
+ else:
+ logger.error("Could not start Redis on any tried port.")
+ raise Exception("Could not start Redis on any tried ports.")
- logger.info(f"StoreInterface start successful on port {self.store_port}")
+ logger.info(f"StoreInterface start successful on port {self.store_port}")
def start_redis(self, size):
subprocess_command = [
@@ -788,26 +829,27 @@ def start_redis(self, size):
stderr=subprocess.DEVNULL,
)
- def _closeStoreInterface(self):
+ def _close_store_interface(self):
"""Internal method to kill the subprocess
- running the store (plasma sever)
+ running the store
"""
if hasattr(self, "p_StoreInterface"):
try:
self.p_StoreInterface.send_signal(signal.SIGINT)
- self.p_StoreInterface.wait()
- logger.info(
- "StoreInterface close successful: {}".format(
- self.store_loc
- if self.config and self.config.use_plasma()
- else self.store_port
+ try:
+ self.p_StoreInterface.wait(timeout=30)
+ logger.info(
+ "StoreInterface close successful: {}".format(self.store_port)
)
- )
+ except subprocess.TimeoutExpired as e:
+ logger.error(e)
+ self.p_StoreInterface.send_signal(signal.SIGKILL)
+ logger.info("Killed datastore process")
except Exception as e:
logger.exception("Cannot close store {}".format(e))
- def createActor(self, name, actor):
+ def create_actor(self, name, actor):
"""Function to instantiate actor, add signal and comm Links,
and update self.actors dictionary
@@ -818,35 +860,46 @@ def createActor(self, name, actor):
# Instantiate selected class
mod = import_module(actor.packagename)
clss = getattr(mod, actor.classname)
- if self.config.use_plasma():
- instance = clss(actor.name, store_loc=self.store_loc, **actor.options)
- else:
- instance = clss(actor.name, store_port_num=self.store_port, **actor.options)
+ outgoing_links = (
+ self.outgoing_topics[actor.name]
+ if actor.name in self.outgoing_topics
+ else []
+ )
+ incoming_links = (
+ self.incoming_topics[actor.name]
+ if actor.name in self.incoming_topics
+ else []
+ )
+ instance = clss(
+ name=actor.name,
+ nexus_comm_port=self.actor_in_socket_port,
+ broker_sub_port=self.broker_sub_port,
+ broker_pub_port=self.broker_pub_port,
+ log_pull_port=self.logger_pull_port,
+ outgoing_links=outgoing_links,
+ incoming_links=incoming_links,
+ store_port_num=self.store_port,
+ **actor.options,
+ )
if "method" in actor.options.keys():
# check for spawn
if "fork" == actor.options["method"]:
# Add link to StoreInterface store
- store = self.createStoreInterface(actor.name)
- instance.setStoreInterface(store)
+ store = self.create_store_interface()
+ instance.set_store_interface(store)
else:
# spawn or forkserver; can't pickle plasma store
logger.info("No store for this actor yet {}".format(name))
else:
# Add link to StoreInterface store
- store = self.createStoreInterface(actor.name)
- instance.setStoreInterface(store)
-
- q_comm = Link(actor.name + "_comm", actor.name, self.name)
- q_sig = Link(actor.name + "_sig", self.name, actor.name)
- self.comm_queues.update({q_comm.name: q_comm})
- self.sig_queues.update({q_sig.name: q_sig})
- instance.setCommLinks(q_comm, q_sig)
+ store = self.create_store_interface()
+ instance.set_store_interface(store)
# Update information
self.actors.update({name: instance})
- def runActor(self, actor):
+ def run_actor(self, actor: Actor):
"""Run the actor continually; used for separate processes
#TODO: hook into monitoring here?
@@ -855,26 +908,45 @@ def runActor(self, actor):
"""
actor.run()
- def createConnections(self):
- """Assemble links (multi or other)
- for later assignment
- """
- for source, drain in self.config.connections.items():
- name = source.split(".")[0]
- # current assumption is connection goes from q_out to something(s) else
- if len(drain) > 1: # we need multiasyncqueue
- link, endLinks = MultiLink(name + "_multi", source, drain)
- self.data_queues.update({source: link})
- for i, e in enumerate(endLinks):
- self.data_queues.update({drain[i]: e})
- else: # single input, single output
- d = drain[0]
- d_name = d.split(".") # TODO: check if .anything, if not assume q_in
- link = Link(name + "_" + d_name[0], source, d)
- self.data_queues.update({source: link})
- self.data_queues.update({d: link})
-
- def assignLink(self, name, link):
+ def create_connections(self):
+ for name, connection in self.config.connections.items():
+ sources = connection["sources"]
+ if not isinstance(sources, list):
+ sources = [sources]
+ sinks = connection["sinks"]
+ if not isinstance(sinks, list):
+ sinks = [sinks]
+
+ name_input = tuple(
+ ["inputs:"]
+ + [name for name in sources]
+ + ["outputs:"]
+ + [name for name in sinks]
+ )
+ name = str(
+ hash(name_input)
+ ) # space-efficient key for uniquely referring to this connection
+ logger.info(f"Created link name {name} for link {name_input}")
+
+ for source in sources:
+ source_actor = source.split(".")[0]
+ source_link = source.split(".")[1]
+ if source_actor not in self.outgoing_topics.keys():
+ self.outgoing_topics[source_actor] = [LinkInfo(source_link, name)]
+ else:
+ self.outgoing_topics[source_actor].append(
+ LinkInfo(source_link, name)
+ )
+
+ for sink in sinks:
+ sink_actor = sink.split(".")[0]
+ sink_link = sink.split(".")[1]
+ if sink_actor not in self.incoming_topics.keys():
+ self.incoming_topics[sink_actor] = [LinkInfo(sink_link, name)]
+ else:
+ self.incoming_topics[sink_actor].append(LinkInfo(sink_link, name))
+
+ def assign_link(self, name, link):
"""Function to set up Links between actors
for data location passing
Actor must already be instantiated
@@ -886,24 +958,283 @@ def assignLink(self, name, link):
classname = name.split(".")[0]
linktype = name.split(".")[1]
if linktype == "q_out":
- self.actors[classname].setLinkOut(link)
+ self.actors[classname].set_link_out(link)
elif linktype == "q_in":
- self.actors[classname].setLinkIn(link)
+ self.actors[classname].set_link_in(link)
elif linktype == "watchout":
- self.actors[classname].setLinkWatch(link)
+ self.actors[classname].set_link_watch(link)
else:
- self.actors[classname].addLink(linktype, link)
+ self.actors[classname].add_link(linktype, link)
+
+ def start_logger(self, log_server_pub_port):
+ spawn_context = get_context("spawn")
+ self.p_logger = spawn_context.Process(
+ target=bootstrap_log_server,
+ args=(
+ "localhost",
+ self.logger_in_port,
+ self.logfile,
+ log_server_pub_port,
+ ),
+ )
+ logger.debug("logger created")
+ self.p_logger.start()
+ time.sleep(1)
+ logger.debug("logger started")
+ if not self.p_logger.is_alive():
+ logger.error(
+ "Logger process failed to start. "
+ "Please see the log server log file for more information. "
+ "The improv server will now exit."
+ )
+ self.quit()
+ raise Exception("Could not start log server.")
+ logger.debug("logger is alive")
+ poll_res = self.logger_in_socket.poll(timeout=5000)
+ if poll_res == 0:
+ logger.error(
+ "Never got reply from logger. Cannot proceed setting up Nexus."
+ )
+ try:
+ with open("log_server.log", "r") as file:
+ logger.debug(file.read())
+ except Exception as e:
+ logger.error(e)
+ self.destroy_nexus()
+ logger.error("exiting after destroy")
+ exit(1)
+ logger_info: LogInfoMsg = self.logger_in_socket.recv_pyobj()
+ self.logger_pull_port = logger_info.pull_port
+ self.logger_pub_port = logger_info.pub_port
+ self.logger_in_socket.send_pyobj(
+ LogInfoReplyMsg(logger_info.name, "OK", "registered logger information")
+ )
+ logger.debug("logger replied with setup message")
- # TODO: StoreInterface access here seems wrong, need to test
- def startWatcher(self):
- from improv.watcher import Watcher
+ def start_message_broker(self):
+ spawn_context = get_context("spawn")
+ self.p_broker = spawn_context.Process(
+ target=bootstrap_broker, args=("localhost", self.broker_in_port)
+ )
+ logger.debug("broker created")
+ self.p_broker.start()
+ time.sleep(1)
+ logger.debug("broker started")
+ if not self.p_broker.is_alive():
+ logger.error(
+ "Broker process failed to start. "
+ "Please see the log file for more information. "
+ "The improv server will now exit."
+ )
+ self.quit()
+ raise Exception("Could not start message broker server.")
+ logger.debug("broker is alive")
+ poll_res = self.broker_in_socket.poll(timeout=5000)
+ if poll_res == 0:
+ logger.error("Never got reply from broker. Cannot proceed.")
+ try:
+ with open("broker_server.log", "r") as file:
+ logger.debug(file.read())
+ except Exception as e:
+ logger.error(e)
+ self.destroy_nexus()
+ logger.debug("exiting after destroy")
+ exit(1)
+ broker_info: BrokerInfoMsg = self.broker_in_socket.recv_pyobj()
+ self.broker_sub_port = broker_info.sub_port
+ self.broker_pub_port = broker_info.pub_port
+ self.broker_in_socket.send_pyobj(
+ BrokerInfoReplyMsg(broker_info.name, "OK", "registered broker information")
+ )
+ logger.debug("broker replied with setup message")
- self.watcher = Watcher("watcher", self.createStoreInterface("watcher"))
- q_sig = Link("watcher_sig", self.name, "watcher")
- self.watcher.setLinks(q_sig)
- self.sig_queues.update({q_sig.name: q_sig})
+ def _shutdown_broker(self):
+ """Internal method to kill the subprocess
+ running the message broker
+ """
+ if self.p_broker:
+ try:
+ self.p_broker.terminate()
+ self.p_broker.join(timeout=5)
+ if self.p_broker.exitcode is None:
+ self.p_broker.kill()
+ logger.error("Killed broker process")
+ else:
+ logger.info(
+ "Broker shutdown successful with exit code {}".format(
+ self.p_broker.exitcode
+ )
+ )
+ except Exception as e:
+ logger.exception(f"Unable to close broker {e}")
+
+ def _shutdown_logger(self):
+ """Internal method to kill the subprocess
+ running the logger
+ """
+ if self.p_logger:
+ try:
+ self.p_logger.terminate()
+ self.p_logger.join(timeout=5)
+ if self.p_logger.exitcode is None:
+ self.p_logger.kill()
+ logger.error("Killed logger process")
+ else:
+ logger.info("Logger shutdown successful")
+ except Exception as e:
+ logger.exception(f"Unable to close logger: {e}")
- self.p_watch = Process(target=self.watcher.run, name="watcher_process")
- self.p_watch.daemon = True
- self.p_watch.start()
- self.processes.append(self.p_watch)
+ def _shutdown_harvester(self):
+ """Internal method to kill the subprocess
+ running the logger
+ """
+ if self.p_harvester:
+ try:
+ self.p_harvester.terminate()
+ self.p_harvester.join(timeout=5)
+ if self.p_harvester.exitcode is None:
+ self.p_harvester.kill()
+ logger.error("Killed harvester process")
+ else:
+ logger.info("Harvester shutdown successful")
+ except Exception as e:
+ logger.exception(f"Unable to close harvester: {e}")
+
+ def set_up_sockets(self, actor_in_port):
+
+ logger.debug("Connecting to output")
+ cfg = self.config.settings # this could be self.settings instead
+ self.zmq_context = zmq.Context()
+ self.zmq_context.setsockopt(SocketOption.LINGER, 0)
+ self.out_socket = self.zmq_context.socket(PUB)
+ self.out_socket.bind("tcp://*:%s" % cfg["output_port"])
+ out_port_string = self.out_socket.getsockopt_string(SocketOption.LAST_ENDPOINT)
+ cfg["output_port"] = int(out_port_string.split(":")[-1])
+
+ logger.debug("Connecting to control")
+ self.in_socket = self.zmq_context.socket(REP)
+ self.in_socket.bind("tcp://*:%s" % cfg["control_port"])
+ in_port_string = self.in_socket.getsockopt_string(SocketOption.LAST_ENDPOINT)
+ cfg["control_port"] = int(in_port_string.split(":")[-1])
+
+ logger.debug("Connecting to actor comm socket")
+ self.actor_in_socket = self.zmq_context.socket(REP)
+ self.actor_in_socket.bind(f"tcp://*:{actor_in_port}")
+ in_port_string = self.actor_in_socket.getsockopt_string(
+ SocketOption.LAST_ENDPOINT
+ )
+ self.actor_in_socket_port = int(in_port_string.split(":")[-1])
+
+ logger.debug("Setting up sync server startup socket")
+ self.zmq_sync_context = zmq_sync.Context()
+ self.zmq_sync_context.setsockopt(SocketOption.LINGER, 0)
+
+ self.logger_in_socket = self.zmq_sync_context.socket(REP)
+ self.logger_in_socket.bind("tcp://*:0")
+ logger_in_port_string = self.logger_in_socket.getsockopt_string(
+ SocketOption.LAST_ENDPOINT
+ )
+ self.logger_in_port = int(logger_in_port_string.split(":")[-1])
+
+ self.broker_in_socket = self.zmq_sync_context.socket(REP)
+ self.broker_in_socket.bind("tcp://*:0")
+ broker_in_port_string = self.broker_in_socket.getsockopt_string(
+ SocketOption.LAST_ENDPOINT
+ )
+ self.broker_in_port = int(broker_in_port_string.split(":")[-1])
+
+ def start_improv_services(self, log_server_pub_port, store_size):
+ logger.debug("Starting logger")
+ self.start_logger(log_server_pub_port)
+ logger.addHandler(
+ log.ZmqLogHandler("localhost", self.logger_pull_port, self.zmq_sync_context)
+ )
+
+ logger.debug("starting broker")
+ self.start_message_broker()
+
+ logger.debug("Parsing redis persistence")
+ self.configure_redis_persistence()
+
+ logger.debug("Starting redis server")
+ # default size should be system-dependent
+ self._start_store_interface(store_size)
+ logger.info("Redis server started")
+
+ self.out_socket.send_string("StoreInterface started")
+
+ if self.config.settings["harvest_data_from_memory"]:
+ logger.debug("starting harvester")
+ self.start_harvester()
+
+ # connect to store and subscribe to notifications
+ logger.info("Create new store object")
+ self.store = StoreInterface(server_port_num=self.store_port)
+ logger.info(f"Redis server connected on port {self.store_port}")
+ logger.info("all services started")
+
+ def apply_cli_config_overrides(
+ self, store_size, control_port, output_port, actor_in_port
+ ):
+ if store_size is not None:
+ self.config.settings["store_size"] = store_size
+ if control_port is not None:
+ self.config.settings["control_port"] = control_port
+ if output_port is not None:
+ self.config.settings["output_port"] = output_port
+ if actor_in_port is not None:
+ self.config.settings["actor_in_port"] = actor_in_port
+
+ def start_harvester(self):
+ spawn_context = get_context("spawn")
+ self.p_harvester = spawn_context.Process(
+ target=bootstrap_harvester,
+ args=(
+ "localhost",
+ self.broker_in_port,
+ "localhost",
+ self.store_port,
+ "localhost",
+ self.broker_pub_port,
+ "localhost",
+ self.logger_pull_port,
+ ),
+ )
+ self.p_harvester.start()
+ time.sleep(1)
+ if not self.p_harvester.is_alive():
+ logger.error(
+ "Harvester process failed to start. "
+ "Please see the log file for more information. "
+ "The improv server will now exit."
+ )
+ self.quit()
+ raise Exception("Could not start harvester server.")
+
+ harvester_info: HarvesterInfoMsg = self.broker_in_socket.recv_pyobj()
+ self.broker_in_socket.send_pyobj(
+ HarvesterInfoReplyMsg(
+ harvester_info.name, "OK", "registered harvester information"
+ )
+ )
+ logger.info("Harvester server started")
+
+ async def serve(self, serve_function, *args, **kwargs):
+ await serve_function(*args, **kwargs)
+
+ async def poll_kernel(self):
+ try:
+ done, pending = await asyncio.wait(
+ self.tasks, return_when=concurrent.futures.FIRST_COMPLETED
+ )
+ except asyncio.CancelledError:
+ pass
+
+ # sort through tasks to see where we got input from
+ # (so we can choose a handler)
+ for i, t in enumerate(self.tasks):
+ if i == 0:
+ if t in done:
+ self.tasks[i] = asyncio.create_task(self.process_actor_message())
+ elif t in done:
+ self.tasks[i] = asyncio.create_task(self.remote_input())
diff --git a/improv/store.py b/improv/store.py
index 98f54a63..86e9cf30 100644
--- a/improv/store.py
+++ b/improv/store.py
@@ -3,21 +3,14 @@
import pickle
import logging
-import traceback
-
-import numpy as np
-import pyarrow.plasma as plasma
+import zlib
from redis import Redis
from redis.retry import Retry
from redis.backoff import ConstantBackoff
from redis.exceptions import BusyLoadingError, ConnectionError, TimeoutError
-from scipy.sparse import csc_matrix
-from pyarrow.lib import ArrowIOError
-from pyarrow._plasma import PlasmaObjectExists, ObjectNotAvailable
-
-REDIS_GLOBAL_TOPIC = "global_topic"
+ZLIB_COMPRESSION_LEVEL = -1
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@@ -43,17 +36,21 @@ def subscribe(self):
class RedisStoreInterface(StoreInterface):
- def __init__(self, name="default", server_port_num=6379, hostname="localhost"):
+ def __init__(
+ self,
+ name="default",
+ server_port_num=6379,
+ hostname="localhost",
+ compression_level=ZLIB_COMPRESSION_LEVEL,
+ ):
self.name = name
self.server_port_num = server_port_num
self.hostname = hostname
self.client = self.connect_to_server()
+ self.compression_level = compression_level
def connect_to_server(self):
- # TODO this should scan for available ports, but only if configured to do so.
- # This happens when the config doesn't have Redis settings,
- # so we need to communicate this somehow to the StoreInterface here.
- """Connect to the store at store_loc, max 20 retries to connect
+ """Connect to the store, max 20 retries to connect
Raises exception if can't connect
Returns the Redis client if successful
@@ -107,18 +104,11 @@ def put(self, object):
object: the object that was a
"""
object_key = str(os.getpid()) + str(uuid.uuid4())
- try:
- # buffers would theoretically go here if we need to force out-of-band
- # serialization for single objects
- # TODO this will actually just silently fail if we use an existing
- # TODO key; not sure it's worth the network overhead to check every
- # TODO key twice every time. we still need a better solution for
- # TODO this, but it will work now singlethreaded most of the time.
-
- self.client.set(object_key, pickle.dumps(object, protocol=5), nx=True)
- except Exception:
- logger.error("Could not store object {}".format(object_key))
- logger.error(traceback.format_exc())
+ data = zlib.compress(
+ pickle.dumps(object, protocol=5), level=self.compression_level
+ )
+
+ self.client.set(object_key, data, nx=True)
return object_key
@@ -138,14 +128,10 @@ def get(self, object_key):
object_value = self.client.get(object_key)
if object_value:
# buffers would also go here to force out-of-band deserialization
- return pickle.loads(object_value)
+ return pickle.loads(zlib.decompress(object_value))
logger.warning("Object {} cannot be found.".format(object_key))
- raise ObjectNotFoundError
-
- def subscribe(self, topic=REDIS_GLOBAL_TOPIC):
- p = self.client.pubsub()
- p.subscribe(topic)
+ raise ObjectNotFoundError(object_key)
def get_list(self, ids):
"""Get multiple objects from the store
@@ -156,7 +142,10 @@ def get_list(self, ids):
Returns:
list of the objects
"""
- return self.client.mget(ids)
+ return [
+ pickle.loads(zlib.decompress(object_value))
+ for object_value in self.client.mget(ids)
+ ]
def get_all(self):
"""Get a listing of all objects in the store.
@@ -166,7 +155,7 @@ def get_all(self):
list of all the objects in the store
"""
all_keys = self.client.keys() # defaults to "*" pattern, so will fetch all
- return self.client.mget(all_keys)
+ return self.get_list(all_keys)
def reset(self):
"""Reset client connection"""
@@ -175,226 +164,6 @@ def reset(self):
"Reset local connection to store on port: {0}".format(self.server_port_num)
)
- def notify(self):
- pass # I don't see any call sites for this, so leaving it blank at the moment
-
-
-class PlasmaStoreInterface(StoreInterface):
- """Basic interface for our specific data store implemented with apache arrow plasma
- Objects are stored with object_ids
- References to objects are contained in a dict where key is shortname,
- value is object_id
- """
-
- def __init__(self, name="default", store_loc="/tmp/store"):
- """
- Constructor for the StoreInterface
-
- :param name:
- :param store_loc: Apache Arrow Plasma client location
- """
-
- self.name = name
- self.store_loc = store_loc
- self.client = self.connect_store(store_loc)
- self.stored = {}
-
- def connect_store(self, store_loc):
- """Connect to the store at store_loc, max 20 retries to connect
- Raises exception if can't connect
- Returns the plasmaclient if successful
- Updates the client internal
-
- Args:
- store_loc: store location
- """
- try:
- self.client = plasma.connect(store_loc, 20)
- logger.info("Successfully connected to store: {} ".format(store_loc))
- except Exception:
- logger.exception("Cannot connect to store: {}".format(store_loc))
- raise CannotConnectToStoreInterfaceError(store_loc)
- return self.client
-
- def put(self, object, object_name):
- """
- Put a single object referenced by its string name
- into the store
- Raises PlasmaObjectExists if we are overwriting
- Unknown error
-
- Args:
- object:
- object_name (str):
- flush_this_immediately (bool):
-
- Returns:
- class 'plasma.ObjectID': Plasma object ID
-
- Raises:
- PlasmaObjectExists: if we are overwriting \
- unknown error
- """
- object_id = None
- try:
- # Need to pickle if object is csc_matrix
- if isinstance(object, csc_matrix):
- prot = pickle.HIGHEST_PROTOCOL
- object_id = self.client.put(pickle.dumps(object, protocol=prot))
- else:
- object_id = self.client.put(object)
-
- except PlasmaObjectExists:
- logger.error("Object already exists. Meant to call replace?")
- except ArrowIOError:
- logger.error("Could not store object {}".format(object_name))
- logger.info("Refreshing connection and continuing")
- self.reset()
- except Exception:
- logger.error("Could not store object {}".format(object_name))
- logger.error(traceback.format_exc())
-
- return object_id
-
- def get(self, object_name):
- """Get a single object from the store by object name
- Checks to see if it knows the object first
- Otherwise throw CannotGetObject to request dict update
- TODO: update for lists of objects
- TODO: replace with getID
-
- Returns:
- Stored object
- """
- # print('trying to get ', object_name)
- # if self.stored.get(object_name) is None:
- # logger.error('Never recorded storing this object: '+object_name)
- # # Don't know anything about this object, treat as problematic
- # raise CannotGetObjectError(query = object_name)
- # else:
- return self.getID(object_name)
-
- def getID(self, obj_id):
- """
- Get object by object ID
-
- Args:
- obj_id (class 'plasma.ObjectID'): the id of the object
-
- Returns:
- Stored object
-
- Raises:
- ObjectNotFoundError: If the id is not found
- """
- res = self.client.get(obj_id, 0) # Timeout = 0 ms
- if res is not plasma.ObjectNotAvailable:
- return res if not isinstance(res, bytes) else pickle.loads(res)
-
- logger.warning("Object {} cannot be found.".format(obj_id))
- raise ObjectNotFoundError
-
- def getList(self, ids):
- """Get multiple objects from the store
-
- Args:
- ids (list): of type plasma.ObjectID
-
- Returns:
- list of the objects
- """
- # self._get()
- return self.client.get(ids)
-
- def get_all(self):
- """Get a listing of all objects in the store
-
- Returns:
- list of all the objects in the store
- """
- return self.client.list()
-
- def reset(self):
- """Reset client connection"""
- self.client = self.connect_store(self.store_loc)
- logger.debug("Reset local connection to store: {0}".format(self.store_loc))
-
- def release(self):
- self.client.disconnect()
-
- # Subscribe to notifications about sealed objects?
- def subscribe(self):
- """Subscribe to a section? of the ds for singals
-
- Raises:
- Exception: Unknown error
- """
- try:
- self.client.subscribe()
- except Exception as e:
- logger.error("Unknown error: {}".format(e))
- raise Exception
-
- def notify(self):
- try:
- notification_info = self.client.get_next_notification()
- # recv_objid, recv_dsize, recv_msize = notification_info
- except ArrowIOError:
- notification_info = None
- except Exception as e:
- logger.exception("Notification error: {}".format(e))
- raise Exception
-
- return notification_info
-
- # Necessary? plasma.ObjectID.from_random()
- def random_ObjectID(self, number=1):
- ids = []
- for i in range(number):
- ids.append(plasma.ObjectID(np.random.bytes(20)))
- return ids
-
- def updateStoreInterfaced(self, object_name, object_id):
- """Update local dict with info we need locally
- Report to Nexus that we updated the store
- (did a put or delete/replace)
-
- Args:
- object_name (str): the name of the object to update
- object_id (): the id of the object to update
- """
- self.stored.update({object_name: object_id})
-
- def getStored(self):
- """
- Returns:
- its info about what it has stored
- """
- return self.stored
-
- def _put(self, obj, id):
- """Internal put"""
- return self.client.put(obj, id)
-
- def _get(self, object_name):
- """Get an object from the store using its name
- Assumes we know the id for the object_name
-
- Raises:
- ObjectNotFound: if object_id returns no object from the store
- """
- # Most errors not shown to user.
- # Maintain separation between external and internal function calls.
- res = self.getID(self.stored.get(object_name))
- # Can also use contains() to check
-
- logger.warning("{}".format(object_name))
- if isinstance(res, ObjectNotAvailable):
- logger.warning("Object {} cannot be found.".format(object_name))
- raise ObjectNotFoundError(obj_id_or_name=object_name) # TODO: Don't raise?
- else:
- return res
-
StoreInterface = RedisStoreInterface
@@ -428,12 +197,12 @@ def __str__(self):
class CannotConnectToStoreInterfaceError(Exception):
"""Raised when failing to connect to store."""
- def __init__(self, store_loc):
+ def __init__(self, store_port):
super().__init__()
self.name = "CannotConnectToStoreInterfaceError"
- self.message = "Cannot connect to store at {}".format(str(store_loc))
+ self.message = "Cannot connect to store at {}".format(str(store_port))
def __str__(self):
return self.message
diff --git a/improv/utils/checks.py b/improv/utils/checks.py
index a72f4657..97c06f24 100644
--- a/improv/utils/checks.py
+++ b/improv/utils/checks.py
@@ -32,10 +32,17 @@ def check_if_connections_acyclic(path_to_yaml):
# Need to keep only module names
connections = {}
- for key, values in raw.items():
- new_key = key.split(".")[0]
- new_values = [value.split(".")[0] for value in values]
- connections[new_key] = new_values
+ for connection, values in raw.items():
+ for source in values["sources"]:
+ source_name = source.split(".")[0]
+ if source_name not in connections.keys():
+ connections[source_name] = [
+ sink_name.split(".")[0] for sink_name in values["sinks"]
+ ]
+ else:
+ connections[source_name] = connections[source_name] + [
+ sink_name.split(".")[0] for sink_name in values["sinks"]
+ ]
g = nx.DiGraph(connections)
dag = nx.is_directed_acyclic_graph(g)
diff --git a/improv/watcher.py b/improv/watcher.py
deleted file mode 100644
index ad2dbf1f..00000000
--- a/improv/watcher.py
+++ /dev/null
@@ -1,175 +0,0 @@
-import numpy as np
-import asyncio
-
-# import pyarrow.plasma as plasma
-# from multiprocessing import Process, Queue, Manager, cpu_count, set_start_method
-# import subprocess
-# import signal
-# import time
-from queue import Empty
-import logging
-
-import concurrent
-
-# from pyarrow.plasma import ObjectNotAvailable
-from improv.actor import Actor, Signal, RunManager
-from improv.store import ObjectNotFoundError
-import pickle
-
-logger = logging.getLogger(__name__)
-logger.setLevel(logging.INFO)
-
-
-class BasicWatcher(Actor):
- """
- Actor that monitors stored objects from the other actors
- and saves objects that have been flagged by those actors
- """
-
- def __init__(self, *args, inputs=None):
- super().__init__(*args)
-
- self.watchin = inputs
-
- def setup(self):
- """
- set up tasks and polling based on inputs which will
- be used for asynchronous polling of input queues
- """
- self.numSaved = 0
- self.tasks = []
- self.polling = self.watchin
- self.setUp = False
-
- def run(self):
- """
- continually run the watcher to check all of the
- input queues for objects to save
- """
-
- with RunManager(
- self.name, self.watchrun, self.setup, self.q_sig, self.q_comm
- ) as rm:
- logger.info(rm)
-
- print("watcher saved " + str(self.numSaved) + " objects")
-
- def watchrun(self):
- """
- set up async loop for polling
- """
- loop = asyncio.get_event_loop()
- loop.run_until_complete(self.watch())
-
- async def watch(self):
- """
- function for asynchronous polling of input queues
- loops through each of the queues in watchin and checks
- if an object is present and then saves the object if found
- """
-
- if self.setUp is False:
- for q in self.polling:
- self.tasks.append(asyncio.create_task(q.get_async()))
- self.setUp = True
-
- done, pending = await asyncio.wait(
- self.tasks, return_when=concurrent.futures.FIRST_COMPLETED
- )
-
- for i, t in enumerate(self.tasks):
- if t in done or self.polling[i].status == "done":
- r = self.polling[i].result # r is array with id and name of object
- actorID = self.polling[
- i
- ].getStart() # name of actor asking watcher to save the object
- try:
- obj = self.client.getID(r[0])
- np.save("output/saved/" + actorID + r[1], obj)
- except ObjectNotFoundError as e:
- logger.info(e.message)
- pass
- self.tasks[i] = asyncio.create_task(self.polling[i].get_async())
-
-
-class Watcher:
- """Monitors the store as separate process
- TODO: Facilitate Watcher being used in multiple processes (shared list)
- """
-
- # Related to subscribe - could be private, i.e., _subscribe
- def __init__(self, name, client):
- self.name = name
- self.client = client
- self.flag = False
- self.saved_ids = []
-
- self.client.subscribe()
- self.n = 0
-
- def setLinks(self, links):
- self.q_sig = links
-
- def run(self):
- while True:
- if self.flag:
- try:
- self.checkStoreInterface2()
- except Exception as e:
- logger.error("Watcher exception during run: {}".format(e))
- # break
- try:
- signal = self.q_sig.get(timeout=0.005)
- if signal == Signal.run():
- self.flag = True
- logger.warning("Received run signal, begin running")
- elif signal == Signal.quit():
- logger.warning("Received quit signal, aborting")
- break
- elif signal == Signal.pause():
- logger.warning("Received pause signal, pending...")
- self.flag = False
- elif signal == Signal.resume(): # currently treat as same as run
- logger.warning("Received resume signal, resuming")
- self.flag = True
- except Empty:
- pass # no signal from Nexus
-
- # def checkStoreInterface(self):
- # notification_info = self.client.notify()
- # recv_objid, recv_dsize, recv_msize = notification_info
- # obj = self.client.getID(recv_objid)
- # try:
- # self.saveObj(obj)
- # self.n += 1
- # except Exception as e:
- # logger.error('Watcher error: {}'.format(e))
-
- def saveObj(self, obj, name):
- with open(
- "/media/hawkwings/Ext Hard Drive/dump/dump" + name + ".pkl", "wb"
- ) as output:
- pickle.dump(obj, output)
-
- def checkStoreInterface2(self):
- objs = list(self.client.get_all().keys())
- ids_to_save = list(set(objs) - set(self.saved_ids))
-
- # with Pool() as pool:
- # saved_ids = pool.map(saveObjbyID, ids_to_save)
- # print('Saved :', len(saved_ids))
- # self.saved_ids.extend(saved_ids)
-
- for id in ids_to_save:
- self.saveObj(self.client.getID(id), str(id))
- self.saved_ids.append(id)
-
-
-# def saveObjbyID(id):
-# client = plasma.connect('/tmp/store')
-# obj = client.get(id)
-# with open(
-# '/media/hawkwings/Ext\ Hard\ Drive/dump/dump'+str(id)+'.pkl', 'wb'
-# ) as output:
-# pickle.dump(obj, output)
-# return id
diff --git a/pyproject.toml b/pyproject.toml
index bf5a42b7..fd185aad 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -16,7 +16,6 @@ dependencies = [
"numpy<=1.26",
"scipy",
"matplotlib",
- "pyarrow==9.0.0",
"PyQt5",
"pyyaml",
"textual==0.15.0",
@@ -56,6 +55,7 @@ exclude = ["test", "pytest", "env", "demos", "figures"]
[tool.pytest.ini_options]
asyncio_mode = "auto"
filterwarnings = [ ]
+asyncio_default_fixture_loop_scope = "function"
log_cli = true
log_cli_level = "INFO"
diff --git a/test/actors/actor_for_bad_args.py b/test/actors/actor_for_bad_args.py
new file mode 100644
index 00000000..bac95922
--- /dev/null
+++ b/test/actors/actor_for_bad_args.py
@@ -0,0 +1,3 @@
+class BadActorArgs:
+ def __init__(self, existent_parameter, another_required_param):
+ pass
diff --git a/test/actors/sample_generator.py b/test/actors/sample_generator.py
index 9d1c3891..d57a5b28 100644
--- a/test/actors/sample_generator.py
+++ b/test/actors/sample_generator.py
@@ -1,4 +1,4 @@
-from improv.actor import Actor
+from improv.actor import ZmqActor
from datetime import date # used for saving
import numpy as np
import logging
@@ -7,7 +7,7 @@
logger.setLevel(logging.INFO)
-class Generator(Actor):
+class Generator(ZmqActor):
"""Sample actor to generate data to pass into a sample processor.
Intended for use along with sample_processor.py.
diff --git a/test/actors/sample_generator_wrong_init.py b/test/actors/sample_generator_wrong_init.py
index 9d1c3891..405ab92f 100644
--- a/test/actors/sample_generator_wrong_init.py
+++ b/test/actors/sample_generator_wrong_init.py
@@ -1,76 +1,2 @@
-from improv.actor import Actor
-from datetime import date # used for saving
-import numpy as np
-import logging
-
-logger = logging.getLogger(__name__)
-logger.setLevel(logging.INFO)
-
-
-class Generator(Actor):
- """Sample actor to generate data to pass into a sample processor.
-
- Intended for use along with sample_processor.py.
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.data = None
- self.name = "Generator"
- self.frame_num = 0
-
- def __str__(self):
- return f"Name: {self.name}, Data: {self.data}"
-
- def setup(self):
- """Generates an array that serves as an initial source of data.
-
- Initial array is a 100 row, 5 column numpy matrix that contains
- integers from 1-99, inclusive.
- """
-
- self.data = np.asmatrix(np.random.randint(100, size=(100, 5)))
- logger.info("Completed setup for Generator")
-
- # def run(self):
- # """ Send array into the store.
- # """
- # self.fcns = {}
- # self.fcns['setup'] = self.setup
- # self.fcns['run'] = self.runStep
- # self.fcns['stop'] = self.stop
-
- # with RunManager(self.name, self.fcns, self.links) as rm:
- # logger.info(rm)
-
- def stop(self):
- """Save current randint vector to a file."""
-
- print("Generator stopping")
- np.save(f"sample_generator_data_{date.today()}", self.data)
- # This is not the best example of a save function,
- # will overwrite previous files with the same name.
- return 0
-
- def runStep(self):
- """Generates additional data after initial setup data is exhausted.
-
- Data is of a different form as the setup data in that although it is
- the same size (5x1 vector), it is uniformly distributed in [1, 10]
- instead of in [1, 100]. Therefore, the average over time should
- converge to 5.5.
- """
-
- if self.frame_num < np.shape(self.data)[0]:
- data_id = self.client.put(
- self.data[self.frame_num], str(f"Gen_raw: {self.frame_num}")
- )
- try:
- self.q_out.put([[data_id, str(self.frame_num)]])
- self.frame_num += 1
- except Exception as e:
- logger.error(f"Generator Exception: {e}")
- else:
- self.data = np.concatenate(
- (self.data, np.asmatrix(np.random.randint(10, size=(1, 5)))), axis=0
- )
+class GeneratorActor:
+ pass
diff --git a/demos/minimal/actors/sample_generator.py b/test/actors/sample_generator_zmq.py
similarity index 62%
rename from demos/minimal/actors/sample_generator.py
rename to test/actors/sample_generator_zmq.py
index 31335833..0bc5b4f5 100644
--- a/demos/minimal/actors/sample_generator.py
+++ b/test/actors/sample_generator_zmq.py
@@ -1,4 +1,5 @@
-from improv.actor import Actor
+from improv.actor import ZmqActor
+from datetime import date # used for saving
import numpy as np
import logging
@@ -6,7 +7,7 @@
logger.setLevel(logging.INFO)
-class Generator(Actor):
+class Generator(ZmqActor):
"""Sample actor to generate data to pass into a sample processor.
Intended for use along with sample_processor.py.
@@ -28,18 +29,30 @@ def setup(self):
integers from 1-99, inclusive.
"""
- logger.info("Beginning setup for Generator")
self.data = np.asmatrix(np.random.randint(100, size=(100, 5)))
logger.info("Completed setup for Generator")
+ # def run(self):
+ # """ Send array into the store.
+ # """
+ # self.fcns = {}
+ # self.fcns['setup'] = self.setup
+ # self.fcns['run'] = self.runStep
+ # self.fcns['stop'] = self.stop
+
+ # with RunManager(self.name, self.fcns, self.links) as rm:
+ # logger.info(rm)
+
def stop(self):
"""Save current randint vector to a file."""
- logger.info("Generator stopping")
- np.save("sample_generator_data.npy", self.data)
+ print("Generator stopping")
+ np.save(f"sample_generator_data_{date.today()}", self.data)
+ # This is not the best example of a save function,
+ # will overwrite previous files with the same name.
return 0
- def runStep(self):
+ def run_step(self):
"""Generates additional data after initial setup data is exhausted.
Data is of a different form as the setup data in that although it is
@@ -49,25 +62,14 @@ def runStep(self):
"""
if self.frame_num < np.shape(self.data)[0]:
- if self.store_loc:
- data_id = self.client.put(
- self.data[self.frame_num], str(f"Gen_raw: {self.frame_num}")
- )
- else:
- data_id = self.client.put(self.data[self.frame_num])
- # logger.info('Put data in store')
+ data_id = self.client.put(self.data[self.frame_num])
try:
- if self.store_loc:
- self.q_out.put([[data_id, str(self.frame_num)]])
- else:
- self.q_out.put(data_id)
- # logger.info("Sent message on")
-
+ self.q_out.put(data_id)
+ # logger.info(f"Sent {self.data[self.frame_num]} with key {data_id}")
self.frame_num += 1
+
except Exception as e:
- logger.error(
- f"--------------------------------Generator Exception: {e}"
- )
+ logger.error(f"Generator Exception: {e}")
else:
self.data = np.concatenate(
(self.data, np.asmatrix(np.random.randint(10, size=(1, 5)))), axis=0
diff --git a/demos/zmq/actors/zmq_rr_sample_processor.py b/test/actors/sample_processor_zmq.py
similarity index 58%
rename from demos/zmq/actors/zmq_rr_sample_processor.py
rename to test/actors/sample_processor_zmq.py
index e4ed9cd3..1948e778 100644
--- a/demos/zmq/actors/zmq_rr_sample_processor.py
+++ b/test/actors/sample_processor_zmq.py
@@ -1,34 +1,33 @@
-from improv.actor import Actor, AsyncActor
+from improv.actor import ZmqActor
import numpy as np
import logging
-from demos.sample_actors.zmqActor import ZmqActor
-
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class Processor(ZmqActor):
- """Sample processor used to calculate the average of an array of integers
- using async ZMQ to communicate.
+ """Sample processor used to calculate the average of an array of integers.
Intended for use with sample_generator.py.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
+ if "name" in kwargs:
+ self.name = kwargs["name"]
def setup(self):
"""Initializes all class variables.
- Sets up a ZmqRRActor to receive data from the generator.
self.name (string): name of the actor.
- self.frame (ObjectID): Store object id referencing data from the store.
+ self.frame (ObjectID): StoreInterface object id referencing data from the store.
self.avg_list (list): list that contains averages of individual vectors.
self.frame_num (int): index of current frame.
"""
- self.name = "Processor"
+ if not hasattr(self, "name"):
+ self.name = "Processor"
self.frame = None
self.avg_list = []
self.frame_num = 1
@@ -38,10 +37,21 @@ def stop(self):
"""Trivial stop function for testing purposes."""
logger.info("Processor stopping")
+ return 0
+
+ # def run(self):
+ # """ Send array into the store.
+ # """
+ # self.fcns = {}
+ # self.fcns['setup'] = self.setup
+ # self.fcns['run'] = self.runStep
+ # self.fcns['stop'] = self.stop
+
+ # with RunManager(self.name, self.fcns, self.links) as rm:
+ # logger.info(rm)
- def runStep(self):
+ def run_step(self):
"""Gets from the input queue and calculates the average.
- Receives data from the generator using a ZmqRRActor.
Receives an ObjectID, references data in the store using that
ObjectID, calculates the average of that data, and finally prints
@@ -49,21 +59,19 @@ def runStep(self):
"""
frame = None
-
try:
- frame = self.get(reply='received')
-
- except:
- logger.error("Could not get frame!")
+ frame = self.q_in.get(timeout=0.05)
+ except Exception as e:
+ logger.error(f"{self.name} could not get frame! At {self.frame_num}: {e}")
pass
if frame is not None and self.frame_num is not None:
self.done = False
- self.frame = self.client.getID(frame)
+ self.frame = self.client.get(frame)
avg = np.mean(self.frame[0])
-
+ # logger.info(f"{self.name} got frame {frame} with value {self.frame}")
# logger.info(f"Average: {avg}")
self.avg_list.append(avg)
- logger.info(f"Overall Average: {np.mean(self.avg_list)}")
+ # logger.info(f"Overall Average: {np.mean(self.avg_list)}")
# logger.info(f"Frame number: {self.frame_num}")
self.frame_num += 1
diff --git a/test/configs/bad_args.yaml b/test/configs/bad_args.yaml
index be94212c..ce96b6b6 100644
--- a/test/configs/bad_args.yaml
+++ b/test/configs/bad_args.yaml
@@ -1,20 +1,7 @@
actors:
Acquirer:
- package: demos.sample_actors.acquire
- class: FileAcquirer
- fiasdfe: data/Tolias_mesoscope_2.hdf5
- fraefawe: 30
-
- Processor:
- package: demos.sample_actors.process
- class: CaimanProcessor
- init_filename: data/tbif_ex_crop.h5
- config_file: eva_caiman_params.txt
-
- Analysis:
- package: demos.sample_actors.analysis
- class: MeanAnalysis
+ package: actors.actor_for_bad_args
+ class: BadActorArgs
+ existent_parameter: 42
connections:
- Acquirer.q_out: [Processor.q_in]
- Processor.q_out: [Analysis.q_in]
diff --git a/test/configs/complex_graph.yaml b/test/configs/complex_graph.yaml
index 3c82029a..16961825 100644
--- a/test/configs/complex_graph.yaml
+++ b/test/configs/complex_graph.yaml
@@ -14,8 +14,14 @@ actors:
class: BehaviorAcquirer
connections:
- Acquirer.q_out: [Analysis.q_in, InputStim.q_in]
- Analysis.q_out: [InputStim.q_in]
-
-# settings:
-# use_watcher: [Acquirer, Processor, Visual, Analysis]
+ Acquirer-out:
+ sources:
+ - Acquirer.q_out
+ sinks:
+ - Analysis.q_in
+ - InputStim.q_in
+ Analysis-InputStim:
+ sources:
+ - Analysis.q_out
+ sinks:
+ - InputStim.q_in
diff --git a/test/configs/cyclic_config.yaml b/test/configs/cyclic_config.yaml
index 61c91303..8e87b880 100644
--- a/test/configs/cyclic_config.yaml
+++ b/test/configs/cyclic_config.yaml
@@ -14,5 +14,14 @@ actors:
class: BehaviorAcquirer
connections:
- Acquirer.q_out: [Analysis.q_in, InputStim.q_in]
- Analysis.q_out: [Acquirer.q_in]
+ Acquirer-out:
+ sources:
+ - Acquirer.q_out
+ sinks:
+ - Analysis.q_in
+ - InputStim.q_in
+ Analysis-Acquirer:
+ sources:
+ - Analysis.q_out
+ sinks:
+ - Acquirer.q_in
diff --git a/test/configs/good_config.yaml b/test/configs/good_config.yaml
index b0c70b70..6f72da5c 100644
--- a/test/configs/good_config.yaml
+++ b/test/configs/good_config.yaml
@@ -10,4 +10,8 @@ actors:
class: SimpleAnalysis
connections:
- Acquirer.q_out: [Analysis.q_in]
+ Acquirer-Analysis:
+ sources:
+ - Acquirer.q_out
+ sinks:
+ - Analysis.q_in
diff --git a/test/configs/good_config_actors.yaml b/test/configs/good_config_actors.yaml
deleted file mode 100644
index e84fd262..00000000
--- a/test/configs/good_config_actors.yaml
+++ /dev/null
@@ -1,8 +0,0 @@
-Acquirer:
- class: FileAcquirer
- filename: data/Tolias_mesoscope_2.hdf5
- framerate: 30
- package: demos.sample_actors.acquire
-Analysis:
- class: SimpleAnalysis
- package: demos.sample_actors.simple_analysis
diff --git a/test/configs/good_config_plasma.yaml b/test/configs/good_config_plasma.yaml
deleted file mode 100644
index 8323017b..00000000
--- a/test/configs/good_config_plasma.yaml
+++ /dev/null
@@ -1,15 +0,0 @@
-actors:
- Acquirer:
- package: demos.sample_actors.acquire
- class: FileAcquirer
- filename: data/Tolias_mesoscope_2.hdf5
- framerate: 30
-
- Analysis:
- package: demos.sample_actors.simple_analysis
- class: SimpleAnalysis
-
-connections:
- Acquirer.q_out: [Analysis.q_in]
-
-plasma_config:
\ No newline at end of file
diff --git a/test/configs/minimal.yaml b/test/configs/minimal.yaml
index 230282d1..d752b415 100644
--- a/test/configs/minimal.yaml
+++ b/test/configs/minimal.yaml
@@ -1,11 +1,15 @@
actors:
Generator:
- package: actors.sample_generator
+ package: actors.sample_generator_zmq
class: Generator
Processor:
- package: actors.sample_processor
+ package: actors.sample_processor_zmq
class: Processor
connections:
- Generator.q_out: [Processor.q_in]
\ No newline at end of file
+ Generator-Processor:
+ sources:
+ - Generator.q_out
+ sinks:
+ - Processor.q_in
diff --git a/test/configs/minimal_gui.yaml b/test/configs/minimal_gui.yaml
new file mode 100644
index 00000000..8dd8bec5
--- /dev/null
+++ b/test/configs/minimal_gui.yaml
@@ -0,0 +1,15 @@
+actors:
+ GUI:
+ package: actors.sample_generator_zmq
+ class: Generator
+
+ Processor:
+ package: actors.sample_processor_zmq
+ class: Processor
+
+connections:
+ Generator-Processor:
+ sources:
+ - Generator.q_out
+ sinks:
+ - Processor.q_in
diff --git a/test/configs/minimal_harvester.yaml b/test/configs/minimal_harvester.yaml
new file mode 100644
index 00000000..74847433
--- /dev/null
+++ b/test/configs/minimal_harvester.yaml
@@ -0,0 +1,18 @@
+actors:
+ Generator:
+ package: actors.sample_generator_zmq
+ class: Generator
+
+ Processor:
+ package: actors.sample_processor_zmq
+ class: Processor
+
+connections:
+ Generator-Processor:
+ sources:
+ - Generator.q_out
+ sinks:
+ - Processor.q_in
+
+settings:
+ harvest_data_from_memory: True
\ No newline at end of file
diff --git a/test/configs/minimal_with_custom_aof_dirname.yaml b/test/configs/minimal_with_custom_aof_dirname.yaml
index 1bac5863..f07c0c44 100644
--- a/test/configs/minimal_with_custom_aof_dirname.yaml
+++ b/test/configs/minimal_with_custom_aof_dirname.yaml
@@ -1,14 +1,18 @@
actors:
Generator:
- package: actors.sample_generator
+ package: actors.sample_generator_zmq
class: Generator
Processor:
- package: actors.sample_processor
+ package: actors.sample_processor_zmq
class: Processor
connections:
- Generator.q_out: [Processor.q_in]
+ Generator-Processor:
+ sources:
+ - Generator.q_out
+ sinks:
+ - Processor.q_in
redis_config:
aof_dirname: custom_aof_dirname
\ No newline at end of file
diff --git a/test/configs/minimal_with_ephemeral_aof_dirname.yaml b/test/configs/minimal_with_ephemeral_aof_dirname.yaml
index 9933896e..d4f5a635 100644
--- a/test/configs/minimal_with_ephemeral_aof_dirname.yaml
+++ b/test/configs/minimal_with_ephemeral_aof_dirname.yaml
@@ -1,14 +1,18 @@
actors:
Generator:
- package: actors.sample_generator
+ package: actors.sample_generator_zmq
class: Generator
Processor:
- package: actors.sample_processor
+ package: actors.sample_processor_zmq
class: Processor
connections:
- Generator.q_out: [Processor.q_in]
+ Generator-Processor:
+ sources:
+ - Generator.q_out
+ sinks:
+ - Processor.q_in
redis_config:
generate_ephemeral_aof_dirname: True
\ No newline at end of file
diff --git a/test/configs/minimal_with_every_second_saving.yaml b/test/configs/minimal_with_every_second_saving.yaml
index 74d4b314..6d381f72 100644
--- a/test/configs/minimal_with_every_second_saving.yaml
+++ b/test/configs/minimal_with_every_second_saving.yaml
@@ -1,14 +1,18 @@
actors:
Generator:
- package: actors.sample_generator
+ package: actors.sample_generator_zmq
class: Generator
Processor:
- package: actors.sample_processor
+ package: actors.sample_processor_zmq
class: Processor
connections:
- Generator.q_out: [Processor.q_in]
+ Generator-Processor:
+ sources:
+ - Generator.q_out
+ sinks:
+ - Processor.q_in
redis_config:
enable_saving: True
diff --git a/test/configs/minimal_with_every_write_saving.yaml b/test/configs/minimal_with_every_write_saving.yaml
index add0c74f..4f1e3af4 100644
--- a/test/configs/minimal_with_every_write_saving.yaml
+++ b/test/configs/minimal_with_every_write_saving.yaml
@@ -1,14 +1,18 @@
actors:
Generator:
- package: actors.sample_generator
+ package: actors.sample_generator_zmq
class: Generator
Processor:
- package: actors.sample_processor
+ package: actors.sample_processor_zmq
class: Processor
connections:
- Generator.q_out: [Processor.q_in]
+ Generator-Processor:
+ sources:
+ - Generator.q_out
+ sinks:
+ - Processor.q_in
redis_config:
enable_saving: True
diff --git a/test/configs/minimal_with_fixed_default_redis_port.yaml b/test/configs/minimal_with_fixed_default_redis_port.yaml
index c8e0b24e..8d33ded9 100644
--- a/test/configs/minimal_with_fixed_default_redis_port.yaml
+++ b/test/configs/minimal_with_fixed_default_redis_port.yaml
@@ -1,14 +1,18 @@
actors:
Generator:
- package: actors.sample_generator
+ package: actors.sample_generator_zmq
class: Generator
Processor:
- package: actors.sample_processor
+ package: actors.sample_processor_zmq
class: Processor
connections:
- Generator.q_out: [Processor.q_in]
+ Generator-Processor:
+ sources:
+ - Generator.q_out
+ sinks:
+ - Processor.q_in
redis_config:
port: 6379
\ No newline at end of file
diff --git a/test/configs/minimal_with_fixed_redis_port.yaml b/test/configs/minimal_with_fixed_redis_port.yaml
index f6f98253..b63471f4 100644
--- a/test/configs/minimal_with_fixed_redis_port.yaml
+++ b/test/configs/minimal_with_fixed_redis_port.yaml
@@ -1,14 +1,18 @@
actors:
Generator:
- package: actors.sample_generator
+ package: actors.sample_generator_zmq
class: Generator
Processor:
- package: actors.sample_processor
+ package: actors.sample_processor_zmq
class: Processor
connections:
- Generator.q_out: [Processor.q_in]
+ Generator-Processor:
+ sources:
+ - Generator.q_out
+ sinks:
+ - Processor.q_in
redis_config:
port: 6378
\ No newline at end of file
diff --git a/test/configs/minimal_with_no_schedule_saving.yaml b/test/configs/minimal_with_no_schedule_saving.yaml
index 2e5f62f3..20731f3d 100644
--- a/test/configs/minimal_with_no_schedule_saving.yaml
+++ b/test/configs/minimal_with_no_schedule_saving.yaml
@@ -1,14 +1,18 @@
actors:
Generator:
- package: actors.sample_generator
+ package: actors.sample_generator_zmq
class: Generator
Processor:
- package: actors.sample_processor
+ package: actors.sample_processor_zmq
class: Processor
connections:
- Generator.q_out: [Processor.q_in]
+ Generator-Processor:
+ sources:
+ - Generator.q_out
+ sinks:
+ - Processor.q_in
redis_config:
enable_saving: True
diff --git a/test/configs/minimal_with_redis_saving.yaml b/test/configs/minimal_with_redis_saving.yaml
index 15e95b6c..087f685d 100644
--- a/test/configs/minimal_with_redis_saving.yaml
+++ b/test/configs/minimal_with_redis_saving.yaml
@@ -1,14 +1,18 @@
actors:
Generator:
- package: actors.sample_generator
+ package: actors.sample_generator_zmq
class: Generator
Processor:
- package: actors.sample_processor
+ package: actors.sample_processor_zmq
class: Processor
connections:
- Generator.q_out: [Processor.q_in]
+ Generator-Processor:
+ sources:
+ - Generator.q_out
+ sinks:
+ - Processor.q_in
redis_config:
enable_saving: True
\ No newline at end of file
diff --git a/test/configs/minimal_with_settings.yaml b/test/configs/minimal_with_settings.yaml
index a5aec2d8..daff183d 100644
--- a/test/configs/minimal_with_settings.yaml
+++ b/test/configs/minimal_with_settings.yaml
@@ -4,16 +4,19 @@ settings:
control_port: 6000
output_port: 6001
logging_port: 6002
- use_watcher: false
actors:
Generator:
- package: actors.sample_generator
+ package: actors.sample_generator_zmq
class: Generator
Processor:
- package: actors.sample_processor
+ package: actors.sample_processor_zmq
class: Processor
connections:
- Generator.q_out: [Processor.q_in]
\ No newline at end of file
+ Generator-Processor:
+ sources:
+ - Generator.q_out
+ sinks:
+ - Processor.q_in
diff --git a/test/configs/minimal_wrong_import.yaml b/test/configs/minimal_wrong_import.yaml
index 6c5e51c5..3fbc7822 100644
--- a/test/configs/minimal_wrong_import.yaml
+++ b/test/configs/minimal_wrong_import.yaml
@@ -4,8 +4,12 @@ actors:
class: Generator
Processor:
- package: actors.sample_processor
+ package: actors.sample_processor_zmq
class: Processor
connections:
- Generator.q_out: [Processor.q_in]
\ No newline at end of file
+ Generator-Processor:
+ sources:
+ - Generator.q_out
+ sinks:
+ - Processor.q_in
\ No newline at end of file
diff --git a/test/configs/minimal_wrong_init.yaml b/test/configs/minimal_wrong_init.yaml
index a59c954a..edb45b51 100644
--- a/test/configs/minimal_wrong_init.yaml
+++ b/test/configs/minimal_wrong_init.yaml
@@ -4,8 +4,12 @@ actors:
class: Generator
Processor:
- package: actors.sample_processor
+ package: actors.sample_processor_zmq
class: Processor
connections:
- Generator.q_out: [Processor.q_in]
\ No newline at end of file
+ Generator-Processor:
+ sources:
+ - Generator.q_out
+ sinks:
+ - Processor.q_in
\ No newline at end of file
diff --git a/test/configs/simple_graph.yaml b/test/configs/simple_graph.yaml
index 67086def..9238739c 100644
--- a/test/configs/simple_graph.yaml
+++ b/test/configs/simple_graph.yaml
@@ -10,7 +10,8 @@ actors:
class: SimpleAnalysis
connections:
- Acquirer.q_out: [Analysis.q_in]
-
-# settings:
-# use_watcher: [Acquirer, Processor, Visual, Analysis]
+ Acquirer-Analysis:
+ sources:
+ - Acquirer.q_out
+ sinks:
+ - Analysis.q_in
diff --git a/test/configs/single_actor.yaml b/test/configs/single_actor.yaml
index 812480b1..c9c6b0e2 100644
--- a/test/configs/single_actor.yaml
+++ b/test/configs/single_actor.yaml
@@ -6,5 +6,3 @@ actors:
framerate: 15
connections:
-# settings:
-# use_watcher: [Acquirer, Processor, Visual, Analysis]
diff --git a/test/configs/single_actor_plasma.yaml b/test/configs/single_actor_plasma.yaml
deleted file mode 100644
index 812480b1..00000000
--- a/test/configs/single_actor_plasma.yaml
+++ /dev/null
@@ -1,10 +0,0 @@
-actors:
- Acquirer:
- package: demos.sample_actors.acquire
- class: FileAcquirer
- filename: data/Tolias_mesoscope_2.hdf5
- framerate: 15
-
-connections:
-# settings:
-# use_watcher: [Acquirer, Processor, Visual, Analysis]
diff --git a/test/conftest.py b/test/conftest.py
index f9b80fa5..d123fcdc 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -1,17 +1,79 @@
+import logging
+import multiprocessing
import os
import signal
-import uuid
+import time
+
import pytest
import subprocess
-store_loc = str(os.path.join("/tmp/", str(uuid.uuid4())))
+import zmq
+
+from improv.actor import ZmqActor
+from improv.harvester import bootstrap_harvester
+from improv.nexus import Nexus
+
redis_port_num = 6379
-WAIT_TIMEOUT = 120
+WAIT_TIMEOUT = 10
+
+SERVER_COUNTER = 0
+
+signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
+
+
+@pytest.fixture
+def ports():
+ global SERVER_COUNTER
+ CONTROL_PORT = 30000
+ OUTPUT_PORT = 30001
+ LOGGING_PORT = 30002
+ ACTOR_IN_PORT = 30003
+ yield (
+ CONTROL_PORT + SERVER_COUNTER,
+ OUTPUT_PORT + SERVER_COUNTER,
+ LOGGING_PORT + SERVER_COUNTER,
+ ACTOR_IN_PORT + SERVER_COUNTER,
+ )
+ SERVER_COUNTER += 4
+
+
+@pytest.fixture
+def setdir():
+ prev = os.getcwd()
+ os.chdir(os.path.dirname(__file__) + "/configs")
+ yield None
+ os.chdir(prev)
+
+
+@pytest.fixture
+def set_dir_config_parent():
+ prev = os.getcwd()
+ os.chdir(os.path.dirname(__file__))
+ yield None
+ os.chdir(prev)
@pytest.fixture
-def set_store_loc():
- return store_loc
+def sample_nex(setdir, ports):
+ nex = Nexus("test")
+ try:
+ nex.create_nexus(
+ file="good_config.yaml",
+ store_size=40000000,
+ control_port=ports[0],
+ output_port=ports[1],
+ )
+ except Exception as e:
+ print(f"error caught in test harness during create_nexus step: {e}")
+ logging.error(f"error caught in test harness during create_nexus step: {e}")
+ raise e
+ yield nex
+ try:
+ nex.destroy_nexus()
+ except Exception as e:
+ print(f"error caught in test harness during destroy_nexus step: {e}")
+ logging.error(f"error caught in test harness during destroy_nexus step: {e}")
+ raise e
@pytest.fixture
@@ -37,21 +99,96 @@ def setup_store(server_port_num):
stderr=subprocess.DEVNULL,
)
+ time.sleep(3)
+ if p.poll() is not None:
+ raise Exception("redis-server failed to start")
+
yield p
# kill the subprocess when the caller is done with it
p.send_signal(signal.SIGINT)
- p.wait(WAIT_TIMEOUT)
+ p.wait(timeout=WAIT_TIMEOUT)
-@pytest.fixture
-def setup_plasma_store(set_store_loc, scope="module"):
- """Fixture to set up the store subprocess with 10 mb."""
- p = subprocess.Popen(
- ["plasma_store", "-s", set_store_loc, "-m", str(10000000)],
- stdout=subprocess.DEVNULL,
- stderr=subprocess.DEVNULL,
+def nex_startup(ports, filename):
+ nex = Nexus("test")
+ nex.create_nexus(
+ file=filename,
+ store_size=100000000,
+ control_port=ports[0],
+ output_port=ports[1],
+ actor_in_port=ports[3],
)
+ nex.start_nexus()
+
+
+@pytest.fixture
+def start_nexus_minimal_zmq(ports):
+ filename = "minimal.yaml"
+ p = multiprocessing.Process(target=nex_startup, args=(ports, filename))
+ p.start()
+ time.sleep(1)
+
yield p
- p.send_signal(signal.SIGINT)
- p.wait(WAIT_TIMEOUT)
+
+ p.terminate()
+ p.join(WAIT_TIMEOUT)
+ if p.exitcode is None:
+ logging.exception("Timed out waiting for nexus to stop")
+ p.kill()
+
+
+@pytest.fixture
+def zmq_actor(ports):
+ actor = ZmqActor(ports[3], None, None, None, None, None, name="test")
+
+ p = multiprocessing.Process(target=actor_startup, args=(actor,))
+
+ yield p
+
+ p.terminate()
+ p.join(WAIT_TIMEOUT)
+ if p.exitcode is None:
+ p.kill()
+
+
+def actor_startup(actor):
+ actor.register_with_nexus()
+
+
+@pytest.fixture
+def harvester(ports):
+ ctx = zmq.Context()
+ socket = ctx.socket(zmq.PUB)
+ socket.bind("tcp://*:1234")
+ p = multiprocessing.Process(
+ target=bootstrap_harvester,
+ args=(
+ "localhost",
+ ports[3],
+ "localhost",
+ 6379,
+ "localhost",
+ 1234,
+ "localhost",
+ 12345,
+ ),
+ )
+ p.start()
+ time.sleep(1)
+ yield ports, socket, p
+ socket.close(linger=0)
+ ctx.destroy(linger=0)
+
+
+class SignalManager:
+ def __init__(self):
+ self.signal_handlers = dict()
+
+ def __enter__(self):
+ for sig in signals:
+ self.signal_handlers[sig] = signal.getsignal(sig)
+
+ def __exit__(self, type, value, traceback):
+ for sig, handler in self.signal_handlers.items():
+ signal.signal(sig, handler)
diff --git a/test/conftest_with_errors.py b/test/conftest_with_errors.py
deleted file mode 100644
index c299959d..00000000
--- a/test/conftest_with_errors.py
+++ /dev/null
@@ -1,78 +0,0 @@
-# import pytest
-import subprocess
-import asyncio
-from improv.actor import RunManager, AsyncRunManager
-import os
-import uuid
-
-
-class StoreInterfaceDependentTestCase:
- def set_up(self):
- """Start the server"""
- print("Setting up Plasma store.")
- store_loc = str(os.path.join("/tmp/", str(uuid.uuid4())))
- self.p = subprocess.Popen(
- ["plasma_store", "-s", store_loc, "-m", str(10000000)],
- stdout=subprocess.DEVNULL,
- stderr=subprocess.DEVNULL,
- )
-
- def tear_down(self):
- """Kill the server"""
- print("Tearing down Plasma store.")
- self.p.kill()
- self.p.wait()
-
-
-class ActorDependentTestCase:
- def set_up(self):
- """Start the server"""
- print("Setting up Plasma store.")
- store_loc = str(os.path.join("/tmp/", str(uuid.uuid4())))
- self.p = subprocess.Popen(
- ["plasma_store", "-s", store_loc, "-m", str(10000000)],
- stdout=subprocess.DEVNULL,
- stderr=subprocess.DEVNULL,
- )
-
- def tear_down(self):
- """Kill the server"""
- print("Tearing down Plasma store.")
- self.p.kill()
- self.p.wait()
-
- def run_setup(self):
- print("Set up = True.")
- self.is_set_up = True
-
- def run_method(self):
- # Accurate print statement?
- print("Running method.")
- self.run_num += 1
-
- def process_setup(self):
- print("Processing setup.")
- pass
-
- def process_run(self):
- print("Processing run - ran.")
- self.q_comm.put("ran")
-
- def create_process(self, q_sig, q_comm):
- print("Creating process.")
- with RunManager(
- "test", self.process_run, self.process_setup, q_sig, q_comm
- ) as rm:
- print(rm)
-
- async def createAsyncProcess(self, q_sig, q_comm):
- print("Creating asyn process.")
- async with AsyncRunManager(
- "test", self.process_run, self.process_setup, q_sig, q_comm
- ) as rm:
- print(rm)
-
- async def a_put(self, signal, time):
- print("Async put.")
- await asyncio.sleep(time)
- self.q_sig.put_async(signal)
diff --git a/test/nexus_analog.py b/test/nexus_analog.py
deleted file mode 100644
index 0fc698ec..00000000
--- a/test/nexus_analog.py
+++ /dev/null
@@ -1,119 +0,0 @@
-import concurrent
-import time
-import asyncio
-import math
-import uuid
-import os
-from improv.link import Link
-from improv.store import StoreInterface
-import subprocess
-
-
-def clean_list_print(lst):
- print("\n=======================\n")
- for el in lst:
- print(el)
- print("\n")
- print("\n=======================\n")
-
-
-def setup_store():
- """Fixture to set up the store subprocess with 10 mb.
-
- This fixture runs a subprocess that instantiates the store with a
- memory of 10 megabytes. It specifies that "/tmp/store/" is the
- location of the store socket.
-
- Yields:
- StoreInterface: An instance of the store.
-
- TODO:
- Figure out the scope.
- """
- store_loc = str(os.path.join("/tmp/", str(uuid.uuid4())))
- subprocess.Popen(
- ["plasma_store", "-s", store_loc, "-m", str(10000000)],
- stdout=subprocess.DEVNULL,
- stderr=subprocess.DEVNULL,
- )
- store = StoreInterface(store_loc=store_loc)
- return store
-
-
-async def pollQueues(links):
- tasks = []
- for link in links:
- tasks.append(asyncio.create_task(link.get_async()))
-
- links_cpy = links
- t_0 = time.perf_counter()
- t_1 = time.perf_counter()
- print("time get")
- cur_time = 0
- while t_1 - t_0 < 5:
- # need to add something to the queue such that asyncio.wait returns
-
- links[0].put("Message")
- done, pending = await asyncio.wait(
- tasks, return_when=concurrent.futures.FIRST_COMPLETED
- )
- for i, t in enumerate(tasks):
- if t in done:
- pass
- tasks[i] = asyncio.create_task(links_cpy[i].get_async())
-
- t_1 = time.perf_counter()
-
- if math.floor(t_1 - t_0) != cur_time:
- print(math.floor(t_1 - t_0))
- cur_time = math.floor(t_1 - t_0)
-
- print("All tasks prior to stop polling: \n")
- clean_list_print([task for task in tasks])
-
- asyncio.get_running_loop()
- return stop_polling(tasks, links)
-
-
-def start():
- links = [
- Link(f"Link {i}", f"start {i}", f"end {i}", setup_store()) for i in range(4)
- ]
- loop = asyncio.get_event_loop()
- print("RUC loop")
- res = loop.run_until_complete(pollQueues(links))
- print(f"RES: {res}")
-
- print("**********************\nAll tasks at the end of execution:")
- clean_list_print(res)
- print("**********************")
- print(f"Loop: {loop}")
- loop.close()
- print(f"Loop: {loop}")
-
-
-def stop_polling(tasks, links):
- # asyncio.gather(*tasks)
- print("Cancelling")
-
- [lnk.put("msg") for lnk in links]
-
- # [task.cancel() for task in tasks]
- # [task.cancel() for task in tasks]
-
- # [lnk.put("msg") for lnk in links]
-
- print("All tasks: \n")
- clean_list_print([task for task in tasks])
- print("Pending:\n")
- clean_list_print([task for task in tasks if not task.done()])
- print("Cancelled: \n")
- clean_list_print([task for task in tasks if task.cancelled()])
- print("Pending and cancelled: \n")
- clean_list_print([task for task in tasks if not task.done() and task.cancelled()])
-
- return [task for task in tasks]
-
-
-if __name__ == "__main__":
- start()
diff --git a/test/special_configs/basic_demo.yaml b/test/special_configs/basic_demo.yaml
index 1b84f7d8..6f113765 100644
--- a/test/special_configs/basic_demo.yaml
+++ b/test/special_configs/basic_demo.yaml
@@ -16,6 +16,3 @@ actors:
connections:
Acquirer.q_out: [Analysis.q_in]
InputStim.q_out: [Analysis.input_stim_queue]
-
-# settings:
-# use_watcher: [Acquirer, Processor, Visual, Analysis]
diff --git a/test/special_configs/basic_demo_with_GUI.yaml b/test/special_configs/basic_demo_with_GUI.yaml
index edfb09e1..0e5dfc27 100644
--- a/test/special_configs/basic_demo_with_GUI.yaml
+++ b/test/special_configs/basic_demo_with_GUI.yaml
@@ -21,6 +21,3 @@ actors:
connections:
Acquirer.q_out: [Analysis.q_in]
InputStim.q_out: [Analysis.input_stim_queue]
-
-# settings:
-# use_watcher: [Acquirer, Processor, Visual, Analysis]
diff --git a/test/test_actor.py b/test/test_actor.py
index 448fd34e..82251418 100644
--- a/test/test_actor.py
+++ b/test/test_actor.py
@@ -1,9 +1,12 @@
import os
import psutil
import pytest
-from improv.link import Link # , AsyncQueue
+import zmq
+
from improv.actor import AbstractActor as Actor
-from improv.store import StoreInterface, PlasmaStoreInterface
+from improv.messaging import ActorStateMsg, ActorStateReplyMsg
+from improv.store import StoreInterface
+from improv.link import ZmqLink
# set global_variables
@@ -12,10 +15,10 @@
@pytest.fixture
-def init_actor(set_store_loc):
+def init_actor():
"""Fixture to initialize and teardown an instance of actor."""
- act = Actor("Test", set_store_loc)
+ act = Actor("Test")
yield act
act = None
@@ -33,33 +36,15 @@ def example_links(setup_store, server_port_num):
"""Fixture to provide link objects as test input and setup store."""
StoreInterface(server_port_num=server_port_num)
- acts = [
- Actor("act" + str(i), server_port_num) for i in range(1, 5)
- ] # range must be even
+ ctx = zmq.Context()
+ s = ctx.socket(zmq.PUSH)
- links = [
- Link("L" + str(i + 1), acts[i], acts[i + 1]) for i in range(len(acts) // 2)
- ]
+ links = [ZmqLink(s, "test") for i in range(2)]
link_dict = {links[i].name: links[i] for i, l in enumerate(links)}
pytest.example_links = link_dict
- return pytest.example_links
-
-
-@pytest.fixture
-def example_links_plasma(setup_store, set_store_loc):
- """Fixture to provide link objects as test input and setup store."""
- PlasmaStoreInterface(store_loc=set_store_loc)
-
- acts = [
- Actor("act" + str(i), set_store_loc) for i in range(1, 5)
- ] # range must be even
-
- links = [
- Link("L" + str(i + 1), acts[i], acts[i + 1]) for i in range(len(acts) // 2)
- ]
- link_dict = {links[i].name: links[i] for i, l in enumerate(links)}
- pytest.example_links = link_dict
- return pytest.example_links
+ yield pytest.example_links
+ s.close(linger=0)
+ ctx.destroy(linger=0)
@pytest.mark.parametrize(
@@ -89,66 +74,34 @@ def test_repr_default_initialization(init_actor):
assert rep == "Test: dict_keys([])"
-def test_repr(example_string_links, set_store_loc):
+def test_repr(example_string_links):
"""Test if the actor representation has the right, nonempty, dict."""
- act = Actor("Test", set_store_loc)
- act.setLinks(example_string_links)
+ act = Actor("Test")
+ act.set_links(example_string_links)
assert act.__repr__() == "Test: dict_keys(['1', '2', '3'])"
-def test_setStoreInterface(setup_store, server_port_num):
+def test_set_store_interface(setup_store, server_port_num):
"""Tests if the store is started and linked with the actor."""
act = Actor("Acquirer", server_port_num)
store = StoreInterface(server_port_num=server_port_num)
- act.setStoreInterface(store.client)
- assert act.client is store.client
-
-
-def test_plasma_setStoreInterface(setup_plasma_store, set_store_loc):
- """Tests if the store is started and linked with the actor."""
-
- act = Actor("Acquirer", set_store_loc)
- store = PlasmaStoreInterface(store_loc=set_store_loc)
- act.setStoreInterface(store.client)
+ act.set_store_interface(store.client)
assert act.client is store.client
@pytest.mark.parametrize(
- "links", [(pytest.example_string_links), ({}), (pytest.example_links), (None)]
+ "links", [pytest.example_string_links, ({}), pytest.example_links, None]
)
-def test_setLinks(links, set_store_loc):
+def test_set_links(links):
"""Tests if the actors links can be set to certain values."""
- act = Actor("test", set_store_loc)
- act.setLinks(links)
+ act = Actor("test")
+ act.set_links(links)
assert act.links == links
-@pytest.mark.parametrize(
- ("qc", "qs"),
- [
- ("comm", "sig"),
- (None, None),
- ("", ""),
- ("LINK", "LINK"), # these are placeholder names (store is not setup)
- ],
-)
-def test_setCommLinks(example_links, qc, qs, init_actor, setup_store, set_store_loc):
- """Tests if commLinks can be added to the actor"s links."""
-
- if qc == "LINK" and qs == "LINK":
- qc = Link("L1", Actor("1", set_store_loc), Actor("2", set_store_loc))
- qs = Link("L2", Actor("3", set_store_loc), Actor("4", set_store_loc))
- act = init_actor
- act.setLinks(example_links)
- act.setCommLinks(qc, qs)
-
- example_links.update({"q_comm": qc, "q_sig": qs})
- assert act.links == example_links
-
-
@pytest.mark.parametrize(
("links", "expected"),
[
@@ -158,18 +111,18 @@ def test_setCommLinks(example_links, qc, qs, init_actor, setup_store, set_store_
(None, TypeError),
],
)
-def test_setLinkIn(init_actor, example_string_links, example_links, links, expected):
+def test_set_link_in(init_actor, example_string_links, example_links, links, expected):
"""Tests if we can set the input queue."""
act = init_actor
- act.setLinks(links)
+ act.set_links(links)
if links is not None:
- act.setLinkIn("input_q")
+ act.set_link_in("input_q")
expected.update({"q_in": "input_q"})
assert act.links == expected
else:
with pytest.raises(AttributeError):
- act.setLinkIn("input_queue")
+ act.set_link_in("input_queue")
@pytest.mark.parametrize(
@@ -181,18 +134,18 @@ def test_setLinkIn(init_actor, example_string_links, example_links, links, expec
(None, TypeError),
],
)
-def test_setLinkOut(init_actor, example_string_links, example_links, links, expected):
+def test_set_link_out(init_actor, example_string_links, example_links, links, expected):
"""Tests if we can set the output queue."""
act = init_actor
- act.setLinks(links)
+ act.set_links(links)
if links is not None:
- act.setLinkOut("output_q")
+ act.set_link_out("output_q")
expected.update({"q_out": "output_q"})
assert act.links == expected
else:
with pytest.raises(AttributeError):
- act.setLinkIn("output_queue")
+ act.set_link_in("output_queue")
@pytest.mark.parametrize(
@@ -204,29 +157,31 @@ def test_setLinkOut(init_actor, example_string_links, example_links, links, expe
(None, TypeError),
],
)
-def test_setLinkWatch(init_actor, example_string_links, example_links, links, expected):
+def test_set_link_watch(
+ init_actor, example_string_links, example_links, links, expected
+):
"""Tests if we can set the watch queue."""
act = init_actor
- act.setLinks(links)
+ act.set_links(links)
if links is not None:
- act.setLinkWatch("watch_q")
+ act.set_link_watch("watch_q")
expected.update({"q_watchout": "watch_q"})
assert act.links == expected
else:
with pytest.raises(AttributeError):
- act.setLinkIn("input_queue")
+ act.set_link_in("input_queue")
-def test_addLink(setup_store, set_store_loc):
+def test_add_link(setup_store):
"""Tests if a link can be added to the dictionary of links."""
- act = Actor("test", set_store_loc)
+ act = Actor("test")
links = {"1": "one", "2": "two"}
- act.setLinks(links)
+ act.set_links(links)
newName = "3"
newLink = "three"
- act.addLink(newName, newLink)
+ act.add_link(newName, newLink)
links.update({"3": "three"})
# trying to check for two separate conditions while being able to
@@ -234,7 +189,7 @@ def test_addLink(setup_store, set_store_loc):
passes = []
err_messages = []
- if act.getLinks()["3"] == "three":
+ if act.get_links()["3"] == "three":
passes.append(True)
else:
passes.append(False)
@@ -243,7 +198,7 @@ def test_addLink(setup_store, set_store_loc):
actor.getLinks()['3'] is not equal to \"three\""
)
- if act.getLinks() == links:
+ if act.get_links() == links:
passes.append(True)
else:
passes.append("False")
@@ -256,7 +211,7 @@ def test_addLink(setup_store, set_store_loc):
assert all(passes), f"The following errors occurred: {err_out}"
-def test_getLinks(init_actor, example_string_links):
+def test_get_links(init_actor, example_string_links):
"""Tests if we can access the dictionary of links.
TODO:
@@ -265,25 +220,9 @@ def test_getLinks(init_actor, example_string_links):
act = init_actor
links = example_string_links
- act.setLinks(links)
-
- assert act.getLinks() == {"1": "one", "2": "two", "3": "three"}
-
+ act.set_links(links)
-@pytest.mark.skip(
- reason="this is something we'll do later because\
- we will subclass actor w/ watcher later"
-)
-def test_put(init_actor):
- """Tests if data keys can be put to output links.
-
- TODO:
- Ask Anne to take a look.
- """
-
- act = init_actor
- act.put()
- assert True
+ assert act.get_links() == {"1": "one", "2": "two", "3": "three"}
def test_run(init_actor):
@@ -293,12 +232,12 @@ def test_run(init_actor):
act.run()
-def test_changePriority(init_actor):
+def test_change_priority(init_actor):
"""Tests if we are able to change the priority of an actor."""
act = init_actor
act.lower_priority = True
- act.changePriority()
+ act.change_priority()
assert psutil.Process(os.getpid()).nice() == 19
@@ -311,39 +250,23 @@ def test_actor_connection(setup_store, server_port_num):
one actor. Then, in the other actor, it is removed from the queue, and
checked to verify it matches the original message.
"""
- act1 = Actor("a1", server_port_num)
- act2 = Actor("a2", server_port_num)
-
- StoreInterface(server_port_num=server_port_num)
- link = Link("L12", act1, act2)
- act1.setLinkIn(link)
- act2.setLinkOut(link)
-
- msg = "message"
-
- act1.q_in.put(msg)
+ assert True
- assert act2.q_out.get() == msg
+def test_actor_registration_with_nexus(ports, zmq_actor):
+ context = zmq.Context()
+ nex_socket = context.socket(zmq.REP)
+ nex_socket.bind(f"tcp://*:{ports[3]}") # actor in port
-def test_plasma_actor_connection(setup_plasma_store, set_store_loc):
- """Test if the links between actors are established correctly.
+ zmq_actor.start()
- This test instantiates two actors with different names, then instantiates
- a Link object linking the two actors. A string is put to the input queue of
- one actor. Then, in the other actor, it is removed from the queue, and
- checked to verify it matches the original message.
- """
- act1 = Actor("a1", set_store_loc)
- act2 = Actor("a2", set_store_loc)
+ res = nex_socket.recv_pyobj()
+ assert isinstance(res, ActorStateMsg)
- PlasmaStoreInterface(store_loc=set_store_loc)
- link = Link("L12", act1, act2)
- act1.setLinkIn(link)
- act2.setLinkOut(link)
+ nex_socket.send_pyobj(ActorStateReplyMsg("test", "OK", ""))
- msg = "message"
+ zmq_actor.terminate()
+ zmq_actor.join(10)
- act1.q_in.put(msg)
- assert act2.q_out.get() == msg
+# TODO: register with broker test
diff --git a/test/test_cli.py b/test/test_cli.py
index c4b34bff..beba7db3 100644
--- a/test/test_cli.py
+++ b/test/test_cli.py
@@ -1,36 +1,29 @@
+import time
+
import pytest
import os
import datetime
from collections import namedtuple
import subprocess
-import asyncio
import signal
-import improv.cli as cli
-
-from test_nexus import ports
-SERVER_WARMUP = 16
-SERVER_TIMEOUT = 16
+import improv.cli as cli
+from conftest import ports
-@pytest.fixture
-def setdir():
- prev = os.getcwd()
- os.chdir(os.path.dirname(__file__))
- yield None
- os.chdir(prev)
+SERVER_WARMUP = 10
+SERVER_TIMEOUT = 15
@pytest.fixture
-async def server(setdir, ports):
+def server(setdir, ports):
"""
Sets up a server using minimal.yaml in the configs folder.
Requires the actor path command line argument and so implicitly
tests that as well.
"""
- os.chdir("configs")
- control_port, output_port, logging_port = ports
+ control_port, output_port, logging_port, actor_in_port = ports
# start server
server_opts = [
@@ -49,23 +42,18 @@ async def server(setdir, ports):
"minimal.yaml",
]
- server = subprocess.Popen(
- server_opts, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
- )
- await asyncio.sleep(SERVER_WARMUP)
+ server = subprocess.Popen(server_opts)
+ time.sleep(SERVER_WARMUP)
yield server
- server.wait(SERVER_TIMEOUT)
- try:
- os.remove("testlog")
- except FileNotFoundError:
- pass
+ if server.poll() is None:
+ pytest.fail("Server did not shut down correctly.")
@pytest.fixture
-async def cli_args(setdir, ports):
+def cli_args(setdir, ports):
logfile = "tmp.log"
- control_port, output_port, logging_port = ports
- config_file = "configs/minimal.yaml"
+ control_port, output_port, logging_port, actor_in_port = ports
+ config_file = "minimal.yaml"
Args = namedtuple(
"cli_args",
"control_port output_port logging_port logfile configfile actor_path",
@@ -86,7 +74,7 @@ def test_configfile_required(setdir):
cli.parse_cli_args(["server", "does_not_exist.yaml"])
-def test_multiple_actor_path(setdir):
+def test_multiple_actor_path(set_dir_config_parent):
args = cli.parse_cli_args(
["run", "-a", "actors", "-a", "configs", "configs/blank_file.yaml"]
)
@@ -113,7 +101,7 @@ def test_multiple_actor_path(setdir):
],
)
def test_can_override_ports(mode, flag, expected, setdir):
- file = "configs/blank_file.yaml"
+ file = "blank_file.yaml"
localhost = "127.0.0.1:"
params = {
"-c": "control_port",
@@ -142,7 +130,7 @@ def test_can_override_ports(mode, flag, expected, setdir):
],
)
def test_non_port_is_error(mode, flag, expected):
- file = "configs/blank_file.yaml"
+ file = "blank_file.yaml"
with pytest.raises(SystemExit):
cli.parse_cli_args([mode, flag, expected, file])
@@ -182,27 +170,29 @@ def test_can_override_ip(mode, flag, expected):
assert vars(args)[params[flag]] == expected
-async def test_sigint_kills_server(server):
+def test_sigint_kills_server(server):
server.send_signal(signal.SIGINT)
+ server.wait(SERVER_TIMEOUT)
-async def test_improv_list_nonempty(server):
+def test_improv_list_nonempty(server):
proc_list = cli.run_list("", printit=False)
assert len(proc_list) > 0
server.send_signal(signal.SIGINT)
+ server.wait(SERVER_TIMEOUT)
-async def test_improv_kill_empties_list(server):
+def test_improv_kill_empties_list(server):
proc_list = cli.run_list("", printit=False)
assert len(proc_list) > 0
cli.run_cleanup("", headless=True)
proc_list = cli.run_list("", printit=False)
assert len(proc_list) == 0
+ server.wait(SERVER_TIMEOUT)
-async def test_improv_run_writes_stderr_to_log(setdir, ports):
- os.chdir("configs")
- control_port, output_port, logging_port = ports
+def test_improv_run_writes_stderr_to_log(setdir, ports):
+ control_port, output_port, logging_port, actor_in_port = ports
# start server
server_opts = [
@@ -223,17 +213,19 @@ async def test_improv_run_writes_stderr_to_log(setdir, ports):
server = subprocess.Popen(
server_opts, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
)
- await asyncio.sleep(SERVER_WARMUP)
+ time.sleep(SERVER_WARMUP)
server.kill()
server.wait(SERVER_TIMEOUT)
- with open("testlog") as log:
+ with open("improv-debug.log") as log:
contents = log.read()
+ print(contents)
assert "Traceback" in contents
+ os.remove("improv-debug.log")
os.remove("testlog")
cli.run_cleanup("", headless=True)
-async def test_get_ports_from_logfile(setdir):
+def test_get_ports_from_logfile(setdir):
test_control_port = 53349
test_output_port = 53350
test_logging_port = 53351
@@ -258,8 +250,8 @@ async def test_get_ports_from_logfile(setdir):
assert logging_port == test_logging_port
-async def test_no_server_start_in_logfile_raises_error(setdir, cli_args, capsys):
- with open(cli_args.logfile, mode="w") as f:
+def test_no_server_start_in_logfile_raises_error(setdir, cli_args, capsys):
+ with open("improv-debug.log", mode="w") as f:
f.write("this is some placeholder text")
cli.get_server_ports(cli_args, timeout=1)
@@ -267,18 +259,18 @@ async def test_no_server_start_in_logfile_raises_error(setdir, cli_args, capsys)
captured = capsys.readouterr()
assert "Unable to read server start time" in captured.out
- os.remove(cli_args.logfile)
+ os.remove("improv-debug.log")
cli.run_cleanup("", headless=True)
-async def test_no_ports_in_logfile_raises_error(setdir, cli_args, capsys):
+def test_no_ports_in_logfile_raises_error(setdir, cli_args, capsys):
curr_dt = datetime.datetime.now().replace(microsecond=0)
- with open(cli_args.logfile, mode="w") as f:
+ with open("improv-debug.log", mode="w") as f:
f.write(f"{curr_dt} Server running on (control, output, log) ports XXX\n")
cli.get_server_ports(cli_args, timeout=1)
captured = capsys.readouterr()
- assert f"Unable to read ports from {cli_args.logfile}." in captured.out
+ assert f"Unable to read ports from {'improv-debug.log'}." in captured.out
- os.remove(cli_args.logfile)
+ os.remove("improv-debug.log")
cli.run_cleanup("", headless=True)
diff --git a/test/test_config.py b/test/test_config.py
index 12cc79bf..f0230dcc 100644
--- a/test/test_config.py
+++ b/test/test_config.py
@@ -6,13 +6,14 @@
# from importlib import import_module
# from improv.config import RepeatedActorError
-from improv.config import Config
+from improv.config import Config, CannotCreateConfigException
from improv.utils import checks
import logging
logger = logging.getLogger(__name__)
+
# set global variables
@@ -34,61 +35,42 @@ def test_init(test_input, set_configdir):
"""
cfg = Config(test_input)
- assert cfg.configFile == test_input
-
-
-# def test_init_attributes():
-# """ Tests if config has correct default attributes on initialization.
-
-# Checks if actors, connection, and hasGUI are all empty or
-# nonexistent. Detects errors by maintaining a list of errors, and
-# then adding to it every time an unexpected behavior is encountered.
-
-# Asserts:
-# If the default attributes are empty or nonexistent.
+ assert cfg.config_file == test_input
-# """
-# cfg = config()
-# errors = []
-
-# if(cfg.actors != {}):
-# errors.append("config.actors is not empty! ")
-# if(cfg.connections != {}):
-# errors.append("config.connections is not empty! ")
-# if(cfg.hasGUI):
-# errors.append("config.hasGUI already exists! ")
-
-# assert not errors, "The following errors occurred:\n{}".format(
-# "\n".join(errors))
-
-
-def test_createConfig_settings(set_configdir):
+def test_create_config_settings(set_configdir):
"""Check if the default way config creates config.settings is correct.
Asserts:
- If the default setting is the dictionary {"use_watcher": "None"}
+ If the default setting is the config as a dict.
"""
cfg = Config("good_config.yaml")
- cfg.createConfig()
- assert cfg.settings == {"use_watcher": False}
+ cfg.parse_config()
+ cfg.create_config()
+ assert cfg.settings == {
+ "control_port": 5555,
+ "output_port": 5556,
+ "store_size": 250_000_000,
+ "actor_in_port": 0,
+ "harvest_data_from_memory": None,
+ }
-# File with syntax error cannot pass the format check
-# def test_createConfig_init_typo(set_configdir):
-# """Tests if createConfig can catch actors with errors in init function.
+def test_create_config_init_typo(set_configdir):
+ """Tests if createConfig can catch actors with errors in init function.
-# Asserts:
-# If createConfig raise any errors.
-# """
+ Asserts:
+ If createConfig raise any errors.
+ """
-# cfg = config("minimal_wrong_init.yaml")
-# res = cfg.createConfig()
-# assert res == -1
+ cfg = Config("minimal_wrong_init.yaml")
+ cfg.parse_config()
+ res = cfg.create_config()
+ assert res == -1
-def test_createConfig_wrong_import(set_configdir):
+def test_create_config_wrong_import(set_configdir):
"""Tests if createConfig can catch actors with errors during import.
Asserts:
@@ -96,11 +78,12 @@ def test_createConfig_wrong_import(set_configdir):
"""
cfg = Config("minimal_wrong_import.yaml")
- res = cfg.createConfig()
+ cfg.parse_config()
+ res = cfg.create_config()
assert res == -1
-def test_createConfig_clean(set_configdir):
+def test_create_config_clean(set_configdir):
"""Tests if createConfig runs without error given a good config.
Asserts:
@@ -108,52 +91,57 @@ def test_createConfig_clean(set_configdir):
"""
cfg = Config("good_config.yaml")
+ cfg.parse_config()
try:
- cfg.createConfig()
+ cfg.create_config()
except Exception as exc:
pytest.fail(f"createConfig() raised an exception {exc}")
-def test_createConfig_noActor(set_configdir):
+def test_create_config_no_actor(set_configdir):
"""Tests if AttributeError is raised when there are no actors."""
cfg = Config("no_actor.yaml")
+ cfg.parse_config()
with pytest.raises(AttributeError):
- cfg.createConfig()
+ cfg.create_config()
-def test_createConfig_ModuleNotFound(set_configdir):
+def test_create_config_module_not_found(set_configdir):
"""Tests if an error is raised when the package can"t be found."""
cfg = Config("bad_package.yaml")
- res = cfg.createConfig()
+ cfg.parse_config()
+ res = cfg.create_config()
assert res == -1
-def test_createConfig_class_ImportError(set_configdir):
+def test_create_config_class_import_error(set_configdir):
"""Tests if an error is raised when the class name is invalid."""
cfg = Config("bad_class.yaml")
- res = cfg.createConfig()
+ cfg.parse_config()
+ res = cfg.create_config()
assert res == -1
-def test_createConfig_AttributeError(set_configdir):
+def test_create_config_attribute_error(set_configdir):
"""Tests if AttributeError is raised."""
cfg = Config("bad_class.yaml")
- res = cfg.createConfig()
+ cfg.parse_config()
+ res = cfg.create_config()
assert res == -1
-def test_createConfig_blank_file(set_configdir):
+def test_create_config_blank_file(set_configdir):
"""Tests if a blank config file raises an error."""
- with pytest.raises(TypeError):
+ with pytest.raises(CannotCreateConfigException):
Config("blank_file.yaml")
-def test_createConfig_nonsense_file(set_configdir, caplog):
+def test_create_config_nonsense_file(set_configdir, caplog):
"""Tests if an improperly formatted config raises an error."""
with pytest.raises(TypeError):
@@ -170,12 +158,13 @@ def test_cyclic_graph(set_configdir):
assert not checks.check_if_connections_acyclic(path)
-def test_saveActors_clean(set_configdir):
+def test_save_actors_clean(set_configdir):
"""Compares internal actor representation to what was saved in the file."""
cfg = Config("good_config.yaml")
- cfg.createConfig()
- cfg.saveActors()
+ cfg.parse_config()
+ cfg.create_config()
+ cfg.save_actors()
with open("good_config_actors.yaml") as savedConfig:
data = yaml.safe_load(savedConfig)
@@ -185,9 +174,116 @@ def test_saveActors_clean(set_configdir):
assert savedKeys == originalKeys
+ os.remove("good_config_actors.yaml")
+
def test_config_settings_read(set_configdir):
cfg = Config("minimal_with_settings.yaml")
- cfg.createConfig()
+ cfg.parse_config()
+ cfg.create_config()
assert "store_size" in cfg.settings
+
+
+def test_config_bad_actor_args(set_configdir):
+ cfg = Config("bad_args.yaml")
+ cfg.parse_config()
+ res = cfg.create_config()
+ assert res == -1
+
+
+def test_config_harvester_disabled(set_configdir):
+ cfg = Config("minimal.yaml")
+ cfg.config = dict()
+ cfg.config["settings"] = dict()
+ cfg.parse_config()
+ assert cfg.settings["harvest_data_from_memory"] is None
+
+
+def test_config_harvester_enabled(set_configdir):
+ cfg = Config("minimal.yaml")
+ cfg.config = dict()
+ cfg.config["settings"] = dict()
+ cfg.config["settings"]["harvest_data_from_memory"] = True
+ cfg.parse_config()
+ assert cfg.settings["harvest_data_from_memory"]
+
+
+def test_config_redis_aof_enabled_saving_not_specified(set_configdir):
+ cfg = Config("minimal.yaml")
+ cfg.config = dict()
+ cfg.config["redis_config"] = dict()
+ cfg.config["redis_config"]["aof_dirname"] = "test"
+ cfg.parse_config()
+ assert cfg.redis_config["aof_dirname"] == "test"
+ assert cfg.redis_config["enable_saving"] is True
+
+
+def test_config_redis_ephemeral_dirname_enabled_saving_not_specified(set_configdir):
+ cfg = Config("minimal.yaml")
+ cfg.config = dict()
+ cfg.config["redis_config"] = dict()
+ cfg.config["redis_config"]["generate_ephemeral_aof_dirname"] = True
+ cfg.parse_config()
+ assert cfg.redis_config["generate_ephemeral_aof_dirname"] is True
+ assert cfg.redis_config["enable_saving"] is True
+
+
+def test_config_redis_ephemeral_dirname_and_aof_dirname_specified(set_configdir):
+ cfg = Config("minimal.yaml")
+ cfg.config = dict()
+ cfg.config["redis_config"] = dict()
+ cfg.config["redis_config"]["generate_ephemeral_aof_dirname"] = True
+ cfg.config["redis_config"]["aof_dirname"] = "test"
+ with pytest.raises(
+ Exception, match="Cannot use unique dirname and use the one provided."
+ ):
+ cfg.parse_config()
+
+
+def test_config_redis_ephemeral_dirname_enabled_saving_disabled(set_configdir):
+ cfg = Config("minimal.yaml")
+ cfg.config = dict()
+ cfg.config["redis_config"] = dict()
+ cfg.config["redis_config"]["generate_ephemeral_aof_dirname"] = True
+ cfg.config["redis_config"]["enable_saving"] = False
+ with pytest.raises(Exception, match="Cannot persist to disk with saving disabled."):
+ cfg.parse_config()
+
+
+def test_config_redis_aof_dirname_enabled_saving_disabled(set_configdir):
+ cfg = Config("minimal.yaml")
+ cfg.config = dict()
+ cfg.config["redis_config"] = dict()
+ cfg.config["redis_config"]["aof_dirname"] = "test"
+ cfg.config["redis_config"]["enable_saving"] = False
+ with pytest.raises(Exception, match="Cannot persist to disk with saving disabled."):
+ cfg.parse_config()
+
+
+def test_config_redis_fsync_enabled_saving_disabled(set_configdir):
+ cfg = Config("minimal.yaml")
+ cfg.config = dict()
+ cfg.config["redis_config"] = dict()
+ cfg.config["redis_config"]["fsync_frequency"] = "always"
+ cfg.config["redis_config"]["enable_saving"] = False
+ with pytest.raises(Exception, match="Cannot persist to disk with saving disabled."):
+ cfg.parse_config()
+
+
+def test_config_redis_unknown_fsync_freq(set_configdir):
+ cfg = Config("minimal.yaml")
+ cfg.config = dict()
+ cfg.config["redis_config"] = dict()
+ cfg.config["redis_config"]["fsync_frequency"] = "unknown"
+ with pytest.raises(Exception, match="Cannot use unknown fsync frequency unknown"):
+ cfg.parse_config()
+
+
+def test_config_gui(set_configdir):
+ cfg = Config("minimal_gui.yaml")
+ cfg.parse_config()
+ cfg.create_config()
+
+ assert cfg.hasGUI is True
+ assert cfg.gui.classname == "Generator"
diff --git a/test/test_demos.py b/test/test_demos.py
index bc6e96c6..97f16f0e 100644
--- a/test/test_demos.py
+++ b/test/test_demos.py
@@ -4,18 +4,15 @@
import os
import asyncio
import subprocess
+
import improv.tui as tui
-import concurrent.futures
import logging
-from demos.sample_actors.zmqActor import ZmqActor
-from test_nexus import ports
-
LOGGER = logging.getLogger(__name__)
LOGGER.setLevel(logging.DEBUG)
-
SERVER_WARMUP = 10
+UNUSED_TCP_PORT = 10567
@pytest.fixture
@@ -35,17 +32,25 @@ def ip():
return pytest.ip
+@pytest.fixture
+def unused_tcp_port():
+ global UNUSED_TCP_PORT
+ pytest.unused_tcp_port = UNUSED_TCP_PORT
+ yield pytest.unused_tcp_port
+ UNUSED_TCP_PORT += 1
+
+
+@pytest.mark.asyncio
@pytest.mark.parametrize(
("dir", "configfile", "logfile"),
[
("minimal", "minimal.yaml", "testlog"),
- ("minimal", "minimal_plasma.yaml", "testlog"),
],
)
async def test_simple_boot_and_quit(dir, configfile, logfile, setdir, ports):
os.chdir(dir)
- control_port, output_port, logging_port = ports
+ control_port, output_port, logging_port, actor_in_port = ports
# start server
server_opts = [
@@ -81,21 +86,22 @@ async def test_simple_boot_and_quit(dir, configfile, logfile, setdir, ports):
assert not pilot.app._running
# wait on server to fully shut down
- server.wait(10)
- os.remove(logfile) # later, might want to read this file and check for messages
+ server.wait(15)
+ # os.remove(logfile) # later, might want to read this file and check for messages
+@pytest.mark.asyncio
@pytest.mark.parametrize(
("dir", "configfile", "logfile", "datafile"),
[
("minimal", "minimal.yaml", "testlog", "sample_generator_data.npy"),
- ("minimal", "minimal_spawn.yaml", "testlog", "sample_generator_data.npy"),
+ ("minimal", "minimal_persistence.yaml", "testlog", "test_persistence.csv"),
],
)
async def test_stop_output(dir, configfile, logfile, datafile, setdir, ports):
os.chdir(dir)
- control_port, output_port, logging_port = ports
+ control_port, output_port, logging_port, actor_in_port = ports
# start server
server_opts = [
@@ -143,61 +149,60 @@ async def test_stop_output(dir, configfile, logfile, datafile, setdir, ports):
os.remove(logfile) # later, might want to read this file and check for messages
-def test_zmq_ps(ip, unused_tcp_port):
- """Tests if we can set the zmq PUB/SUB socket and send message."""
- port = unused_tcp_port
- LOGGER.info("beginning test")
- act1 = ZmqActor("act1", type="PUB", ip=ip, port=port)
- act2 = ZmqActor("act2", type="SUB", ip=ip, port=port)
- LOGGER.info("ZMQ Actors constructed")
- # Note these sockets must be set up for testing
- # this is not needed for running in improv
- act1.setSendSocket()
- act2.setRecvSocket()
-
- msg = "hello"
- act1.put(msg)
- LOGGER.info("sent message")
- recvmsg = act2.get()
- LOGGER.info("received message")
- assert recvmsg == msg
-
-
-def test_zmq_rr(ip, unused_tcp_port):
- """Tests if we can set the zmq REQ/REP socket and send message."""
- port = unused_tcp_port
- act1 = ZmqActor("act1", "/tmp/store", type="REQ", ip=ip, port=port)
- act2 = ZmqActor("act2", "/tmp/store", type="REP", ip=ip, port=port)
- msg = "hello"
- reply = "world"
-
- def handle_request():
- return act1.put(msg)
-
- def handle_reply():
- return act2.get(reply)
-
- # Use a ThreadPoolExecutor to run handle_request()
- # and handle_reply() in separate threads.
-
- with concurrent.futures.ThreadPoolExecutor() as executor:
- future1 = executor.submit(handle_request)
- future2 = executor.submit(handle_reply)
-
- # Ensure the request is sent before the reply.
- request_result = future1.result()
- reply_result = future2.result()
-
- # Check if the received message is equal to the original message.
- assert reply_result == msg
- # Check if the reply is correct.
- assert request_result == reply
-
-
-def test_zmq_rr_timeout(ip, unused_tcp_port):
- """Test for requestMsg where we timeout or fail to send"""
- port = unused_tcp_port
- act1 = ZmqActor("act1", "/tmp/store", type="REQ", ip=ip, port=port)
- msg = "hello"
- replymsg = act1.put(msg)
- assert replymsg is None
+@pytest.mark.skip
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ ("dir", "configfile", "logfile", "datafile"),
+ [
+ ("minimal", "minimal_spawn.yaml", "testlog", "sample_generator_data.npy"),
+ ],
+)
+async def test_stop_output_spawn(dir, configfile, logfile, datafile, setdir, ports):
+ os.chdir(dir)
+
+ control_port, output_port, logging_port, actor_in_port = ports
+
+ # start server
+ server_opts = [
+ "improv",
+ "server",
+ "-c",
+ str(control_port),
+ "-o",
+ str(output_port),
+ "-l",
+ str(logging_port),
+ "-f",
+ logfile,
+ configfile,
+ ]
+
+ with open(logfile, mode="a+") as log:
+ server = subprocess.Popen(server_opts, stdout=log, stderr=log)
+ await asyncio.sleep(SERVER_WARMUP)
+
+ # initialize client
+ app = tui.TUI(control_port, output_port, logging_port)
+
+ # run client
+ async with app.run_test() as pilot:
+ print("running pilot")
+ await pilot.press(*"setup", "enter")
+ await pilot.pause(0.5)
+ await pilot.press(*"run", "enter")
+ await pilot.pause(1)
+ await pilot.press(*"stop", "enter")
+ await pilot.pause(2)
+ await pilot.press(*"quit", "enter")
+ await pilot.pause(3)
+ assert not pilot.app._running
+
+ # wait on server to fully shut down
+ server.wait(10)
+
+ # check that the file written by Generator's stop function got written
+ os.path.isfile(datafile)
+
+ # then remove that file and logile
+ os.remove(datafile)
+ os.remove(logfile) # later, might want to read this file and check for messages
diff --git a/test/test_harvester.py b/test/test_harvester.py
new file mode 100644
index 00000000..7f078573
--- /dev/null
+++ b/test/test_harvester.py
@@ -0,0 +1,224 @@
+import logging
+import signal
+import time
+
+import pytest
+import zmq
+from zmq.backend.cython._zmq import SocketOption
+
+from improv.harvester import RedisHarvester
+from improv.store import RedisStoreInterface
+
+from improv.link import ZmqLink
+from improv.messaging import HarvesterInfoReplyMsg
+from conftest import SignalManager
+
+
+def test_harvester_shuts_down_on_sigint(setup_store, harvester):
+ harvester_ports, broker_socket, p = harvester
+ ctx = zmq.Context()
+ s = ctx.socket(zmq.REP)
+ s.bind(f"tcp://*:{harvester_ports[3]}")
+ s.recv_pyobj()
+ reply = HarvesterInfoReplyMsg("harvester", "OK", "")
+ s.send_pyobj(reply)
+ time.sleep(2)
+ p.terminate()
+ p.join(5)
+ if p.exitcode is None:
+ p.kill()
+ pytest.fail("Harvester did not exit in time")
+ else:
+ assert True
+
+ s.close(linger=0)
+ ctx.destroy(linger=0)
+
+
+def test_harvester_relieves_memory_pressure(setup_store, harvester):
+ store_interface = RedisStoreInterface()
+ harvester_ports, broker_socket, p = harvester
+ broker_link = ZmqLink(broker_socket, "test", "test topic")
+ ctx = zmq.Context()
+ s = ctx.socket(zmq.REP)
+ s.bind(f"tcp://*:{harvester_ports[3]}")
+ s.recv_pyobj()
+ reply = HarvesterInfoReplyMsg("harvester", "OK", "")
+ s.send_pyobj(reply)
+ client = RedisStoreInterface()
+ for i in range(10):
+ message = [i for i in range(500000)]
+ key = client.put(message)
+ broker_link.put(key)
+ time.sleep(2)
+
+ db_info = store_interface.client.info()
+ max_memory = db_info["maxmemory"]
+ used_memory = db_info["used_memory"]
+ used_max_ratio = used_memory / max_memory
+ assert used_max_ratio <= 0.75
+
+ p.terminate()
+ p.join(5)
+ if p.exitcode is None:
+ p.kill()
+ pytest.fail("Harvester did not exit in time")
+
+ s.close(linger=0)
+ ctx.destroy(linger=0)
+
+
+def test_harvester_stop_logs_and_halts_running():
+ with SignalManager():
+ ctx = zmq.Context()
+ s = ctx.socket(zmq.PULL)
+ s.bind("tcp://*:0")
+ pull_port_string = s.getsockopt_string(SocketOption.LAST_ENDPOINT)
+ pull_port = int(pull_port_string.split(":")[-1])
+ harvester = RedisHarvester(
+ nexus_hostname="localhost",
+ nexus_comm_port=10000, # never gets called in this test
+ redis_hostname="localhost",
+ redis_port=10000, # never gets called in this test
+ broker_hostname="localhost",
+ broker_port=10000, # never gets called in this test
+ logger_hostname="localhost",
+ logger_port=pull_port,
+ )
+ harvester.stop(signal.SIGINT, None)
+ assert not harvester.running
+ msg_available = s.poll(timeout=1000)
+ assert msg_available
+ record = s.recv_json()
+ assert (
+ record["message"]
+ == f"Harvester shutting down due to signal {signal.SIGINT}"
+ )
+
+ s.close(linger=0)
+ ctx.destroy(linger=0)
+
+
+# @pytest.mark.skip
+def test_harvester_relieves_memory_pressure_one_loop(ports, setup_store):
+ def harvest_and_quit(harvester_instance: RedisHarvester):
+ harvester_instance.collect()
+ harvester_instance.stop(signal.SIGINT, None)
+
+ with SignalManager():
+ try:
+ ctx = zmq.Context()
+ nex_s = ctx.socket(zmq.REP)
+ nex_s.bind(f"tcp://*:{ports[3]}")
+
+ broker_s = ctx.socket(zmq.PUB)
+ broker_s.bind("tcp://*:1234")
+ broker_link = ZmqLink(broker_s, "test", "test topic")
+
+ log_s = ctx.socket(zmq.PULL)
+ log_s.bind("tcp://*:0")
+ pull_port_string = log_s.getsockopt_string(SocketOption.LAST_ENDPOINT)
+ pull_port = int(pull_port_string.split(":")[-1])
+
+ harvester = RedisHarvester(
+ nexus_hostname="localhost",
+ nexus_comm_port=ports[3], # never gets called in this test
+ redis_hostname="localhost",
+ redis_port=6379, # never gets called in this test
+ broker_hostname="localhost",
+ broker_port=1234, # never gets called in this test
+ logger_hostname="localhost",
+ logger_port=pull_port,
+ )
+
+ harvester.establish_connections()
+
+ client = RedisStoreInterface()
+ for i in range(9):
+ message = [i for i in range(500000)]
+ try:
+ key = client.put(message)
+ broker_link.put(key)
+ except Exception as e:
+ logging.warning(e)
+ logging.warning("Proceeding under the assumption Redis is full.")
+ break
+ time.sleep(2)
+
+ harvester.serve(harvest_and_quit, harvester_instance=harvester)
+
+ db_info = client.client.info()
+ max_memory = db_info["maxmemory"]
+ used_memory = db_info["used_memory"]
+ used_max_ratio = used_memory / max_memory
+
+ assert used_max_ratio <= 0.5
+ assert not harvester.running
+ assert harvester.nexus_socket.closed
+ assert harvester.sub_socket.closed
+ assert harvester.zmq_context.closed
+ logging.info("Passed harvester one loop test")
+
+ except Exception as e:
+ logging.error(f"Encountered exception {e} during execution of test")
+ print(e)
+ pass
+
+ try:
+
+ nex_s.close(linger=0)
+ broker_s.close(linger=0)
+ log_s.close(linger=0)
+ ctx.destroy(linger=0)
+ except Exception as e:
+ logging.error(f"Encountered exception {e} during teardown")
+ print(e)
+ pass
+
+
+def test_harvester_loops_with_no_memory_pressure(ports, setup_store):
+ def harvest_and_quit(harvester_instance: RedisHarvester):
+ harvester_instance.collect()
+ harvester_instance.stop(signal.SIGINT, None)
+
+ with SignalManager():
+
+ ctx = zmq.Context()
+ nex_s = ctx.socket(zmq.REP)
+ nex_s.bind(f"tcp://*:{ports[3]}")
+
+ log_s = ctx.socket(zmq.PULL)
+ log_s.bind("tcp://*:0")
+ pull_port_string = log_s.getsockopt_string(SocketOption.LAST_ENDPOINT)
+ pull_port = int(pull_port_string.split(":")[-1])
+
+ harvester = RedisHarvester(
+ nexus_hostname="localhost",
+ nexus_comm_port=ports[3], # never gets called in this test
+ redis_hostname="localhost",
+ redis_port=6379, # never gets called in this test
+ broker_hostname="localhost",
+ broker_port=1234, # never gets called in this test
+ logger_hostname="localhost",
+ logger_port=pull_port,
+ )
+
+ harvester.establish_connections()
+
+ client = RedisStoreInterface()
+
+ harvester.serve(harvest_and_quit, harvester_instance=harvester)
+
+ db_info = client.client.info()
+ max_memory = db_info["maxmemory"]
+ used_memory = db_info["used_memory"]
+ used_max_ratio = used_memory / max_memory
+ assert used_max_ratio <= 0.5
+ assert not harvester.running
+ assert harvester.nexus_socket.closed
+ assert harvester.sub_socket.closed
+ assert harvester.zmq_context.closed
+
+ nex_s.close(linger=0)
+ log_s.close(linger=0)
+ ctx.destroy(linger=0)
diff --git a/test/test_link.py b/test/test_link.py
index eec9a521..90056c20 100644
--- a/test/test_link.py
+++ b/test/test_link.py
@@ -1,170 +1,137 @@
-import asyncio
-import queue
-import subprocess
+import json
import time
import pytest
-
+import zmq
from improv.actor import Actor
+from zmq import SocketOption
-from improv.link import Link
+from improv.link import ZmqLink
-def init_actors(n=1):
- """Function to return n unique actors.
+@pytest.fixture
+def test_sub_link():
+ """Fixture to provide a commonly used Link object."""
+ ctx = zmq.Context()
+ link_socket = ctx.socket(zmq.SUB)
+ link_socket.bind("tcp://*:0")
+ link_socket_string = link_socket.getsockopt_string(SocketOption.LAST_ENDPOINT)
+ link_socket_port = int(link_socket_string.split(":")[-1])
- Returns:
- list: A list of n actors, each being named after its index.
- """
+ link_socket.poll(timeout=0)
- # the links must be specified as an empty dictionary to avoid
- # actors sharing a dictionary of links
+ link_pub_socket = ctx.socket(zmq.PUB)
+ link_pub_socket.connect(f"tcp://localhost:{link_socket_port}")
- return [Actor("test " + str(i), "/tmp/store", links={}) for i in range(n)]
+ link_socket.poll(timeout=0) # open libzmq bug (not pyzmq)
+ link_socket.subscribe("test_topic")
+ time.sleep(0.5)
+ link_socket.poll(timeout=0)
+ link = ZmqLink(link_socket, "test_link", "test_topic")
-@pytest.fixture
-def example_link(setup_store):
- """Fixture to provide a commonly used Link object."""
- act = init_actors(2)
- lnk = Link("Example", act[0].name, act[1].name)
- yield lnk
- lnk = None
+ yield link, link_pub_socket
+ link.socket.close(linger=0)
+ link_pub_socket.close(linger=0)
+ ctx.destroy(linger=0)
@pytest.fixture
-def example_actor_system(setup_store):
- """Fixture to provide a list of 4 connected actors."""
-
- # store = setup_store
- acts = init_actors(4)
-
- L01 = Link("L01", acts[0].name, acts[1].name)
- L13 = Link("L13", acts[1].name, acts[3].name)
- L12 = Link("L12", acts[1].name, acts[2].name)
- L23 = Link("L23", acts[2].name, acts[3].name)
-
- links = [L01, L13, L12, L23]
+def test_pub_link():
+ """Fixture to provide a commonly used Link object."""
+ ctx = zmq.Context()
+ link_socket = ctx.socket(zmq.PUB)
+ link_socket.bind("tcp://*:0")
+ link_socket_string = link_socket.getsockopt_string(SocketOption.LAST_ENDPOINT)
+ link_socket_port = int(link_socket_string.split(":")[-1])
- acts[0].addLink("q_out_1", L01)
- acts[1].addLink("q_out_1", L13)
- acts[1].addLink("q_out_2", L12)
- acts[2].addLink("q_out_1", L23)
+ link_sub_socket = ctx.socket(zmq.SUB)
+ link_sub_socket.connect(f"tcp://localhost:{link_socket_port}")
+ link_sub_socket.poll(timeout=0)
+ link_sub_socket.subscribe("test_topic")
+ time.sleep(0.5)
+ link_sub_socket.poll(timeout=0)
- acts[1].addLink("q_in_1", L01)
- acts[2].addLink("q_in_1", L12)
- acts[3].addLink("q_in_1", L13)
- acts[3].addLink("q_in_2", L23)
+ link = ZmqLink(link_socket, "test_link", "test_topic")
- yield [acts, links] # also yield Links
- acts = None
+ yield link, link_sub_socket
+ link.socket.close(linger=0)
+ link_sub_socket.close(linger=0)
+ ctx.destroy(linger=0)
@pytest.fixture
-def _kill_pytest_processes():
- """Kills all processes with "pytest" in their name.
+def test_req_link():
+ """Fixture to provide a commonly used Link object."""
+ ctx = zmq.Context()
+ link_socket = ctx.socket(zmq.REQ)
+ link_socket.bind("tcp://*:0")
+ link_socket_string = link_socket.getsockopt_string(SocketOption.LAST_ENDPOINT)
+ link_socket_port = int(link_socket_string.split(":")[-1])
- NOTE:
- This fixture should only be used at the end of testing.
- """
+ link_rep_socket = ctx.socket(zmq.REP)
+ link_rep_socket.connect(f"tcp://localhost:{link_socket_port}")
- subprocess.Popen(
- ["kill", "`pgrep pytest`"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
- )
+ link = ZmqLink(link_socket, "test_link")
-
-@pytest.mark.parametrize(
- ("attribute", "expected"),
- [
- ("name", "Example"),
- ("real_executor", None),
- ("cancelled_join", False),
- ("status", "pending"),
- ("result", None),
- ],
-)
-def test_Link_init(setup_store, example_link, attribute, expected):
- """Tests if the default initialization attributes are set."""
-
- lnk = example_link
- atr = getattr(lnk, attribute)
- assert atr == expected
+ yield link, link_rep_socket
+ link.socket.close(linger=0)
+ link_rep_socket.close(linger=0)
+ ctx.destroy(linger=0)
-def test_Link_init_start_end(setup_store):
- """Tests if the initialization has the right actors."""
+@pytest.fixture
+def test_rep_link():
+ """Fixture to provide a commonly used Link object."""
+ ctx = zmq.Context()
+ link_socket = ctx.socket(zmq.REP)
+ link_socket.bind("tcp://*:0")
+ link_socket_string = link_socket.getsockopt_string(SocketOption.LAST_ENDPOINT)
+ link_socket_port = int(link_socket_string.split(":")[-1])
- act = init_actors(2)
- lnk = Link("example_link", act[0].name, act[1].name)
+ link_req_socket = ctx.socket(zmq.REQ)
+ link_req_socket.connect(f"tcp://localhost:{link_socket_port}")
- assert lnk.start == act[0].name
- assert lnk.end == act[1].name
+ link = ZmqLink(link_socket, "test_link")
+ yield link, link_req_socket
+ link.socket.close(linger=0)
+ link_req_socket.close(linger=0)
+ ctx.destroy(linger=0)
-def test_getstate(example_link):
- """Tests if __getstate__ has the right values on initialization.
- Gets the dictionary of the link, then compares them against known
- default values. Does not compare store and actors.
+def test_pub_put(test_pub_link):
+ """Tests if messages can be put into the link.
TODO:
- Compare store and actors.
+ Parametrize multiple test input types.
"""
- res = example_link.__getstate__()
- errors = []
- errors.append(res["_real_executor"] is None)
- errors.append(res["cancelled_join"] is False)
-
- assert all(errors)
-
-
-@pytest.mark.parametrize(
- "input",
- [([None]), ([1]), ([i for i in range(5)]), ([str(i**i) for i in range(10)])],
-)
-def test_qsize_empty(example_link, input):
- """Tests that the queue has the number of elements in "input"."""
-
- lnk = example_link
- for i in input:
- lnk.put(i)
-
- qsize = lnk.queue.qsize()
- assert qsize == len(input)
-
-
-def test_getStart(example_link):
- """Tests if getStart returns the starting actor."""
-
- lnk = example_link
-
- assert lnk.getStart() == Actor("test 0", "/tmp/store").name
-
-
-def test_getEnd(example_link):
- """Tests if getEnd returns the ending actor."""
-
- lnk = example_link
+ link, link_sub_socket = test_pub_link
+ msg = "message"
- assert lnk.getEnd() == Actor("test 1", "/tmp/store").name
+ link.put(msg)
+ res = link_sub_socket.recv_multipart()
+ assert res[0].decode("utf-8") == "test_topic"
+ assert json.loads(res[1].decode("utf-8")) == msg
-def test_put(example_link):
+def test_req_put(test_req_link):
"""Tests if messages can be put into the link.
TODO:
Parametrize multiple test input types.
"""
- lnk = example_link
+ link, link_rep_socket = test_req_link
msg = "message"
- lnk.put(msg)
- assert lnk.get() == "message"
+ link.put(msg)
+ res = link_rep_socket.recv_pyobj()
+ assert res == msg
-def test_put_unserializable(example_link, caplog, setup_store):
+def test_put_unserializable(test_pub_link):
"""Tests if an unserializable object raises an error.
Instantiates an actor, which is unserializable, and passes it into
@@ -173,185 +140,199 @@ def test_put_unserializable(example_link, caplog, setup_store):
Raises:
SerializationCallbackError: Actor objects are unserializable.
"""
- # store = setup_store
act = Actor("test", "/tmp/store")
- lnk = example_link
+ link, link_sub_socket = test_pub_link
sentinel = True
try:
- lnk.put(act)
+ link.put(act)
except Exception:
sentinel = False
- assert sentinel, "Unable to put"
- assert str(lnk.get()) == str(act)
+ assert not sentinel
-def test_put_irreducible(example_link, setup_store):
+def test_put_irreducible(test_pub_link, setup_store):
"""Tests if an irreducible object raises an error."""
- lnk = example_link
+ link, link_sub_socket = test_pub_link
store = setup_store
with pytest.raises(TypeError):
- lnk.put(store)
+ link.put(store)
-def test_put_nowait(example_link):
+def test_put_nowait(test_pub_link):
"""Tests if messages can be put into the link without blocking.
TODO:
Parametrize multiple test input types.
"""
- lnk = example_link
+ link, link_sub_socket = test_pub_link
msg = "message"
t_0 = time.perf_counter()
- lnk.put_nowait(msg)
+ link.put(msg) # put is already async even in synchronous zmq
t_1 = time.perf_counter()
t_net = t_1 - t_0
assert t_net < 0.005 # 5 ms
-@pytest.mark.asyncio
-async def test_put_async_success(example_link):
- """Tests if put_async returns None.
-
- TODO:
- Parametrize test input.
- """
-
- lnk = example_link
- msg = "message"
- res = await lnk.put_async(msg)
- assert res is None
-
-
-@pytest.mark.asyncio
-async def test_put_async_multiple(example_link):
+def test_put_multiple(test_pub_link):
"""Tests if async putting multiple objects preserves their order."""
+ link, link_sub_socket = test_pub_link
+
messages = [str(i) for i in range(10)]
messages_out = []
for msg in messages:
- await example_link.put_async(msg)
+ link.put(msg)
for msg in messages:
- messages_out.append(example_link.get())
+ messages_out.append(
+ json.loads(link_sub_socket.recv_multipart()[1].decode("utf-8"))
+ )
assert messages_out == messages
@pytest.mark.asyncio
-async def test_put_and_get_async(example_link):
+async def test_put_and_get_async(test_pub_link):
"""Tests if async get preserves order after async put."""
messages = [str(i) for i in range(10)]
messages_out = []
+ link, link_sub_socket = test_pub_link
+
for msg in messages:
- await example_link.put_async(msg)
+ await link.put_async(msg)
for msg in messages:
- messages_out.append(await example_link.get_async())
+ messages_out.append(
+ json.loads(link_sub_socket.recv_multipart()[1].decode("utf-8"))
+ )
assert messages_out == messages
-@pytest.mark.skip(
- reason="This test needs additional work to cause an overflow in the datastore."
+@pytest.mark.parametrize(
+ "message",
+ [
+ "message",
+ "",
+ None,
+ [str(i) for i in range(5)],
+ ],
)
-def test_put_overflow(setup_store, server_port_num, caplog):
- """Tests if putting too large of an object raises an error."""
-
- p = subprocess.Popen(
- ["redis-server", "--port", str(server_port_num), "--maxmemory", str(1000)],
- stdout=subprocess.DEVNULL,
- stderr=subprocess.DEVNULL,
- )
-
- acts = init_actors(2)
- lnk = Link("L1", acts[0], acts[1])
-
- message = [i for i in range(10**6)] # 24000 bytes
+@pytest.mark.parametrize(
+ "timeout",
+ [
+ None,
+ 5,
+ ],
+)
+def test_sub_get(test_sub_link, message, timeout):
+ """Tests if get gets the correct element from the queue."""
- lnk.put(message)
+ link, link_pub_socket = test_sub_link
- p.kill()
- p.wait()
+ time.sleep(1)
- if caplog.records:
- for record in caplog.records:
- if "PlasmaStoreInterfaceFull" in record.msg:
- assert True
+ if type(message) is list:
+ for i in message:
+ link_pub_socket.send_multipart(
+ [link.topic.encode("utf-8"), json.dumps(i).encode("utf-8")]
+ )
+ expected = message[0]
else:
- pytest.fail("expected an error!")
+ link_pub_socket.send_multipart(
+ [link.topic.encode("utf-8"), json.dumps(message).encode("utf-8")]
+ )
+ expected = message
+
+ assert link.get(timeout=timeout) == expected
@pytest.mark.parametrize(
"message",
[
- ("message"),
- (""),
- (None),
- ([str(i) for i in range(5)]),
+ "message",
+ "",
+ None,
+ [str(i) for i in range(5)],
+ ],
+)
+@pytest.mark.parametrize(
+ "timeout",
+ [
+ None,
+ 5,
],
)
-def test_get(example_link, message):
+def test_rep_get(test_rep_link, message, timeout):
"""Tests if get gets the correct element from the queue."""
- lnk = example_link
+ link, link_req_socket = test_rep_link
- if type(message) is list:
- for i in message:
- lnk.put(i)
- expected = message[0]
- else:
- lnk.put(message)
- expected = message
+ link_req_socket.send_pyobj(message)
+
+ expected = message
- assert lnk.get() == expected
+ assert link.get(timeout=timeout) == expected
-def test_get_empty(example_link):
+def test_get_empty(test_sub_link):
"""Tests if get blocks if the queue is empty."""
- lnk = example_link
- if lnk.queue.empty:
- with pytest.raises(queue.Empty):
- lnk.get(timeout=5.0)
- else:
- pytest.fail("expected a timeout!")
+ link, unused = test_sub_link
+
+ time.sleep(0.1)
+
+ with pytest.raises(TimeoutError):
+ link.get(timeout=0.5)
@pytest.mark.parametrize(
"message",
[
- ("message"),
- (""),
+ "message",
+ "",
([str(i) for i in range(5)]),
],
)
-def test_get_nowait(example_link, message):
+def test_get_nowait(test_sub_link, message):
"""Tests if get_nowait gets the correct element from the queue."""
- lnk = example_link
+ link, link_pub_socket = test_sub_link
if type(message) is list:
for i in message:
- lnk.put(i)
+ link_pub_socket.send_multipart(
+ [link.topic.encode("utf-8"), json.dumps(i).encode("utf-8")]
+ )
expected = message[0]
else:
- lnk.put(message)
+ link_pub_socket.send_multipart(
+ [link.topic.encode("utf-8"), json.dumps(message).encode("utf-8")]
+ )
expected = message
+ for i in range(20):
+ time.sleep(0.1)
+ avail = link.socket.poll(timeout=100)
+ if avail:
+ break
+ else:
+ pytest.fail("Message was not sent to link after 4s")
+
t_0 = time.perf_counter()
- res = lnk.get_nowait()
+ res = link.get_nowait()
t_1 = time.perf_counter()
@@ -359,95 +340,31 @@ def test_get_nowait(example_link, message):
assert t_1 - t_0 < 0.005 # 5 msg
-def test_get_nowait_empty(example_link):
+def test_get_nowait_empty(test_sub_link):
"""Tests if get_nowait raises an error when the queue is empty."""
- lnk = example_link
- if lnk.queue.empty():
- with pytest.raises(queue.Empty):
- lnk.get_nowait()
- else:
- pytest.fail("the queue is not empty")
+ link, unused = test_sub_link
+ with pytest.raises(TimeoutError):
+ link.get_nowait()
@pytest.mark.asyncio
-async def test_get_async_success(example_link):
+async def test_get_async_success(test_sub_link):
"""Tests if async_get gets the correct element from the queue."""
- lnk = example_link
+ link, link_pub_socket = test_sub_link
msg = "message"
- await lnk.put_async(msg)
- res = await lnk.get_async()
+ link_pub_socket.send_multipart(
+ [link.topic.encode("utf-8"), json.dumps(msg).encode("utf-8")]
+ )
+ res = await link.get_async()
assert res == "message"
-@pytest.mark.asyncio
-async def test_get_async_empty(example_link):
- """Tests if get_async times out given an empty queue.
-
- TODO:
- Implement a way to kill the task after execution (subprocess)?
- """
-
- lnk = example_link
- timeout = 5.0
-
- with pytest.raises(asyncio.TimeoutError):
- task = asyncio.create_task(lnk.get_async())
- await asyncio.wait_for(task, timeout)
- task.cancel()
-
- lnk.put("exit") # this is here to break out of get_async()
-
-
-@pytest.mark.skip(reason="unfinished")
-def test_cancel_join_thread(example_link):
- """Tests cancel_join_thread. This test is unfinished
-
- TODO:
- Identify where and when cancel_join_thread is being called.
- """
-
- lnk = example_link
- lnk.cancel_join_thread()
-
- assert lnk._cancelled_join is True
-
-
-@pytest.mark.skip(reason="unfinished")
-@pytest.mark.asyncio
-async def test_join_thread(example_link):
- """Tests join_thread. This test is unfinished
-
- TODO:
- Identify where and when join_thread is being called.
- """
- lnk = example_link
- await lnk.put_async("message")
- # msg = await lnk.get_async()
- lnk.join_thread()
- assert True
-
-
-@pytest.mark.asyncio
-async def test_multi_actor_system(example_actor_system, setup_store):
- """Tests if async puts/gets with many actors have good messages."""
-
- setup_store
-
- graph = example_actor_system
-
- acts = graph[0]
-
- heavy_msg = [str(i) for i in range(10**6)]
- light_msgs = ["message" + str(i) for i in range(3)]
-
- await acts[0].links["q_out_1"].put_async(heavy_msg)
- await acts[1].links["q_out_1"].put_async(light_msgs[0])
- await acts[1].links["q_out_2"].put_async(light_msgs[1])
- await acts[2].links["q_out_1"].put_async(light_msgs[2])
-
- assert await acts[1].links["q_in_1"].get_async() == heavy_msg
- assert await acts[2].links["q_in_1"].get_async() == light_msgs[1]
- assert await acts[3].links["q_in_1"].get_async() == light_msgs[0]
- assert await acts[3].links["q_in_2"].get_async() == light_msgs[2]
+def test_pub_put_no_topic():
+ ctx = zmq.Context()
+ s = ctx.socket(zmq.PUB)
+ with pytest.raises(Exception, match="Cannot open PUB link without topic"):
+ ZmqLink(s, "test")
+ s.close(linger=0)
+ ctx.destroy(linger=0)
diff --git a/test/test_messaging.py b/test/test_messaging.py
new file mode 100644
index 00000000..a8efff3f
--- /dev/null
+++ b/test/test_messaging.py
@@ -0,0 +1,107 @@
+import improv.messaging
+
+
+def test_actor_state_msg():
+ name = "test name"
+ status = "test status"
+ port = 12345
+ info = "test info"
+ msg = improv.messaging.ActorStateMsg(name, status, port, info)
+ assert msg.actor_name == name
+ assert msg.status == status
+ assert msg.nexus_in_port == port
+ assert msg.info == info
+
+
+def test_actor_state_reply_msg():
+ name = "test name"
+ status = "test status"
+ info = "test info"
+ msg = improv.messaging.ActorStateReplyMsg(name, status, info)
+ assert msg.actor_name == name
+ assert msg.status == status
+ assert msg.info == info
+
+
+def test_actor_signal_msg():
+ name = "test name"
+ signal = "test signal"
+ info = "test info"
+ msg = improv.messaging.ActorSignalMsg(name, signal, info)
+ assert msg.actor_name == name
+ assert msg.signal == signal
+ assert msg.info == info
+
+
+def test_actor_signal_reply_msg():
+ name = "test name"
+ status = "test status"
+ signal = "test_signal"
+ info = "test info"
+ msg = improv.messaging.ActorSignalReplyMsg(name, signal, status, info)
+ assert msg.actor_name == name
+ assert msg.status == status
+ assert msg.info == info
+ assert msg.signal == signal
+
+
+def test_broker_info_msg():
+ name = "test name"
+ pub_port = 54321
+ sub_port = 12345
+ info = "test info"
+ msg = improv.messaging.BrokerInfoMsg(name, pub_port, sub_port, info)
+ assert msg.name == name
+ assert msg.pub_port == pub_port
+ assert msg.sub_port == sub_port
+ assert msg.info == info
+
+
+def test_broker_info_reply_msg():
+ name = "test name"
+ status = "test status"
+ info = "test info"
+ msg = improv.messaging.BrokerInfoReplyMsg(name, status, info)
+ assert msg.name == name
+ assert msg.status == status
+ assert msg.info == info
+
+
+def test_log_info_msg():
+ name = "test name"
+ pull_port = 54321
+ pub_port = 12345
+ info = "test info"
+ msg = improv.messaging.LogInfoMsg(name, pull_port, pub_port, info)
+ assert msg.name == name
+ assert msg.pub_port == pub_port
+ assert msg.pull_port == pull_port
+ assert msg.info == info
+
+
+def test_log_info_reply_msg():
+ name = "test name"
+ status = "test status"
+ info = "test info"
+ msg = improv.messaging.LogInfoReplyMsg(name, status, info)
+ assert msg.name == name
+ assert msg.status == status
+ assert msg.info == info
+
+
+def test_harvester_info_msg():
+ name = "test name"
+ info = "test info"
+ msg = improv.messaging.HarvesterInfoMsg(name, info)
+ assert msg.name == name
+ assert msg.info == info
+
+
+def test_harvester_info_reply_msg():
+ name = "test name"
+ status = "test status"
+ info = "test info"
+ msg = improv.messaging.HarvesterInfoReplyMsg(name, status, info)
+ assert msg.name == name
+ assert msg.status == status
+ assert msg.info == info
diff --git a/test/test_nexus.py b/test/test_nexus.py
index e1a0e5f6..f4479f33 100644
--- a/test/test_nexus.py
+++ b/test/test_nexus.py
@@ -2,77 +2,21 @@
import shutil
import time
import os
+
import pytest
import logging
-import subprocess
-import signal
-import yaml
-
-from improv.nexus import Nexus
-from improv.store import StoreInterface
-
-# from improv.actor import Actor
-# from improv.store import StoreInterface
-
-SERVER_COUNTER = 0
-
-
-@pytest.fixture
-def ports():
- global SERVER_COUNTER
- CONTROL_PORT = 5555
- OUTPUT_PORT = 5556
- LOGGING_PORT = 5557
- yield (
- CONTROL_PORT + SERVER_COUNTER,
- OUTPUT_PORT + SERVER_COUNTER,
- LOGGING_PORT + SERVER_COUNTER,
- )
- SERVER_COUNTER += 3
-
-
-@pytest.fixture
-def setdir():
- prev = os.getcwd()
- os.chdir(os.path.dirname(__file__) + "/configs")
- yield None
- os.chdir(prev)
-
-
-@pytest.fixture
-def sample_nex(setdir, ports):
- nex = Nexus("test")
- nex.createNexus(
- file="good_config.yaml",
- store_size=40000000,
- control_port=ports[0],
- output_port=ports[1],
- )
- yield nex
- nex.destroyNexus()
-
-# @pytest.fixture
-# def setup_store(setdir):
-# """ Fixture to set up the store subprocess with 10 mb.
+import zmq
-# This fixture runs a subprocess that instantiates the store with a
-# memory of 10 megabytes. It specifies that "/tmp/store/" is the
-# location of the store socket.
-
-# Yields:
-# StoreInterface: An instance of the store.
-
-# TODO:
-# Figure out the scope.
-# """
-# setdir
-# p = subprocess.Popen(
-# ['plasma_store', '-s', '/tmp/store/', '-m', str(10000000)],\
-# stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
-# store = StoreInterface(store_loc = "/tmp/store/")
-# yield store
-# p.kill()
+from improv.config import CannotCreateConfigException
+from improv.messaging import ActorStateMsg
+from improv.nexus import (
+ Nexus,
+ ConfigFileNotProvidedException,
+ ConfigFileNotValidException,
+)
+from improv.store import StoreInterface
+from conftest import SignalManager
def test_init(setdir):
@@ -85,32 +29,24 @@ def test_init(setdir):
"cfg_name",
[
"good_config.yaml",
- "good_config_plasma.yaml",
],
)
-def test_createNexus(setdir, ports, cfg_name):
+def test_create_nexus(setdir, ports, cfg_name):
nex = Nexus("test")
- nex.createNexus(file=cfg_name, control_port=ports[0], output_port=ports[1])
- assert list(nex.comm_queues.keys()) == [
- "GUI_comm",
- "Acquirer_comm",
- "Analysis_comm",
- ]
- assert list(nex.sig_queues.keys()) == ["Acquirer_sig", "Analysis_sig"]
- assert list(nex.data_queues.keys()) == ["Acquirer.q_out", "Analysis.q_in"]
+ nex.create_nexus(file=cfg_name, control_port=ports[0], output_port=ports[1])
assert list(nex.actors.keys()) == ["Acquirer", "Analysis"]
assert list(nex.flags.keys()) == ["quit", "run", "load"]
assert nex.processes == []
- nex.destroyNexus()
+ nex.destroy_nexus()
assert True
def test_config_logged(setdir, ports, caplog):
nex = Nexus("test")
- nex.createNexus(
+ nex.create_nexus(
file="minimal_with_settings.yaml", control_port=ports[0], output_port=ports[1]
)
- nex.destroyNexus()
+ nex.destroy_nexus()
assert any(
[
"not_relevant: for testing purposes" in record.msg
@@ -119,54 +55,49 @@ def test_config_logged(setdir, ports, caplog):
)
-def test_loadConfig(sample_nex):
+def test_load_config(sample_nex):
nex = sample_nex
- nex.loadConfig("good_config.yaml")
- assert set(nex.comm_queues.keys()) == set(
- ["Acquirer_comm", "Analysis_comm", "GUI_comm"]
+ assert any(
+ [
+ link_info.link_name == "q_out"
+ for link_info in nex.outgoing_topics["Acquirer"]
+ ]
+ )
+ assert any(
+ [link_info.link_name == "q_in" for link_info in nex.incoming_topics["Analysis"]]
)
def test_argument_config_precedence(setdir, ports):
nex = Nexus("test")
- nex.createNexus(
+ nex.create_nexus(
file="minimal_with_settings.yaml",
control_port=ports[0],
output_port=ports[1],
store_size=11_000_000,
- use_watcher=True,
)
cfg = nex.config.settings
- nex.destroyNexus()
+ nex.destroy_nexus()
assert cfg["control_port"] == ports[0]
assert cfg["output_port"] == ports[1]
- assert cfg["store_size"] == 20_000_000
- assert not cfg["use_watcher"]
+ assert cfg["store_size"] == 11_000_000
-def test_settings_override_random_ports(setdir, ports):
- config_file = "minimal_with_settings.yaml"
- nex = Nexus("test")
- with open(config_file, "r") as ymlfile:
- cfg = yaml.safe_load(ymlfile)["settings"]
- control_port, output_port = nex.createNexus(
- file=config_file, control_port=0, output_port=0
- )
- nex.destroyNexus()
- assert control_port == cfg["control_port"]
- assert output_port == cfg["output_port"]
+# delete this comment later
+def test_start_nexus(sample_nex):
+ with SignalManager():
+ async def set_quit_flag(test_nex):
+ test_nex.flags["quit"] = True
-# delete this comment later
-@pytest.mark.skip(reason="unfinished")
-def test_startNexus(sample_nex):
- nex = sample_nex
- nex.startNexus()
- assert [p.name for p in nex.processes] == ["Acquirer", "Analysis"]
- nex.destroyNexus()
+ nex = sample_nex
+ nex.start_nexus(nex.poll_queues, poll_function=set_quit_flag, test_nex=nex)
+ assert [p.name for p in nex.processes] == ["Acquirer", "Analysis"]
-# @pytest.mark.skip(reason="This test is unfinished")
+@pytest.mark.skip(
+ reason="This test is unfinished - it does not validate link structure"
+)
@pytest.mark.parametrize(
("cfg_name", "actor_list", "link_list"),
[
@@ -197,20 +128,17 @@ def test_config_construction(cfg_name, actor_list, link_list, setdir, ports):
"""
nex = Nexus("test")
- nex.createNexus(file=cfg_name, control_port=ports[0], output_port=ports[1])
+ nex.create_nexus(file=cfg_name, control_port=ports[0], output_port=ports[1])
logging.info(cfg_name)
# Check for actors
act_lst = list(nex.actors)
- lnk_lst = list(nex.sig_queues)
- nex.destroyNexus()
+ nex.destroy_nexus()
assert actor_list == act_lst
- assert link_list == lnk_lst
act_lst = []
- lnk_lst = []
assert True
@@ -218,165 +146,247 @@ def test_config_construction(cfg_name, actor_list, link_list, setdir, ports):
"cfg_name",
[
"single_actor.yaml",
- "single_actor_plasma.yaml",
],
)
def test_single_actor(setdir, ports, cfg_name):
nex = Nexus("test")
with pytest.raises(AttributeError):
- nex.createNexus(
+ nex.create_nexus(
file="single_actor.yaml", control_port=ports[0], output_port=ports[1]
)
- nex.destroyNexus()
+ nex.destroy_nexus()
def test_cyclic_graph(setdir, ports):
nex = Nexus("test")
- nex.createNexus(
+ nex.create_nexus(
file="cyclic_config.yaml", control_port=ports[0], output_port=ports[1]
)
assert True
- nex.destroyNexus()
+ nex.destroy_nexus()
def test_blank_cfg(setdir, caplog, ports):
nex = Nexus("test")
- with pytest.raises(TypeError):
- nex.createNexus(
+ with pytest.raises(CannotCreateConfigException):
+ nex.create_nexus(
file="blank_file.yaml", control_port=ports[0], output_port=ports[1]
)
assert any(
["The config file is empty" in record.msg for record in list(caplog.records)]
)
- nex.destroyNexus()
+ nex.destroy_nexus()
-# def test_hasGUI_True(setdir):
-# setdir
-# nex = Nexus("test")
-# nex.createNexus(file="basic_demo_with_GUI.yaml")
+def test_start_store(caplog):
+ nex = Nexus("test")
+ nex._start_store_interface(100_000_000) # 100 MB store
-# assert True
-# nex.destroyNexus()
+ assert any(
+ "StoreInterface start successful" in record.msg for record in caplog.records
+ )
-# @pytest.mark.skip(reason="This test is unfinished.")
-# def test_hasGUI_False():
-# assert True
+ nex._close_store_interface()
+ nex.destroy_nexus()
+ assert True
-@pytest.mark.skip(reason="unfinished")
-def test_queue_message(setdir, sample_nex):
- nex = sample_nex
- nex.startNexus()
- time.sleep(20)
- nex.setup()
- time.sleep(20)
- nex.run()
- time.sleep(10)
- acq_comm = nex.comm_queues["Acquirer_comm"]
- acq_comm.put("Test Message")
-
- assert nex.comm_queues is None
- nex.destroyNexus()
- assert True
+def test_close_store(caplog):
+ nex = Nexus("test")
+ nex._start_store_interface(10000)
+ nex._close_store_interface()
+ assert any(
+ "StoreInterface close successful" in record.msg for record in caplog.records
+ )
-@pytest.mark.asyncio
-@pytest.mark.skip(reason="This test is unfinished.")
-async def test_queue_readin(sample_nex, caplog):
- nex = sample_nex
- nex.startNexus()
- # cqs = nex.comm_queues
- # assert cqs == None
- assert [record.msg for record in caplog.records] is None
- # cqs["Acquirer_comm"].put('quit')
- # assert "quit" == cqs["Acquirer_comm"].get()
- # await nex.pollQueues()
- assert True
+ # write to store
+ with pytest.raises(AttributeError):
+ nex.p_StoreInterface.put("Message in", "Message in Label")
-@pytest.mark.skip(reason="This test is unfinished.")
-def test_queue_sendout():
+ nex.destroy_nexus()
assert True
-@pytest.mark.skip(reason="This test is unfinished.")
-def test_run_sig():
- assert True
+def test_start_harvester(caplog, setdir, ports):
+ nex = Nexus("test")
+ try:
+ nex.create_nexus(
+ file="minimal_harvester.yaml",
+ store_size=100_000_000,
+ control_port=ports[0],
+ output_port=ports[1],
+ )
+ time.sleep(3)
-@pytest.mark.skip(reason="This test is unfinished.")
-def test_setup_sig():
- assert True
+ nex.destroy_nexus()
+ except Exception as e:
+ print(f"error caught in test harness: {e}")
+ logging.error(f"error caught in test harness: {e}")
+ assert any("Harvester server started" in record.msg for record in caplog.records)
-@pytest.mark.skip(reason="This test is unfinished.")
-def test_quit_sig():
- assert True
+def test_process_actor_state_update(caplog, setdir, ports):
+ nex = Nexus("test")
+ try:
+ nex.create_nexus(
+ file="minimal_harvester.yaml",
+ store_size=100_000_000,
+ control_port=ports[0],
+ output_port=ports[1],
+ )
-@pytest.mark.skip(reason="This test is unfinished.")
-def test_usehdd_True():
- assert True
+ time.sleep(3)
+ new_actor_message = ActorStateMsg("test actor", "waiting", 1234, "test info")
-@pytest.mark.skip(reason="This test is unfinished.")
-def test_usehdd_False():
- assert True
+ nex.process_actor_state_update(new_actor_message)
+ assert "test actor" in nex.actor_states
+ assert nex.actor_states["test actor"].actor_name == new_actor_message.actor_name
+ assert (
+ nex.actor_states["test actor"].nexus_in_port
+ == new_actor_message.nexus_in_port
+ )
+ assert nex.actor_states["test actor"].status == new_actor_message.status
+ update_actor_message = ActorStateMsg("test actor", "waiting", 1234, "test info")
-def test_startstore(caplog):
- nex = Nexus("test")
- nex._startStoreInterface(10000000) # 10 MB store
+ nex.process_actor_state_update(update_actor_message)
+ nex.destroy_nexus()
+ except Exception as e:
+ print(f"error caught in test harness: {e}")
+ logging.error(f"error caught in test harness: {e}")
assert any(
- "StoreInterface start successful" in record.msg for record in caplog.records
+ "Received state message from new actor test actor" in record.msg
+ for record in caplog.records
+ )
+ assert any(
+ "Received state message from actor test actor" in record.msg
+ for record in caplog.records
)
- nex._closeStoreInterface()
- nex.destroyNexus()
- assert True
+
+def test_process_actor_state_update_allows_run(caplog, setdir, ports):
+ nex = Nexus("test")
+ try:
+ nex.create_nexus(
+ file="minimal_harvester.yaml",
+ store_size=100_000_000,
+ control_port=ports[0],
+ output_port=ports[1],
+ )
+
+ time.sleep(3)
+
+ nex.actor_states["test actor1"] = None
+ nex.actor_states["test actor2"] = None
+
+ actor1_message = ActorStateMsg("test actor1", "ready", 1234, "test info")
+
+ nex.process_actor_state_update(actor1_message)
+ assert "test actor1" in nex.actor_states
+ assert nex.actor_states["test actor1"].actor_name == actor1_message.actor_name
+ assert (
+ nex.actor_states["test actor1"].nexus_in_port
+ == actor1_message.nexus_in_port
+ )
+ assert nex.actor_states["test actor1"].status == actor1_message.status
+
+ assert not nex.allowStart
+
+ actor2_message = ActorStateMsg("test actor2", "ready", 5678, "test info2")
+
+ nex.process_actor_state_update(actor2_message)
+ assert "test actor2" in nex.actor_states
+ assert nex.actor_states["test actor2"].actor_name == actor2_message.actor_name
+ assert (
+ nex.actor_states["test actor2"].nexus_in_port
+ == actor2_message.nexus_in_port
+ )
+ assert nex.actor_states["test actor2"].status == actor2_message.status
+
+ nex.destroy_nexus()
+ except Exception as e:
+ print(f"error caught in test harness: {e}")
+ logging.error(f"error caught in test harness: {e}")
+ assert nex.allowStart
-def test_closestore(caplog):
+@pytest.mark.asyncio
+async def test_process_actor_message(caplog, setdir, ports):
nex = Nexus("test")
- nex._startStoreInterface(10000)
- nex._closeStoreInterface()
+ try:
+ nex.create_nexus(
+ file="minimal_harvester.yaml",
+ store_size=100_000_000,
+ control_port=ports[0],
+ output_port=ports[1],
+ )
- assert any(
- "StoreInterface close successful" in record.msg for record in caplog.records
- )
+ time.sleep(3)
- # write to store
+ nex.actor_states["test actor1"] = None
+ nex.actor_states["test actor2"] = None
- with pytest.raises(AttributeError):
- nex.p_StoreInterface.put("Message in", "Message in Label")
+ actor1_message = ActorStateMsg("test actor1", "ready", 1234, "test info")
- nex.destroyNexus()
- assert True
+ ctx = nex.zmq_context
+ s = ctx.socket(zmq.REQ)
+ s.connect(f"tcp://localhost:{nex.actor_in_socket_port}")
+
+ s.send_pyobj(actor1_message)
+
+ await nex.process_actor_message()
+
+ nex.process_actor_state_update(actor1_message)
+ assert "test actor1" in nex.actor_states
+ assert nex.actor_states["test actor1"].actor_name == actor1_message.actor_name
+ assert (
+ nex.actor_states["test actor1"].nexus_in_port
+ == actor1_message.nexus_in_port
+ )
+ assert nex.actor_states["test actor1"].status == actor1_message.status
+
+ s.close(linger=0)
+ nex.destroy_nexus()
+ except Exception as e:
+ print(f"error caught in test harness: {e}")
+ logging.error(f"error caught in test harness: {e}")
+
+ assert not nex.allowStart
def test_specified_free_port(caplog, setdir, ports):
nex = Nexus("test")
- nex.createNexus(
- file="minimal_with_fixed_redis_port.yaml",
- store_size=10000000,
- control_port=ports[0],
- output_port=ports[1],
- )
+ try:
+ nex.create_nexus(
+ file="minimal_with_fixed_redis_port.yaml",
+ store_size=100_000_000,
+ control_port=ports[0],
+ output_port=ports[1],
+ )
- store = StoreInterface(server_port_num=6378)
- store.connect_to_server()
- key = store.put("port 6378")
- assert store.get(key) == "port 6378"
+ store = StoreInterface(server_port_num=6378)
+ store.connect_to_server()
+ key = store.put("port 6378")
+ assert store.get(key) == "port 6378"
- assert any(
- "Successfully connected to redis datastore on port 6378" in record.msg
- for record in caplog.records
- )
+ assert any(
+ "Successfully connected to redis datastore on port 6378" in record.msg
+ for record in caplog.records
+ )
+
+ time.sleep(3)
- nex.destroyNexus()
+ nex.destroy_nexus()
+ except Exception as e:
+ print(f"error caught in test harness: {e}")
+ logging.error(f"error caught in test harness: {e}")
assert any(
"StoreInterface start successful on port 6378" in record.msg
@@ -386,32 +396,47 @@ def test_specified_free_port(caplog, setdir, ports):
def test_specified_busy_port(caplog, setdir, ports, setup_store):
nex = Nexus("test")
- with pytest.raises(Exception, match="Could not start Redis on specified port."):
- nex.createNexus(
+ try:
+ nex.create_nexus(
file="minimal_with_fixed_default_redis_port.yaml",
- store_size=10000000,
+ store_size=100_000_000,
control_port=ports[0],
output_port=ports[1],
)
- nex.destroyNexus()
+ time.sleep(3)
+
+ nex.destroy_nexus()
+ except Exception as e:
+ print(f"error caught in test harness: {e}")
+ logging.error(f"error caught in test harness: {e}")
assert any(
- "Could not start Redis on specified port number." in record.msg
+ "Could not connect to port 6379" in record.msg for record in caplog.records
+ )
+
+ assert any(
+ "StoreInterface start successful on port 6380" in record.msg
for record in caplog.records
)
def test_unspecified_port_default_free(caplog, setdir, ports):
nex = Nexus("test")
- nex.createNexus(
- file="minimal.yaml",
- store_size=10000000,
- control_port=ports[0],
- output_port=ports[1],
- )
+ try:
+ nex.create_nexus(
+ file="minimal.yaml",
+ store_size=100_000_000,
+ control_port=ports[0],
+ output_port=ports[1],
+ )
- nex.destroyNexus()
+ time.sleep(3)
+
+ nex.destroy_nexus()
+ except Exception as e:
+ print(f"error caught in test harness: {e}")
+ logging.error(f"error caught in test harness: {e}")
assert any(
"StoreInterface start successful on port 6379" in record.msg
@@ -421,14 +446,20 @@ def test_unspecified_port_default_free(caplog, setdir, ports):
def test_unspecified_port_default_busy(caplog, setdir, ports, setup_store):
nex = Nexus("test")
- nex.createNexus(
- file="minimal.yaml",
- store_size=10000000,
- control_port=ports[0],
- output_port=ports[1],
- )
+ try:
+ nex.create_nexus(
+ file="minimal.yaml",
+ store_size=100_000_000,
+ control_port=ports[0],
+ output_port=ports[1],
+ )
+
+ time.sleep(3)
- nex.destroyNexus()
+ nex.destroy_nexus()
+ except Exception as e:
+ print(f"error caught in test harness: {e}")
+ logging.error(f"error caught in test harness: {e}")
assert any(
"StoreInterface start successful on port 6380" in record.msg
for record in caplog.records
@@ -436,15 +467,27 @@ def test_unspecified_port_default_busy(caplog, setdir, ports, setup_store):
def test_no_aof_dir_by_default(caplog, setdir, ports):
- nex = Nexus("test")
- nex.createNexus(
- file="minimal.yaml",
- store_size=10000000,
- control_port=ports[0],
- output_port=ports[1],
- )
+ try:
+ if "appendonlydir" in os.listdir("."):
+ shutil.rmtree("appendonlydir")
+ else:
+ logging.info("didn't find dbfilename")
+
+ nex = Nexus("test")
+
+ nex.create_nexus(
+ file="minimal.yaml",
+ store_size=100_000_000,
+ control_port=ports[0],
+ output_port=ports[1],
+ )
- nex.destroyNexus()
+ time.sleep(3)
+
+ nex.destroy_nexus()
+ except Exception as e:
+ print(f"error caught in test harness: {e}")
+ logging.error(f"error caught in test harness: {e}")
assert "appendonlydir" not in os.listdir(".")
assert all(["improv_persistence_" not in name for name in os.listdir(".")])
@@ -452,19 +495,23 @@ def test_no_aof_dir_by_default(caplog, setdir, ports):
def test_default_aof_dir_if_none_specified(caplog, setdir, ports, server_port_num):
nex = Nexus("test")
- nex.createNexus(
- file="minimal_with_redis_saving.yaml",
- store_size=10000000,
- control_port=ports[0],
- output_port=ports[1],
- )
+ try:
+ nex.create_nexus(
+ file="minimal_with_redis_saving.yaml",
+ store_size=100_000_000,
+ control_port=ports[0],
+ output_port=ports[1],
+ )
- store = StoreInterface(server_port_num=server_port_num)
- store.put(1)
+ store = StoreInterface(server_port_num=server_port_num)
+ store.put(1)
- time.sleep(3)
+ time.sleep(3)
- nex.destroyNexus()
+ nex.destroy_nexus()
+ except Exception as e:
+ print(f"error caught in test harness: {e}")
+ logging.error(f"error caught in test harness: {e}")
assert "appendonlydir" in os.listdir(".")
@@ -478,19 +525,23 @@ def test_default_aof_dir_if_none_specified(caplog, setdir, ports, server_port_nu
def test_specify_static_aof_dir(caplog, setdir, ports, server_port_num):
nex = Nexus("test")
- nex.createNexus(
- file="minimal_with_custom_aof_dirname.yaml",
- store_size=10000000,
- control_port=ports[0],
- output_port=ports[1],
- )
+ try:
+ nex.create_nexus(
+ file="minimal_with_custom_aof_dirname.yaml",
+ store_size=100_000_000,
+ control_port=ports[0],
+ output_port=ports[1],
+ )
- store = StoreInterface(server_port_num=server_port_num)
- store.put(1)
+ store = StoreInterface(server_port_num=server_port_num)
+ store.put(1)
- time.sleep(3)
+ time.sleep(3)
- nex.destroyNexus()
+ nex.destroy_nexus()
+ except Exception as e:
+ print(f"error caught in test harness: {e}")
+ logging.error(f"error caught in test harness: {e}")
assert "custom_aof_dirname" in os.listdir(".")
@@ -504,19 +555,23 @@ def test_specify_static_aof_dir(caplog, setdir, ports, server_port_num):
def test_use_ephemeral_aof_dir(caplog, setdir, ports, server_port_num):
nex = Nexus("test")
- nex.createNexus(
- file="minimal_with_ephemeral_aof_dirname.yaml",
- store_size=10000000,
- control_port=ports[0],
- output_port=ports[1],
- )
+ try:
+ nex.create_nexus(
+ file="minimal_with_ephemeral_aof_dirname.yaml",
+ store_size=100_000_000,
+ control_port=ports[0],
+ output_port=ports[1],
+ )
- store = StoreInterface(server_port_num=server_port_num)
- store.put(1)
+ store = StoreInterface(server_port_num=server_port_num)
+ store.put(1)
- time.sleep(3)
+ time.sleep(3)
- nex.destroyNexus()
+ nex.destroy_nexus()
+ except Exception as e:
+ print(f"error caught in test harness: {e}")
+ logging.error(f"error caught in test harness: {e}")
assert any(["improv_persistence_" in name for name in os.listdir(".")])
@@ -527,18 +582,24 @@ def test_use_ephemeral_aof_dir(caplog, setdir, ports, server_port_num):
def test_save_no_schedule(caplog, setdir, ports, server_port_num):
nex = Nexus("test")
- nex.createNexus(
- file="minimal_with_no_schedule_saving.yaml",
- store_size=10000000,
- control_port=ports[0],
- output_port=ports[1],
- )
+ try:
+ nex.create_nexus(
+ file="minimal_with_no_schedule_saving.yaml",
+ store_size=100_000_000,
+ control_port=ports[0],
+ output_port=ports[1],
+ )
- store = StoreInterface(server_port_num=server_port_num)
+ store = StoreInterface(server_port_num=server_port_num)
- fsync_schedule = store.client.config_get("appendfsync")
+ fsync_schedule = store.client.config_get("appendfsync")
- nex.destroyNexus()
+ time.sleep(3)
+
+ nex.destroy_nexus()
+ except Exception as e:
+ print(f"error caught in test harness: {e}")
+ logging.error(f"error caught in test harness: {e}")
assert "appendonlydir" in os.listdir(".")
shutil.rmtree("appendonlydir")
@@ -548,18 +609,24 @@ def test_save_no_schedule(caplog, setdir, ports, server_port_num):
def test_save_every_second(caplog, setdir, ports, server_port_num):
nex = Nexus("test")
- nex.createNexus(
- file="minimal_with_every_second_saving.yaml",
- store_size=10000000,
- control_port=ports[0],
- output_port=ports[1],
- )
+ try:
+ nex.create_nexus(
+ file="minimal_with_every_second_saving.yaml",
+ store_size=100_000_000,
+ control_port=ports[0],
+ output_port=ports[1],
+ )
+
+ store = StoreInterface(server_port_num=server_port_num)
- store = StoreInterface(server_port_num=server_port_num)
+ fsync_schedule = store.client.config_get("appendfsync")
- fsync_schedule = store.client.config_get("appendfsync")
+ time.sleep(3)
- nex.destroyNexus()
+ nex.destroy_nexus()
+ except Exception as e:
+ print(f"error caught in test harness: {e}")
+ logging.error(f"error caught in test harness: {e}")
assert "appendonlydir" in os.listdir(".")
shutil.rmtree("appendonlydir")
@@ -569,18 +636,24 @@ def test_save_every_second(caplog, setdir, ports, server_port_num):
def test_save_every_write(caplog, setdir, ports, server_port_num):
nex = Nexus("test")
- nex.createNexus(
- file="minimal_with_every_write_saving.yaml",
- store_size=10000000,
- control_port=ports[0],
- output_port=ports[1],
- )
+ try:
+ nex.create_nexus(
+ file="minimal_with_every_write_saving.yaml",
+ store_size=100_000_000,
+ control_port=ports[0],
+ output_port=ports[1],
+ )
+
+ store = StoreInterface(server_port_num=server_port_num)
- store = StoreInterface(server_port_num=server_port_num)
+ fsync_schedule = store.client.config_get("appendfsync")
- fsync_schedule = store.client.config_get("appendfsync")
+ time.sleep(3)
- nex.destroyNexus()
+ nex.destroy_nexus()
+ except Exception as e:
+ print(f"error caught in test harness: {e}")
+ logging.error(f"error caught in test harness: {e}")
assert "appendonlydir" in os.listdir(".")
shutil.rmtree("appendonlydir")
@@ -588,65 +661,121 @@ def test_save_every_write(caplog, setdir, ports, server_port_num):
assert fsync_schedule["appendfsync"] == "always"
-@pytest.mark.skip(reason="Nexus no longer deletes files on shutdown. Nothing to test.")
-def test_store_already_deleted_issues_warning(caplog):
- nex = Nexus("test")
- nex._startStoreInterface(10000)
- store_location = nex.store_loc
- StoreInterface(store_loc=nex.store_loc)
- os.remove(nex.store_loc)
- nex.destroyNexus()
- assert any(
- "StoreInterface file {} is already deleted".format(store_location) in record.msg
- for record in caplog.records
- )
+# def test_sigint_exits_cleanly(ports, set_dir_config_parent):
+# server_opts = [
+# "improv",
+# "server",
+# "-c",
+# str(ports[0]),
+# "-o",
+# str(ports[1]),
+# "-f",
+# "global.log",
+# "configs/minimal.yaml",
+# ]
+#
+# env = os.environ.copy()
+# env["PYTHONPATH"] += ":" + os.getcwd()
+#
+# server = subprocess.Popen(
+# server_opts, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=env
+# )
+#
+# time.sleep(5)
+#
+# server.send_signal(signal.SIGINT)
+#
+# server.wait(10)
+# assert True
-@pytest.mark.skip(reason="unfinished")
-def test_actor_sub(setdir, capsys, monkeypatch, ports):
- monkeypatch.setattr("improv.nexus.input", lambda: "setup\n")
- cfg_file = "sample_config.yaml"
+# def test_nexus_actor_in_port(ports, setdir, start_nexus_minimal_zmq):
+# context = zmq.Context()
+# nex_socket = context.socket(zmq.REQ)
+# nex_socket.connect(f"tcp://localhost:{ports[3]}") # actor in port
+#
+# test_socket = context.socket(zmq.REP)
+# test_socket.bind("tcp://*:0")
+# in_port_string = test_socket.getsockopt_string(SocketOption.LAST_ENDPOINT)
+# test_socket_port = int(in_port_string.split(":")[-1])
+# logging.info(f"Using port {test_socket_port}")
+#
+# logging.info("waiting to send")
+# actor_state = ActorStateMsg(
+# "test_actor", "test_status", test_socket_port, "test info string"
+# )
+# nex_socket.send_pyobj(actor_state)
+# logging.info("Sent")
+# out = nex_socket.recv_pyobj()
+# assert isinstance(out, ActorStateReplyMsg)
+# assert out.actor_name == actor_state.actor_name
+# assert out.status == "OK"
+
+
+def test_nexus_create_nexus_no_cfg_file(ports):
nex = Nexus("test")
+ with pytest.raises(ConfigFileNotProvidedException):
+ nex.create_nexus()
- nex.createNexus(
- file=cfg_file, store_size=4000, control_port=ports[0], output_port=ports[1]
- )
- print("Nexus Created")
-
- nex.startNexus()
- print("Nexus Started")
- # time.sleep(5)
- # print("Printing...")
- # subprocess.Popen(["setup"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
- # time.sleep(2)
- # subprocess.Popen(["run"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
- # time.sleep(5)
- # subprocess.Popen(["quit"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
-
- nex.destroyNexus()
- assert True
+#
+# @pytest.mark.skip(reason="Blocking comms so this won't work as-is")
+# def test_nexus_actor_comm_setup(ports, setdir):
+# filename = "minimal_zmq.yaml"
+# nex = Nexus("test")
+# nex.create_nexus(
+# file=filename,
+# store_size=10000000,
+# control_port=ports[0],
+# output_port=ports[1],
+# actor_in_port=ports[2],
+# )
+#
+# actor = nex.actors["Generator"]
+# actor.register_with_nexus()
+#
+# nex.process_actor_message()
+#
+#
+# @pytest.mark.skip(reason="Test isn't meant to be used for coverage")
+# def test_debug_nex(ports, setdir):
+# filename = "minimal_zmq.yaml"
+# conftest.nex_startup(ports, filename)
+#
+#
+# @pytest.mark.skip(reason="Test isn't meant to be used for coverage")
+# def test_nex_cfg(ports, setdir):
+# filename = "minimal_zmq.yaml"
+# nex = Nexus("test")
+# nex.create_nexus(
+# file=filename,
+# store_size=100000000,
+# control_port=ports[0],
+# output_port=ports[1],
+# actor_in_port=ports[2],
+# )
+# nex.start_nexus()
-@pytest.mark.skip(
- reason="skipping to prevent issues with orphaned stores. TODO fix this"
-)
-def test_sigint_exits_cleanly(ports, tmp_path):
- server_opts = [
- "improv",
- "server",
- "-c",
- str(ports[0]),
- "-o",
- str(ports[1]),
- "-f",
- tmp_path / "global.log",
- ]
-
- server = subprocess.Popen(
- server_opts, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
- )
- server.send_signal(signal.SIGINT)
+def test_nexus_bad_config_actor_args(setdir):
+ nex = Nexus("test")
+ # with pytest.raises(ConfigFileNotValidException):
+ try:
+ nex.create_nexus("bad_args.yaml")
+ except ConfigFileNotValidException:
+ assert True
+ except Exception as e:
+ print(f"error caught in test harness: {e}")
+ logging.error(f"error caught in test harness: {e}")
+
- server.wait(10)
- assert True
+def test_nexus_no_config_file():
+ nex = Nexus("test")
+ # with pytest.raises(ConfigFileNotProvidedException):
+ try:
+ nex.create_nexus()
+ except ConfigFileNotProvidedException:
+ assert True
+ except Exception as e:
+ print(f"error caught in test harness: {e}")
+ logging.error(f"error caught in test harness: {e}")
diff --git a/test/test_store_with_errors.py b/test/test_store_with_errors.py
index b95781a0..36d1d0be 100644
--- a/test/test_store_with_errors.py
+++ b/test/test_store_with_errors.py
@@ -1,9 +1,7 @@
import pytest
-from pyarrow import plasma
-from improv.store import StoreInterface, RedisStoreInterface, PlasmaStoreInterface
+from improv.store import StoreInterface, RedisStoreInterface, ObjectNotFoundError
-from pyarrow._plasma import PlasmaObjectExists
from scipy.sparse import csc_matrix
import numpy as np
import redis
@@ -17,91 +15,30 @@
logger.setLevel(logging.DEBUG)
-# TODO: add docstrings!!!
-# TODO: clean up syntax - consistent capitalization, function names, etc.
-# TODO: decide to keep classes
-# TODO: increase coverage!!! SEE store.py
-
-# Separate each class as individual file - individual tests???
-
-
def test_connect(setup_store, server_port_num):
store = StoreInterface(server_port_num=server_port_num)
assert isinstance(store.client, redis.Redis)
-def test_plasma_connect(setup_plasma_store, set_store_loc):
- store = PlasmaStoreInterface(store_loc=set_store_loc)
- assert isinstance(store.client, plasma.PlasmaClient)
-
-
def test_redis_connect(setup_store, server_port_num):
store = RedisStoreInterface(server_port_num=server_port_num)
- assert isinstance(store.client, redis.Redis)
assert store.client.ping()
-def test_connect_incorrect_path(setup_plasma_store, set_store_loc):
- # TODO: shorter name???
- # TODO: passes, but refactor --- see comments
- store_loc = "asdf"
- # Handle exception thrown - assert name == 'CannotConnectToStoreInterfaceError'
- # and message == 'Cannot connect to store at {}'.format(str(store_loc))
- # with pytest.raises(Exception, match='CannotConnectToStoreInterfaceError') as cm:
- # store.connect_store(store_loc)
- # # Check that the exception thrown is a CannotConnectToStoreInterfaceError
- # raise Exception('Cannot connect to store: {0}'.format(e))
- with pytest.raises(CannotConnectToStoreInterfaceError) as e:
- store = PlasmaStoreInterface(store_loc=store_loc)
- store.connect_store(store_loc)
- # Check that the exception thrown is a CannotConnectToStoreInterfaceError
- assert e.value.message == "Cannot connect to store at {}".format(str(store_loc))
-
-
-def test_redis_connect_wrong_port(setup_store, server_port_num):
+def test_redis_connect_wrong_port(server_port_num):
bad_port_num = 1234
with pytest.raises(CannotConnectToStoreInterfaceError) as e:
RedisStoreInterface(server_port_num=bad_port_num)
assert e.value.message == "Cannot connect to store at {}".format(str(bad_port_num))
-def test_connect_none_path(setup_plasma_store):
- # BUT default should be store_loc = '/tmp/store' if not entered?
- store_loc = None
- # Handle exception thrown - assert name == 'CannotConnectToStoreInterfaceError'
- # and message == 'Cannot connect to store at {}'.format(str(store_loc))
- # with pytest.raises(Exception) as cm:
- # store.connnect_store(store_loc)
- # Check that the exception thrown is a CannotConnectToStoreInterfaceError
- # assert cm.exception.name == 'CannotConnectToStoreInterfaceError'
- # with pytest.raises(Exception, match='CannotConnectToStoreInterfaceError') as cm:
- # store.connect_store(store_loc)
- # Check that the exception thrown is a CannotConnectToStoreInterfaceError
- # raise Exception('Cannot connect to store: {0}'.format(e))
- with pytest.raises(CannotConnectToStoreInterfaceError) as e:
- store = PlasmaStoreInterface(store_loc=store_loc)
- store.connect_store(store_loc)
- # Check that the exception thrown is a CannotConnectToStoreInterfaceError
- assert e.value.message == "Cannot connect to store at {}".format(str(store_loc))
-
-
-# class StoreInterfaceGet(self):
-
-
# TODO: @pytest.parameterize...store.get and store.getID for diff datatypes,
-# pickleable and not, etc.
-# Check raises...CannotGetObjectError (object never stored)
def test_init_empty(setup_store, server_port_num):
store = StoreInterface(server_port_num=server_port_num)
# logger.info(store.client.config_get())
assert store.get_all() == []
-def test_plasma_init_empty(setup_plasma_store, set_store_loc):
- store = PlasmaStoreInterface(store_loc=set_store_loc)
- assert store.get_all() == {}
-
-
def test_is_csc_matrix_and_put(setup_store, server_port_num):
mat = csc_matrix((3, 4), dtype=np.int8)
store = StoreInterface(server_port_num=server_port_num)
@@ -109,21 +46,13 @@ def test_is_csc_matrix_and_put(setup_store, server_port_num):
assert isinstance(store.get(x), csc_matrix)
-def test_plasma_is_csc_matrix_and_put(setup_plasma_store, set_store_loc):
- mat = csc_matrix((3, 4), dtype=np.int8)
- store = PlasmaStoreInterface(store_loc=set_store_loc)
- x = store.put(mat, "matrix")
- assert isinstance(store.getID(x), csc_matrix)
-
-
-@pytest.mark.skip
def test_get_list_and_all(setup_store, server_port_num):
store = StoreInterface(server_port_num=server_port_num)
- # id = store.put(1, "one")
- # id2 = store.put(2, "two")
- # id3 = store.put(3, "three")
- assert [1, 2] == store.getList(["one", "two"])
- assert [1, 2, 3] == store.get_all()
+ id = store.put(1)
+ id2 = store.put(2)
+ store.put(3)
+ assert [1, 2] == store.get_list([id, id2])
+ assert [1, 2, 3] == sorted(store.get_all())
def test_reset(setup_store, server_port_num):
@@ -133,13 +62,6 @@ def test_reset(setup_store, server_port_num):
assert store.get(id) == 1
-def test_plasma_reset(setup_plasma_store, set_store_loc):
- store = PlasmaStoreInterface(store_loc=set_store_loc)
- store.reset()
- id = store.put(1, "one")
- assert store.get(id) == 1
-
-
def test_put_one(setup_store, server_port_num):
store = StoreInterface(server_port_num=server_port_num)
id = store.put(1)
@@ -152,27 +74,10 @@ def test_redis_put_one(setup_store, server_port_num):
assert 1 == store.get(key)
-def test_plasma_put_one(setup_plasma_store, set_store_loc):
- store = PlasmaStoreInterface(store_loc=set_store_loc)
- id = store.put(1, "one")
- assert 1 == store.get(id)
-
-
-@pytest.mark.skip(reason="Error not being raised")
-def test_put_twice(setup_store):
- # store = StoreInterface()
- with pytest.raises(PlasmaObjectExists) as e:
- # id = store.put(2, "two")
- # id2 = store.put(2, "two")
- pass
- # Check that the exception thrown is an PlasmaObjectExists
- assert e.value.message == "Object already exists. Meant to call replace?"
-
-
-def test_getOne(setup_store, server_port_num):
- store = StoreInterface(server_port_num=server_port_num)
- id = store.put(1)
- assert 1 == store.get(id)
+def test_redis_get_unknown_object(setup_store, server_port_num):
+ store = RedisStoreInterface(server_port_num=server_port_num)
+ with pytest.raises(ObjectNotFoundError):
+ store.get("unknown")
def test_redis_get_one(setup_store, server_port_num):
diff --git a/test/test_tui.py b/test/test_tui.py
index d45bccfe..64b44ee5 100644
--- a/test/test_tui.py
+++ b/test/test_tui.py
@@ -6,8 +6,6 @@
from zmq import PUB, REP
from zmq.log.handlers import PUBHandler
-from test_nexus import ports
-
@pytest.fixture
def logger(ports):
@@ -31,7 +29,7 @@ async def sockets(ports):
@pytest.fixture
async def app(ports):
- mock = tui.TUI(*ports)
+ mock = tui.TUI(*ports[:-1])
yield mock
time.sleep(0.5)
@@ -60,8 +58,13 @@ async def test_log_panel_receives_logging(app, logger):
async def test_input_box_echoed_to_console(app):
+ ctx = zmq.Context()
+ mock_server_socket = ctx.socket(REP)
+ mock_server_socket.bind(f"tcp://*:{app.control_port.split(':')[1]}")
async with app.run_test() as pilot:
await pilot.press(*"foo", "enter")
+ await mock_server_socket.recv_string()
+ await mock_server_socket.send_string("test reply")
console = pilot.app.get_widget_by_id("console")
assert console.history[0] == "foo"