diff --git a/.github/workflows/CI.yaml b/.github/workflows/CI.yaml index 0c80283d..2d5fb25b 100644 --- a/.github/workflows/CI.yaml +++ b/.github/workflows/CI.yaml @@ -20,7 +20,7 @@ jobs: max-parallel: 1 fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] os: [ubuntu-latest, macos-latest] # [ubuntu-latest, macos-latest, windows-latest] steps: @@ -65,7 +65,7 @@ jobs: - name: Test with pytest run: | - python -m pytest --cov=improv + python -m pytest -x -s -l --cov=improv - name: Coveralls uses: coverallsapp/github-action@v2 diff --git a/.gitignore b/.gitignore index b76c5666..bedf9d58 100644 --- a/.gitignore +++ b/.gitignore @@ -128,6 +128,9 @@ dmypy.json arrow .vscode/ +.idea/ *.code-workspace improv/_version.py + +*venv diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 00000000..13566b81 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 00000000..105ce2da --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 00000000..69d3d2cc --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 00000000..35eb1ddf --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/demos/1p_caiman/1p_demo.yaml b/demos/1p_caiman/1p_demo.yaml index 6f754860..e60e469a 100644 --- a/demos/1p_caiman/1p_demo.yaml +++ b/demos/1p_caiman/1p_demo.yaml @@ -34,6 +34,3 @@ connections: Processor.q_out: [Analysis.q_in] Analysis.q_out: [Visual.q_in] InputStim.q_out: [Analysis.input_stim_queue] - -# settings: -# use_watcher: [Acquirer, Processor, Visual, Analysis] \ No newline at end of file diff --git a/demos/basic/Behavior_demo.py b/demos/basic/Behavior_demo.py index 8117189c..4a26aea4 100644 --- a/demos/basic/Behavior_demo.py +++ b/demos/basic/Behavior_demo.py @@ -7,7 +7,7 @@ loadFile = "./Behavior_demo.yaml" nexus = Nexus("Nexus") -nexus.createNexus(file=loadFile) +nexus.create_nexus(file=loadFile) # All modules needed have been imported # so we can change the level of logging here @@ -20,4 +20,4 @@ # logger = logging.getLogger("improv") # logger.setLevel(logging.INFO) -nexus.startNexus() +nexus.start_nexus() diff --git a/demos/basic/Behavior_demo.yaml b/demos/basic/Behavior_demo.yaml index 8f105507..992f1528 100644 --- a/demos/basic/Behavior_demo.yaml +++ b/demos/basic/Behavior_demo.yaml @@ -1,6 +1,3 @@ -settings: - use_watcher: [Acquirer, Processor, Visual, Analysis, Behavior, Motion] - actors: GUI: package: actors.visual diff --git a/demos/basic/basic_demo.py b/demos/basic/basic_demo.py index 953948ef..971c42fd 100644 --- a/demos/basic/basic_demo.py +++ b/demos/basic/basic_demo.py @@ -7,7 +7,7 @@ loadFile = "./basic_demo.yaml" nexus = Nexus("Nexus") -nexus.createNexus(file=loadFile) +nexus.create_nexus(file=loadFile) # All modules needed have been imported # so we can change the level of logging here @@ -20,4 +20,4 @@ # logger = logging.getLogger("improv") # logger.setLevel(logging.INFO) -nexus.startNexus() +nexus.start_nexus() diff --git a/demos/basic/basic_demo.yaml b/demos/basic/basic_demo.yaml index b007147c..1736bd64 100644 --- a/demos/basic/basic_demo.yaml +++ b/demos/basic/basic_demo.yaml @@ -33,6 +33,3 @@ connections: Processor.q_out: [Analysis.q_in] Analysis.q_out: [Visual.q_in] InputStim.q_out: [Analysis.input_stim_queue] - -# settings: -# use_watcher: [Acquirer, Processor, Visual, Analysis] \ No newline at end of file diff --git a/demos/bubblewrap/actors/bubble.py b/demos/bubblewrap/actors/bubble.py index c63c1e3c..5f4a33ea 100644 --- a/demos/bubblewrap/actors/bubble.py +++ b/demos/bubblewrap/actors/bubble.py @@ -55,7 +55,7 @@ def setup(self): self.bw.init_nodes() logger.info("Nodes initialized") - self._getStoreInterface() + self._get_store_interface() def runStep(self): """Observe new data from dim reduction and update bubblewrap""" diff --git a/demos/live/live_demo.py b/demos/live/live_demo.py index fff6839f..52f39fb8 100644 --- a/demos/live/live_demo.py +++ b/demos/live/live_demo.py @@ -7,7 +7,7 @@ loadFile = "./live_demo.yaml" nexus = Nexus("Nexus") -nexus.createNexus(file=loadFile) +nexus.create_nexus(file=loadFile) # All modules needed have been imported # so we can change the level of logging here @@ -20,4 +20,4 @@ # logger = logging.getLogger("improv") # logger.setLevel(logging.INFO) -nexus.startNexus() +nexus.start_nexus() diff --git a/demos/zmq/actors/zmq_ps_sample_generator.py b/demos/minimal/actors/sample_generator_zmq.py similarity index 54% rename from demos/zmq/actors/zmq_ps_sample_generator.py rename to demos/minimal/actors/sample_generator_zmq.py index 42b0afc4..585423f9 100644 --- a/demos/zmq/actors/zmq_ps_sample_generator.py +++ b/demos/minimal/actors/sample_generator_zmq.py @@ -1,15 +1,14 @@ +from improv.actor import ZmqActor +from datetime import date # used for saving import numpy as np import logging -from demos.sample_actors.zmqActor import ZmqActor - logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) class Generator(ZmqActor): - """Sample actor to generate data to pass into a sample processor - using sync ZMQ to communicate. + """Sample actor to generate data to pass into a sample processor. Intended for use along with sample_processor.py. """ @@ -25,39 +24,53 @@ def __str__(self): def setup(self): """Generates an array that serves as an initial source of data. - Sets up a ZmqPSActor to send data to the processor. Initial array is a 100 row, 5 column numpy matrix that contains integers from 1-99, inclusive. """ - logger.info("Beginning setup for Generator") self.data = np.asmatrix(np.random.randint(100, size=(100, 5))) - logger.info("Completed setup for Generator") + self.improv_logger.info("Completed setup for Generator") + + # def run(self): + # """ Send array into the store. + # """ + # self.fcns = {} + # self.fcns['setup'] = self.setup + # self.fcns['run'] = self.runStep + # self.fcns['stop'] = self.stop + + # with RunManager(self.name, self.fcns, self.links) as rm: + # logger.info(rm) def stop(self): """Save current randint vector to a file.""" - logger.info("Generator stopping") - np.save("sample_generator_data.npy", self.data) + self.improv_logger.info("Generator stopping") + np.save(f"sample_generator_data", self.data) + # This is not the best example of a save function, + # will overwrite previous files with the same name. return 0 - def runStep(self): + def run_step(self): """Generates additional data after initial setup data is exhausted. - Sends data to the processor using a ZmqPSActor. Data is of a different form as the setup data in that although it is the same size (5x1 vector), it is uniformly distributed in [1, 10] instead of in [1, 100]. Therefore, the average over time should converge to 5.5. """ + if self.frame_num < np.shape(self.data)[0]: - data_id = self.client.put(self.data[self.frame_num], str(f"Gen_raw: {self.frame_num}")) + data_id = self.client.put(self.data[self.frame_num]) try: - self.put(data_id) #[data_id, str(self.frame_num)]) + self.q_out.put(data_id) + # self.improv_logger.info(f"Sent {self.data[self.frame_num]} with key {data_id}") self.frame_num += 1 + except Exception as e: - logger.error(f"---------Generator Exception: {e}") + self.improv_logger.error(f"Generator Exception: {e}") else: - new_data = np.asmatrix(np.random.randint(10, size=(1, 5))) - self.data = np.concatenate((self.data, new_data), axis=0) + self.data = np.concatenate( + (self.data, np.asmatrix(np.random.randint(10, size=(1, 5)))), axis=0 + ) diff --git a/demos/minimal/actors/sample_persistence_generator.py b/demos/minimal/actors/sample_persistence_generator.py new file mode 100644 index 00000000..5f6afbf1 --- /dev/null +++ b/demos/minimal/actors/sample_persistence_generator.py @@ -0,0 +1,78 @@ +import random +import time + +from improv.actor import ZmqActor +from datetime import date # used for saving +import numpy as np +import logging + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class Generator(ZmqActor): + """Sample actor to generate data to pass into a sample processor. + + Intended for use along with sample_processor.py. + """ + + def __init__(self, output_filename, *args, **kwargs): + super().__init__(*args, **kwargs) + self.data = None + self.name = "Generator" + self.frame_num = 0 + self.output_filename = output_filename + + def __str__(self): + return f"Name: {self.name}, Data: {self.data}" + + def setup(self): + """Generates an array that serves as an initial source of data. + + Initial array is a 100 row, 5 column numpy matrix that contains + integers from 1-99, inclusive. + """ + + self.data = np.asmatrix(np.random.randint(100, size=(100, 5))) + self.improv_logger.info("Completed setup for Generator") + + def stop(self): + """Save current randint vector to a file.""" + + self.improv_logger.info("Generator stopping") + return 0 + + def run_step(self): + """Generates additional data after initial setup data is exhausted. + + Data is of a different form as the setup data in that although it is + the same size (5x1 vector), it is uniformly distributed in [1, 10] + instead of in [1, 100]. Therefore, the average over time should + converge to 5.5. + """ + + + device_time = time.time_ns() + time.sleep(0.0003) + acquired_time = time.time_ns() # mock the time the generator "actually received" the data + + if self.frame_num < np.shape(self.data)[0]: + + with open(self.output_filename, "a+") as f: + device_data = self.data[self.frame_num] + packaged_data = (acquired_time, (device_time, device_data)) + # save the data to the flat file before sending it downstream + f.write(f"{packaged_data[0]}, {packaged_data[1][0]}, {packaged_data[1][1]}\n") + + data_id = self.client.put(packaged_data) + try: + self.q_out.put(data_id) + # self.improv_logger.info(f"Sent {self.data[self.frame_num]} with key {data_id}") + self.frame_num += 1 + + except Exception as e: + self.improv_logger.error(f"Generator Exception: {e}") + else: + self.data = np.concatenate( + (self.data, np.asmatrix(np.random.randint(10, size=(1, 5)))), axis=0 + ) diff --git a/demos/zmq/actors/zmq_ps_sample_processor.py b/demos/minimal/actors/sample_persistence_processor.py similarity index 53% rename from demos/zmq/actors/zmq_ps_sample_processor.py rename to demos/minimal/actors/sample_persistence_processor.py index f54c00c8..3a38d3fc 100644 --- a/demos/zmq/actors/zmq_ps_sample_processor.py +++ b/demos/minimal/actors/sample_persistence_processor.py @@ -1,44 +1,46 @@ +from improv.actor import ZmqActor import numpy as np import logging -from demos.sample_actors.zmqActor import ZmqActor - logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) class Processor(ZmqActor): - """Sample processor used to calculate the average of an array of integers - using sync ZMQ to communicate. + """Sample processor used to calculate the average of an array of integers. Intended for use with sample_generator.py. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + if "name" in kwargs: + self.name = kwargs["name"] def setup(self): """Initializes all class variables. - Sets up a ZmqRRActor to receive data from the generator. self.name (string): name of the actor. - self.frame (ObjectID): Store object id referencing data from the store. + self.frame (ObjectID): StoreInterface object id referencing data from the store. self.avg_list (list): list that contains averages of individual vectors. self.frame_num (int): index of current frame. """ - self.name = "Processor" + + if not hasattr(self, "name"): + self.name = "Processor" self.frame = None self.avg_list = [] self.frame_num = 1 - logger.info("Completed setup for Processor") + self.improv_logger.info("Completed setup for Processor") def stop(self): """Trivial stop function for testing purposes.""" - logger.info("Processor stopping; have received {} frames so far".format(self.frame_num)) - def runStep(self): + self.improv_logger.info("Processor stopping") + return 0 + + def run_step(self): """Gets from the input queue and calculates the average. - Receives data from the generator using a ZmqRRActor. Receives an ObjectID, references data in the store using that ObjectID, calculates the average of that data, and finally prints @@ -47,19 +49,18 @@ def runStep(self): frame = None try: - frame = self.get() - - except: - logger.error("Could not get frame!") + frame = self.q_in.get(timeout=0.05) + except Exception as e: + # logger.error(f"{self.name} could not get frame! At {self.frame_num}: {e}") pass - if frame is not None: + if frame is not None and self.frame_num is not None: self.done = False - self.frame = self.client.getID(frame) - avg = np.mean(self.frame[0]) - - # logger.info(f"Average: {avg}") + self.frame = self.client.get(frame) + device_data = self.frame[1][1] + avg = np.mean(device_data) + # self.improv_logger.info(f"Average: {avg}") self.avg_list.append(avg) - logger.info(f"Overall Average: {np.mean(self.avg_list)}") - # logger.info(f"Frame number: {self.frame_num}") + # self.improv_logger.info(f"Overall Average: {np.mean(self.avg_list)}") + # self.improv_logger.info(f"Frame number: {self.frame_num}") self.frame_num += 1 diff --git a/demos/minimal/actors/sample_processor.py b/demos/minimal/actors/sample_processor_zmq.py similarity index 64% rename from demos/minimal/actors/sample_processor.py rename to demos/minimal/actors/sample_processor_zmq.py index 40fbf4d5..3ff9cd25 100644 --- a/demos/minimal/actors/sample_processor.py +++ b/demos/minimal/actors/sample_processor_zmq.py @@ -1,4 +1,4 @@ -from improv.actor import Actor +from improv.actor import ZmqActor import numpy as np import logging @@ -6,7 +6,7 @@ logger.setLevel(logging.INFO) -class Processor(Actor): +class Processor(ZmqActor): """Sample processor used to calculate the average of an array of integers. Intended for use with sample_generator.py. @@ -14,6 +14,8 @@ class Processor(Actor): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + if "name" in kwargs: + self.name = kwargs["name"] def setup(self): """Initializes all class variables. @@ -24,18 +26,20 @@ def setup(self): self.frame_num (int): index of current frame. """ - self.name = "Processor" + if not hasattr(self, "name"): + self.name = "Processor" self.frame = None self.avg_list = [] self.frame_num = 1 - logger.info("Completed setup for Processor") + self.improv_logger.info("Completed setup for Processor") def stop(self): """Trivial stop function for testing purposes.""" - logger.info("Processor stopping") + self.improv_logger.info("Processor stopping") + return 0 - def runStep(self): + def run_step(self): """Gets from the input queue and calculates the average. Receives an ObjectID, references data in the store using that @@ -46,22 +50,16 @@ def runStep(self): frame = None try: frame = self.q_in.get(timeout=0.05) - - except Exception: - logger.error("Could not get frame!") + except Exception as e: + # logger.error(f"{self.name} could not get frame! At {self.frame_num}: {e}") pass if frame is not None and self.frame_num is not None: self.done = False - if self.store_loc: - self.frame = self.client.getID(frame[0][0]) - else: - self.frame = self.client.get(frame) + self.frame = self.client.get(frame) avg = np.mean(self.frame[0]) - - logger.info(f"Average: {avg}") + # self.improv_logger.info(f"Average: {avg}") self.avg_list.append(avg) - logger.info(f"Overall Average: {np.mean(self.avg_list)}") - logger.info(f"Frame number: {self.frame_num}") - + # self.improv_logger.info(f"Overall Average: {np.mean(self.avg_list)}") + # self.improv_logger.info(f"Frame number: {self.frame_num}") self.frame_num += 1 diff --git a/demos/minimal/actors/sample_spawn_processor.py b/demos/minimal/actors/sample_spawn_processor.py index cdc8bbb7..0f8962ee 100644 --- a/demos/minimal/actors/sample_spawn_processor.py +++ b/demos/minimal/actors/sample_spawn_processor.py @@ -1,4 +1,4 @@ -from improv.actor import Actor +from improv.actor import ZmqActor import numpy as np from queue import Empty import logging @@ -7,7 +7,7 @@ logger.setLevel(logging.INFO) -class Processor(Actor): +class Processor(ZmqActor): """Sample processor used to calculate the average of an array of integers. Intended for use with sample_generator.py. @@ -31,7 +31,7 @@ def setup(self): self.frame_num = 1 logger.info("Completed setup for Processor") - self._getStoreInterface() + self._get_store_interface() def stop(self): """Trivial stop function for testing purposes.""" diff --git a/demos/minimal/minimal.yaml b/demos/minimal/minimal.yaml index c8e0b24e..d752b415 100644 --- a/demos/minimal/minimal.yaml +++ b/demos/minimal/minimal.yaml @@ -1,14 +1,15 @@ actors: Generator: - package: actors.sample_generator + package: actors.sample_generator_zmq class: Generator Processor: - package: actors.sample_processor + package: actors.sample_processor_zmq class: Processor connections: - Generator.q_out: [Processor.q_in] - -redis_config: - port: 6379 \ No newline at end of file + Generator-Processor: + sources: + - Generator.q_out + sinks: + - Processor.q_in diff --git a/demos/minimal/minimal_persistence.yaml b/demos/minimal/minimal_persistence.yaml new file mode 100644 index 00000000..65226762 --- /dev/null +++ b/demos/minimal/minimal_persistence.yaml @@ -0,0 +1,16 @@ +actors: + Generator: + package: actors.sample_persistence_generator + class: Generator + output_filename: test_persistence.csv + + Processor: + package: actors.sample_persistence_processor + class: Processor + +connections: + Generator-Processor: + sources: + - Generator.q_out + sinks: + - Processor.q_in diff --git a/demos/minimal/minimal_plasma.yaml b/demos/minimal/minimal_plasma.yaml deleted file mode 100644 index ab869678..00000000 --- a/demos/minimal/minimal_plasma.yaml +++ /dev/null @@ -1,13 +0,0 @@ -actors: - Generator: - package: actors.sample_generator - class: Generator - - Processor: - package: actors.sample_processor - class: Processor - -connections: - Generator.q_out: [Processor.q_in] - -plasma_config: \ No newline at end of file diff --git a/demos/minimal/minimal_spawn.yaml b/demos/minimal/minimal_spawn.yaml index 6f22e653..7baab190 100644 --- a/demos/minimal/minimal_spawn.yaml +++ b/demos/minimal/minimal_spawn.yaml @@ -1,6 +1,6 @@ actors: Generator: - package: actors.sample_generator + package: actors.sample_generator_zmq class: Generator Processor: @@ -9,7 +9,8 @@ actors: method: spawn connections: - Generator.q_out: [Processor.q_in] - -redis_config: - port: 6378 \ No newline at end of file + Generator-Processor: + sources: + - Generator.q_out + sinks: + - Processor.q_in \ No newline at end of file diff --git a/demos/neurofinder/neurofind_demo.py b/demos/neurofinder/neurofind_demo.py index 5f49399c..ad327d73 100644 --- a/demos/neurofinder/neurofind_demo.py +++ b/demos/neurofinder/neurofind_demo.py @@ -7,7 +7,7 @@ loadFile = "./neurofind_demo.yaml" nexus = Nexus("Nexus") -nexus.createNexus(file=loadFile) +nexus.create_nexus(file=loadFile) # All modules needed have been imported # so we can change the level of logging here @@ -20,4 +20,4 @@ # logger = logging.getLogger("improv") # logger.setLevel(logging.INFO) -nexus.startNexus() +nexus.start_nexus() diff --git a/demos/neurofinder/neurofind_demo.yaml b/demos/neurofinder/neurofind_demo.yaml index eb7796bc..483f6dc8 100644 --- a/demos/neurofinder/neurofind_demo.yaml +++ b/demos/neurofinder/neurofind_demo.yaml @@ -1,6 +1,3 @@ -settings: - use_watcher: [Acquirer, Processor, Visual, Analysis] - actors: GUI: package: demos.basic.actors.visual diff --git a/demos/sample_actors/acquire.py b/demos/sample_actors/acquire.py index 28dfb4f7..c355ed43 100644 --- a/demos/sample_actors/acquire.py +++ b/demos/sample_actors/acquire.py @@ -5,7 +5,7 @@ import numpy as np from skimage.io import imread -from improv.actor import Actor +from improv.actor import ZmqActor import logging @@ -15,7 +15,7 @@ # Classes: File Acquirer, Stim, Behavior, Tiff -class FileAcquirer(Actor): +class FileAcquirer(ZmqActor): """Class to import data from files and output frames in a buffer, or discrete. """ @@ -109,7 +109,7 @@ def saveFrame(self, frame): self.f.flush() -class StimAcquirer(Actor): +class StimAcquirer(ZmqActor): """Class to load visual stimuli data from file and stream into the pipeline """ @@ -149,7 +149,7 @@ def runStep(self): self.n += 1 -class BehaviorAcquirer(Actor): +class BehaviorAcquirer(ZmqActor): """Actor that acquires information of behavioral stimulus during the experiment @@ -188,7 +188,7 @@ def runStep(self): self.n += 1 -class FileStim(Actor): +class FileStim(ZmqActor): """Actor that acquires information of behavioral stimulus during the experiment from a file """ @@ -216,7 +216,7 @@ def runStep(self): self.n += 1 -class TiffAcquirer(Actor): +class TiffAcquirer(ZmqActor): """Loops through a TIF file.""" def __init__(self, *args, filename=None, framerate=30, **kwargs): diff --git a/demos/sample_actors/analysis.py b/demos/sample_actors/analysis.py index 664ddaf3..854e7a50 100644 --- a/demos/sample_actors/analysis.py +++ b/demos/sample_actors/analysis.py @@ -4,7 +4,7 @@ from queue import Empty import os -from improv.actor import Actor, RunManager +from improv.actor import ZmqActor from improv.store import ObjectNotFoundError import logging @@ -13,7 +13,7 @@ logger.setLevel(logging.INFO) -class MeanAnalysis(Actor): +class MeanAnalysis(ZmqActor): # TODO: Add additional error handling # TODO: this is too complex for a sample actor? def __init__(self, *args, **kwargs): diff --git a/demos/sample_actors/simple_analysis.py b/demos/sample_actors/simple_analysis.py index ff350b58..78890a43 100644 --- a/demos/sample_actors/simple_analysis.py +++ b/demos/sample_actors/simple_analysis.py @@ -3,7 +3,7 @@ from queue import Empty import os -from improv.actor import Actor, RunManager +from improv.actor import ZmqActor from improv.store import ObjectNotFoundError import logging @@ -12,7 +12,7 @@ logger.setLevel(logging.INFO) -class SimpleAnalysis(Actor): +class SimpleAnalysis(ZmqActor): # TODO: Add additional error handling def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/demos/sample_actors/zmqActor.py b/demos/sample_actors/zmqActor.py deleted file mode 100644 index 6455b0b4..00000000 --- a/demos/sample_actors/zmqActor.py +++ /dev/null @@ -1,206 +0,0 @@ -import asyncio -import time - -import zmq -from zmq import ( - PUB, - SUB, - SUBSCRIBE, - REQ, - REP, - LINGER, - Again, - NOBLOCK, - ZMQError, - EAGAIN, - ETERM, -) -from zmq.log.handlers import PUBHandler -import zmq.asyncio - -from improv.actor import Actor - -import logging - -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - - -class ZmqActor(Actor): - """ - Zmq actor with pub/sub or rep/req pattern. - """ - - def __init__(self, *args, type="PUB", ip="127.0.0.1", port=5555, **kwargs): - super().__init__(*args, **kwargs) - logger.info("Constructed Zmq Actor") - if str(type) in "PUB" or str(type) in "SUB": - self.pub_sub_flag = True # default - else: - self.pub_sub_flag = False - self.rep_req_flag = not self.pub_sub_flag - self.ip = ip - self.port = port - self.address = "tcp://{}:{}".format(self.ip, self.port) - - self.send_socket = None - self.recv_socket = None - self.req_socket = None - self.rep_socket = None - - self.context = zmq.Context.instance() - - def sendMsg(self, msg, msg_type="pyobj"): - """ - Sends a message to the controller. - """ - if not self.send_socket: - self.setSendSocket() - - if msg_type == "multipart": - self.send_socket.send_multipart(msg) - if msg_type == "pyobj": - self.send_socket.send_pyobj(msg) - elif msg_type == "single": - self.send_socket.send(msg) - - def recvMsg(self, msg_type="pyobj", flags=0): - """ - Receives a message from the controller. - - NOTE: default flag=0 instead of flag=NOBLOCK - """ - if not self.recv_socket: - self.setRecvSocket() - - while True: - try: - if msg_type == "multipart": - recv_msg = self.recv_socket.recv_multipart(flags=flags) - elif msg_type == "pyobj": - recv_msg = self.recv_socket.recv_pyobj(flags=flags) - elif msg_type == "single": - recv_msg = self.recv_socket.recv(flags=flags) - break - except Again: - pass - except ZMQError as e: - logger.info(f"ZMQ error: {e}") - if e.errno == ETERM: - pass # interrupted - pass or break if in try loop - if e.errno == EAGAIN: - pass # no message was ready (yet!) - - # self.recv_socket.close() - return recv_msg - - def requestMsg(self, msg): - """Safe version of send/receive with controller. - Based on the Lazy Pirate pattern [here] - (https://zguide.zeromq.org/docs/chapter4/#Client-Side-Reliability-Lazy-Pirate-Pattern) - """ - REQUEST_TIMEOUT = 500 - REQUEST_RETRIES = 3 - retries_left = REQUEST_RETRIES - - self.setReqSocket() - - reply = None - try: - logger.debug(f"Sending {msg} to controller.") - self.req_socket.send_pyobj(msg) - - while True: - ready = self.req_socket.poll(REQUEST_TIMEOUT) - - if ready: - reply = self.req_socket.recv_pyobj() - logger.debug(f"Received {reply} from controller.") - break - else: - retries_left -= 1 - logger.debug("No response from server.") - - # try to close and reconnect - self.req_socket.setsockopt(LINGER, 0) - self.req_socket.close() - if retries_left == 0: - logger.debug("Server seems to be offline. Giving up.") - break - - logger.debug("Attempting to reconnect to server...") - - self.setReqSocket() - - logger.debug(f"Resending {msg} to controller.") - self.req_socket.send_pyobj(msg) - - except asyncio.CancelledError: - pass - - self.req_socket.close() - return reply - - def replyMsg(self, reply, delay=0.0001): - """ - Safe version of receive/reply with controller. - """ - - self.setRepSocket() - - msg = self.rep_socket.recv_pyobj() - time.sleep(delay) - self.rep_socket.send_pyobj(reply) - self.rep_socket.close() - - return msg - - def put(self, msg=None): - logger.debug(f"Putting message {msg}") - if self.pub_sub_flag: - logger.debug(f"putting message {msg} using pub/sub") - return self.sendMsg(msg) - else: - logger.debug(f"putting message {msg} using rep/req") - return self.requestMsg(msg) - - def get(self, reply=None): - if self.pub_sub_flag: - logger.debug(f"getting message with pub/sub") - return self.recvMsg() - else: - logger.debug(f"getting message using reply {reply} with pub/sub") - return self.replyMsg(reply) - - def setSendSocket(self, timeout=1.001): - """ - Sets up the send socket for the actor. - """ - self.send_socket = self.context.socket(PUB) - self.send_socket.bind(self.address) - time.sleep(timeout) - - def setRecvSocket(self, timeout=1.001): - """ - Sets up the receive socket for the actor. - """ - self.recv_socket = self.context.socket(SUB) - self.recv_socket.connect(self.address) - self.recv_socket.setsockopt(SUBSCRIBE, b"") - time.sleep(timeout) - - def setReqSocket(self, timeout=0.0001): - """ - Sets up the request socket for the actor. - """ - self.req_socket = self.context.socket(REQ) - self.req_socket.connect(self.address) - time.sleep(timeout) - - def setRepSocket(self, timeout=0.0001): - """ - Sets up the reply socket for the actor. - """ - self.rep_socket = self.context.socket(REP) - self.rep_socket.bind(self.address) - time.sleep(timeout) diff --git a/demos/spike/spike_demo.py b/demos/spike/spike_demo.py index ca4f110d..24a37b04 100644 --- a/demos/spike/spike_demo.py +++ b/demos/spike/spike_demo.py @@ -7,7 +7,7 @@ loadFile = "./spike_demo.yaml" nexus = Nexus("Nexus") -nexus.createNexus(file=loadFile) +nexus.create_nexus(file=loadFile) # All modules needed have been imported # so we can change the level of logging here @@ -20,4 +20,4 @@ # logger = logging.getLogger("improv") # logger.setLevel(logging.INFO) -nexus.startNexus() +nexus.start_nexus() diff --git a/demos/spike/spike_demo.yaml b/demos/spike/spike_demo.yaml index 6aad9992..8b46f7a7 100644 --- a/demos/spike/spike_demo.yaml +++ b/demos/spike/spike_demo.yaml @@ -1,6 +1,3 @@ -# settings: -# use_watcher: [Acquirer, Analysis] - actors: # GUI: # package: actors.visual diff --git a/demos/zmq/README.md b/demos/zmq/README.md deleted file mode 100644 index f02edc26..00000000 --- a/demos/zmq/README.md +++ /dev/null @@ -1,14 +0,0 @@ -# ZMQ Demo - -This demo is intended to show how actors can communicate using zmq. There are two options for the demo, each corresponding to a different config file. The difference between these two options is in how they send and receive messages. One uses the publish/send concept, while the other uses the request/reply concept. - -## Instructions - -Just like any other demo, you can run `improv run ` in order to run the demo. For example, from within the `improv/demos/zmq` directory, running `improv run zmq_rr_demo.yaml` will run the zmq_rr demo. - -After running the demo, a tui (text user interface) should show up. From here, we can type `setup` followed by `run` to run the demo. After we are done, we can type `stop` to pause, or `quit` to exit. - -## Expected Output - -Currently, only logging output is supported. There will be no live output during the run. - diff --git a/demos/zmq/actors/zmq_rr_sample_generator.py b/demos/zmq/actors/zmq_rr_sample_generator.py deleted file mode 100644 index ef941ae6..00000000 --- a/demos/zmq/actors/zmq_rr_sample_generator.py +++ /dev/null @@ -1,59 +0,0 @@ -import numpy as np -import logging - -from demos.sample_actors.zmqActor import ZmqActor - -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - -class Generator(ZmqActor): - """Sample actor to generate data to pass into a sample processor - using async ZMQ to communicate. - - Intended for use along with sample_processor.py. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.data = None - self.name = "Generator" - self.frame_num = 0 - - def __str__(self): - return f"Name: {self.name}, Data: {self.data}" - - def setup(self): - """Generates an array that serves as an initial source of data. - Sets up a ZmqRRActor to send data to the processor. - - Initial array is a 100 row, 5 column numpy matrix that contains - integers from 1-99, inclusive. - """ - logger.info("Beginning setup for Generator") - self.data = np.asmatrix(np.random.randint(100, size=(100, 5))) - logger.info("Completed setup for Generator") - - def stop(self): - """Save current randint vector to a file.""" - logger.info("Generator stopping") - np.save("sample_generator_data.npy", self.data) - return 0 - - def runStep(self): - """Generates additional data after initial setup data is exhausted. - Sends data to the processor using a ZmqRRActor. - - Data is of a different form as the setup data in that although it is - the same size (5x1 vector), it is uniformly distributed in [1, 10] - instead of in [1, 100]. Therefore, the average over time should - converge to 5.5. - """ - if self.frame_num < np.shape(self.data)[0]: - data_id = self.client.put(self.data[self.frame_num], str(f"Gen_raw_{self.frame_num}")) - try: - self.put(data_id) - self.frame_num += 1 - except Exception as e: - logger.error(f"Generator Exception: {e}") - else: - self.data = np.concatenate((self.data, np.asmatrix(np.random.randint(10, size=(1, 5)))), axis=0) diff --git a/demos/zmq/config/zmq_ps_demo.yaml b/demos/zmq/config/zmq_ps_demo.yaml deleted file mode 100644 index 230282d1..00000000 --- a/demos/zmq/config/zmq_ps_demo.yaml +++ /dev/null @@ -1,11 +0,0 @@ -actors: - Generator: - package: actors.sample_generator - class: Generator - - Processor: - package: actors.sample_processor - class: Processor - -connections: - Generator.q_out: [Processor.q_in] \ No newline at end of file diff --git a/demos/zmq/config/zmq_rr_demo.yaml b/demos/zmq/config/zmq_rr_demo.yaml deleted file mode 100644 index 230282d1..00000000 --- a/demos/zmq/config/zmq_rr_demo.yaml +++ /dev/null @@ -1,11 +0,0 @@ -actors: - Generator: - package: actors.sample_generator - class: Generator - - Processor: - package: actors.sample_processor - class: Processor - -connections: - Generator.q_out: [Processor.q_in] \ No newline at end of file diff --git a/demos/zmq/zmq_ps_demo.yaml b/demos/zmq/zmq_ps_demo.yaml deleted file mode 100644 index 63fca91c..00000000 --- a/demos/zmq/zmq_ps_demo.yaml +++ /dev/null @@ -1,17 +0,0 @@ -actors: - Generator: - package: actors.zmq_ps_sample_generator - class: Generator - ip: 127.0.0.1 - port: 5555 - type: PUB - - Processor: - package: actors.zmq_ps_sample_processor - class: Processor - ip: 127.0.0.1 - port: 5555 - type: SUB - -connections: - Generator.q_out: [Processor.q_in] diff --git a/demos/zmq/zmq_rr_demo.yaml b/demos/zmq/zmq_rr_demo.yaml deleted file mode 100644 index 585cebd1..00000000 --- a/demos/zmq/zmq_rr_demo.yaml +++ /dev/null @@ -1,17 +0,0 @@ -actors: - Generator: - package: actors.zmq_rr_sample_generator - class: Generator - ip: 127.0.0.1 - port: 5555 - type: REQ - - Processor: - package: actors.zmq_rr_sample_processor - class: Processor - ip: 127.0.0.1 - port: 5555 - type: REP - -connections: - Generator.q_out: [Processor.q_in] diff --git a/docs/running.ipynb b/docs/running.ipynb index 47b7b2b0..2ed13a09 100644 --- a/docs/running.ipynb +++ b/docs/running.ipynb @@ -19,10 +19,10 @@ "languageId": "shellscript" } }, - "outputs": [], "source": [ "!improv --help" - ] + ], + "outputs": [] }, { "cell_type": "markdown", @@ -43,10 +43,10 @@ "languageId": "shellscript" } }, - "outputs": [], "source": [ "!improv run --help" - ] + ], + "outputs": [] }, { "cell_type": "markdown", @@ -110,40 +110,10 @@ "languageId": "shellscript" } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "usage: improv server [-h] [-c CONTROL_PORT] [-o OUTPUT_PORT] [-l LOGGING_PORT]\n", - " [-f LOGFILE] [-a ACTOR_PATH]\n", - " configfile\n", - "\n", - "Start the improv server\n", - "\n", - "positional arguments:\n", - " configfile YAML file specifying improv pipeline\n", - "\n", - "options:\n", - " -h, --help show this help message and exit\n", - " -c CONTROL_PORT, --control-port CONTROL_PORT\n", - " local port on which control signals are received\n", - " -o OUTPUT_PORT, --output-port OUTPUT_PORT\n", - " local port on which output messages are broadcast\n", - " -l LOGGING_PORT, --logging-port LOGGING_PORT\n", - " local port on which logging messages are broadcast\n", - " -f LOGFILE, --logfile LOGFILE\n", - " name of log file\n", - " -a ACTOR_PATH, --actor-path ACTOR_PATH\n", - " search path to add to sys.path when looking for\n", - " actors; defaults to the directory containing\n", - " configfile\n" - ] - } - ], "source": [ "!improv server --help" - ] + ], + "outputs": [] }, { "cell_type": "markdown", @@ -160,30 +130,10 @@ "languageId": "shellscript" } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "usage: improv client [-h] [-c CONTROL_PORT] [-s SERVER_PORT] [-l LOGGING_PORT]\n", - "\n", - "Start the improv client\n", - "\n", - "options:\n", - " -h, --help show this help message and exit\n", - " -c CONTROL_PORT, --control-port CONTROL_PORT\n", - " address on which control signals are sent to the\n", - " server\n", - " -s SERVER_PORT, --server-port SERVER_PORT\n", - " address on which messages from the server are received\n", - " -l LOGGING_PORT, --logging-port LOGGING_PORT\n", - " address on which logging messages are broadcast\n" - ] - } - ], "source": [ "!improv client --help" - ] + ], + "outputs": [] }, { "cell_type": "markdown", diff --git a/improv/actor.py b/improv/actor.py index 3d6a2a23..09dd3b91 100644 --- a/improv/actor.py +++ b/improv/actor.py @@ -1,10 +1,17 @@ +from __future__ import annotations + import time import signal import asyncio import traceback from queue import Empty -import improv.store +import zmq +from zmq import SocketOption + +from improv import log +from improv.link import ZmqLink +from improv.messaging import ActorStateMsg, ActorStateReplyMsg, ActorSignalReplyMsg from improv.store import StoreInterface import logging @@ -13,6 +20,17 @@ logger.setLevel(logging.INFO) +# TODO construct log handler within post-fork setup based on +# TODO arguments sent to actor by nexus +# TODO and add it below; fine to leave logger here as-is + + +class LinkInfo: + def __init__(self, link_name, link_topic): + self.link_name = link_name + self.link_topic = link_topic + + class AbstractActor: """Base class for an actor that Nexus controls and interacts with. @@ -20,9 +38,7 @@ class AbstractActor: Also needs to be responsive to sent Signals (e.g. run, setup, etc) """ - def __init__( - self, name, store_loc=None, method="fork", store_port_num=None, *args, **kwargs - ): + def __init__(self, name, method="fork", store_port_num=None, *args, **kwargs): """Require a name for multiple instances of the same actor/class Create initial empty dict of Links for easier referencing """ @@ -31,8 +47,8 @@ def __init__( self.links = {} self.method = method self.client = None - self.store_loc = store_loc self.lower_priority = False + self.improv_logger = None self.store_port_num = store_port_num # Start with no explicit data queues. @@ -49,7 +65,7 @@ def __repr__(self): """ return self.name + ": " + str(self.links.keys()) - def setStoreInterface(self, client): + def set_store_interface(self, client): """Sets the client interface to the store Args: @@ -57,17 +73,14 @@ def setStoreInterface(self, client): """ self.client = client - def _getStoreInterface(self): + def _get_store_interface(self): # TODO: Where do we require this be run? Add a Signal and include in RM? if not self.client: store = None - if StoreInterface == improv.store.RedisStoreInterface: - store = StoreInterface(self.name, self.store_port_num) - else: - store = StoreInterface(self.name, self.store_loc) - self.setStoreInterface(store) + store = StoreInterface(self.name, self.store_port_num) + self.set_store_interface(store) - def setLinks(self, links): + def set_links(self, links): """General full dict set for links Args: @@ -75,7 +88,7 @@ def setLinks(self, links): """ self.links = links - def setCommLinks(self, q_comm, q_sig): + def set_comm_links(self, q_comm, q_sig): """Set explicit communication links to/from Nexus (q_comm, q_sig) Args: @@ -86,7 +99,7 @@ def setCommLinks(self, q_comm, q_sig): self.q_sig = q_sig self.links.update({"q_comm": self.q_comm, "q_sig": self.q_sig}) - def setLinkIn(self, q_in): + def set_link_in(self, q_in): """Set the dedicated input queue Args: @@ -95,7 +108,7 @@ def setLinkIn(self, q_in): self.q_in = q_in self.links.update({"q_in": self.q_in}) - def setLinkOut(self, q_out): + def set_link_out(self, q_out): """Set the dedicated output queue Args: @@ -104,7 +117,7 @@ def setLinkOut(self, q_out): self.q_out = q_out self.links.update({"q_out": self.q_out}) - def setLinkWatch(self, q_watch): + def set_link_watch(self, q_watch): """Set the dedicated watchout queue Args: @@ -113,7 +126,7 @@ def setLinkWatch(self, q_watch): self.q_watchout = q_watch self.links.update({"q_watchout": self.q_watchout}) - def addLink(self, name, link): + def add_link(self, name, link): """Function provided to add additional data links by name using same form as q_in or q_out Must be done during registration and not during run @@ -126,7 +139,7 @@ def addLink(self, name, link): # User can then use: self.my_queue = self.links['my_queue'] in a setup fcn, # or continue to reference it using self.links['my_queue'] - def getLinks(self): + def get_links(self): """Returns dictionary of links for the current actor Returns: @@ -134,24 +147,6 @@ def getLinks(self): """ return self.links - def put(self, idnames, q_out=None, save=None): - """TODO: This is deprecated? Prefer using Links explicitly""" - if save is None: - save = [False] * len(idnames) - - if len(save) < len(idnames): - save = save + [False] * (len(idnames) - len(save)) - - if q_out is None: - q_out = self.q_out - - q_out.put(idnames) - - for i in range(len(idnames)): - if save[i]: - if self.q_watchout: - self.q_watchout.put(idnames[i]) - def setup(self): """Essenitally the registration process Can also be an initialization for the actor @@ -174,7 +169,7 @@ def stop(self): """ pass - def changePriority(self): + def change_priority(self): """Try to lower this process' priority Only changes priority if lower_priority is set TODO: Only works on unix machines. Add Windows functionality @@ -188,6 +183,18 @@ def changePriority(self): logger.info("Lowered priority of this process: {}".format(self.name)) print("Lowered ", os.getpid(), " for ", self.name) + def register_with_nexus(self): + pass + + def register_with_broker(self): + pass + + def setup_links(self): + pass + + def setup_logging(self): + raise NotImplementedError + class ManagedActor(AbstractActor): def __init__(self, *args, **kwargs): @@ -196,17 +203,146 @@ def __init__(self, *args, **kwargs): # Define dictionary of actions for the RunManager self.actions = {} self.actions["setup"] = self.setup - self.actions["run"] = self.runStep + self.actions["run"] = self.run_step self.actions["stop"] = self.stop + self.nexus_sig_port: int = None def run(self): - with RunManager(self.name, self.actions, self.links): + self.register_with_nexus() + self.setup_logging() + self.register_with_broker() + self.setup_links() + with RunManager( + self.name, self.actions, self.links, self.nexus_sig_port, self.improv_logger + ): pass - def runStep(self): + def run_step(self): raise NotImplementedError +class ZmqActor(ManagedActor): + def __init__( + self, + nexus_comm_port, + broker_sub_port, + broker_pub_port, + log_pull_port, + outgoing_links, + incoming_links, + broker_host="localhost", + log_host="localhost", + nexus_host="localhost", + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.broker_pub_socket: zmq.Socket | None = None + self.zmq_context: zmq.Context | None = None + self.nexus_host: str = nexus_host + self.nexus_comm_port: int = nexus_comm_port + self.nexus_sig_port: int | None = None + self.nexus_comm_socket: zmq.Socket | None = None + self.nexus_sig_socket: zmq.Socket | None = None + self.broker_sub_port: int = broker_sub_port + self.broker_pub_port: int = broker_pub_port + self.broker_host: str = broker_host + self.log_pull_port: int = log_pull_port + self.log_host: str = log_host + self.outgoing_links: list[LinkInfo] = outgoing_links + self.incoming_links: list[LinkInfo] = incoming_links + self.incoming_sockets: dict[str, zmq.Socket] = dict() + + # Redefine dictionary of actions for the RunManager + self.actions = {"setup": self.setup, "run": self.run_step, "stop": self.stop} + + def register_with_nexus(self): + logger.info(f"Actor {self.name} registering with nexus") + self.zmq_context = zmq.Context() + self.zmq_context.setsockopt(SocketOption.LINGER, 0) + # create a REQ socket pointed at nexus' global actor in port + self.nexus_comm_socket = self.zmq_context.socket(zmq.REQ) + self.nexus_comm_socket.connect( + f"tcp://{self.nexus_host}:{self.nexus_comm_port}" + ) + + # create a REP socket for comms from Nexus and save its state + self.nexus_sig_socket = self.zmq_context.socket(zmq.REP) + self.nexus_sig_socket.bind("tcp://*:0") # find any available port + sig_socket_addr = self.nexus_sig_socket.getsockopt_string( + SocketOption.LAST_ENDPOINT + ) + self.nexus_sig_port = int(sig_socket_addr.split(":")[-1]) + + # build and send a message to nexus + actor_state = ActorStateMsg( + self.name, + Signal.waiting(), + self.nexus_sig_port, + "comms opened and actor ready to initialize", + ) + + self.nexus_comm_socket.send_pyobj(actor_state) + + rep: ActorStateReplyMsg = self.nexus_comm_socket.recv_pyobj() + logger.info( + f"Got response from nexus:\n" + f"Status: {rep.status}\n" + f"Info: {rep.info}\n" + ) + + self.links["q_comm"] = ZmqLink(self.nexus_comm_socket, f"{self.name}.q_comm") + self.links["q_sig"] = ZmqLink(self.nexus_sig_socket, f"{self.name}.q_sig") + + logger.info(f"Actor {self.name} registered with Nexus") + + def register_with_broker(self): # really opening sockets here + self.improv_logger.info(f"Actor {self.name} registering with broker") + if len(self.outgoing_links) > 0: + self.broker_pub_socket = self.zmq_context.socket(zmq.PUB) + self.broker_pub_socket.connect( + f"tcp://{self.broker_host}:{self.broker_sub_port}" + ) + + for incoming_link in self.incoming_links: + new_socket: zmq.Socket = self.zmq_context.socket(zmq.SUB) + new_socket.connect(f"tcp://{self.broker_host}:{self.broker_pub_port}") + new_socket.subscribe(incoming_link.link_topic) + self.incoming_sockets[incoming_link.link_name] = new_socket + self.improv_logger.info(f"Actor {self.name} registered with broker") + + def setup_links(self): + self.improv_logger.info(f"Actor {self.name} setting up links") + for outgoing_link in self.outgoing_links: + self.links[outgoing_link.link_name] = ZmqLink( + self.broker_pub_socket, + outgoing_link.link_name, + outgoing_link.link_topic, + ) + + for incoming_link in self.incoming_links: + self.links[incoming_link.link_name] = ZmqLink( + self.incoming_sockets[incoming_link.link_name], + incoming_link.link_name, + incoming_link.link_topic, + ) + + if "q_out" in self.links.keys(): + self.q_out = self.links["q_out"] + if "q_in" in self.links.keys(): + self.q_in = self.links["q_in"] + self.improv_logger.info(f"Actor {self.name} finished setting up links") + + def setup_logging(self): + self.improv_logger = logging.getLogger(self.name) + self.improv_logger.setLevel(logging.INFO) + for handler in logger.handlers: + self.improv_logger.addHandler(handler) + self.improv_logger.addHandler( + log.ZmqLogHandler(self.log_host, self.log_pull_port, self.zmq_context) + ) + + class AsyncActor(AbstractActor): def __init__(self, *args, **kwargs): super().__init__(*args) @@ -214,7 +350,7 @@ def __init__(self, *args, **kwargs): # Define dictionary of actions for the RunManager self.actions = {} self.actions["setup"] = self.setup - self.actions["run"] = self.runStep + self.actions["run"] = self.run_step self.actions["stop"] = self.stop def run(self): @@ -231,7 +367,7 @@ async def setup(self): """ pass - async def runStep(self): + async def run_step(self): raise NotImplementedError async def stop(self): @@ -243,13 +379,24 @@ async def stop(self): class RunManager: - def __init__(self, name, actions, links, runStoreInterface=None, timeout=1e-6): + def __init__( + self, + name, + actions, + links, + nexus_sig_port, + improv_logger, + runStoreInterface=None, + timeout=1e-6, + ): self.run = False self.stop = False self.config = False + self.nexus_sig_port = nexus_sig_port + self.improv_logger = improv_logger self.actorName = name - logger.debug("RunManager for {} created".format(self.actorName)) + self.improv_logger.debug("RunManager for {} created".format(self.actorName)) self.actions = actions self.links = links @@ -269,58 +416,82 @@ def __enter__(self): try: self.actions["run"]() except Exception as e: - logger.error("Actor {} error in run: {}".format(an, e)) - logger.error(traceback.format_exc()) + self.improv_logger.error("Actor {} error in run: {}".format(an, e)) + self.improv_logger.error(traceback.format_exc()) elif self.stop: try: self.actions["stop"]() except Exception as e: - logger.error("Actor {} error in stop: {}".format(an, e)) - logger.error(traceback.format_exc()) + self.improv_logger.error("Actor {} error in stop: {}".format(an, e)) + self.improv_logger.error(traceback.format_exc()) self.stop = False # Run once elif self.config: try: if self.runStoreInterface: self.runStoreInterface() self.actions["setup"]() - self.q_comm.put([Signal.ready()]) + self.q_comm.put( + ActorStateMsg( + self.actorName, Signal.ready(), self.nexus_sig_port, "" + ) + ) + res = self.q_comm.get() + self.improv_logger.info( + f"Actor {res.actor_name} got state update reply:\n" + f"Status: {res.status}\n" + f"Info: {res.info}\n" + ) except Exception as e: - logger.error("Actor {} error in setup: {}".format(an, e)) - logger.error(traceback.format_exc()) + self.improv_logger.error( + "Actor {} error in setup: {}".format(an, e) + ) + self.improv_logger.error(traceback.format_exc()) self.config = False # Check for new Signals received from Nexus try: - signal = self.q_sig.get(timeout=self.timeout) - logger.debug("{} received Signal {}".format(self.actorName, signal)) + signal_msg = self.q_sig.get(timeout=self.timeout) + signal = signal_msg.signal + self.q_sig.put(ActorSignalReplyMsg(an, signal, "OK", "")) + self.improv_logger.warning( + "{} received Signal {}".format(self.actorName, signal) + ) if signal == Signal.run(): self.run = True - logger.warning("Received run signal, begin running") + self.improv_logger.warning("Received run signal, begin running") elif signal == Signal.setup(): self.config = True elif signal == Signal.stop(): self.run = False self.stop = True - logger.warning(f"actor {self.actorName} received stop signal") + self.improv_logger.warning( + f"actor {self.actorName} received stop signal" + ) elif signal == Signal.quit(): - logger.warning("Received quit signal, aborting") + self.improv_logger.warning("Received quit signal, aborting") break elif signal == Signal.pause(): - logger.warning("Received pause signal, pending...") + self.improv_logger.warning("Received pause signal, pending...") self.run = False elif signal == Signal.resume(): # currently treat as same as run - logger.warning("Received resume signal, resuming") + self.improv_logger.warning("Received resume signal, resuming") self.run = True + elif signal == Signal.status(): + self.improv_logger.info( + f"Actor {self.actorName} received status request" + ) except KeyboardInterrupt: break except Empty: pass # No signal from Nexus + except TimeoutError: + pass # No signal from Nexus over zmq return None def __exit__(self, type, value, traceback): - logger.info("Ran for " + str(time.time() - self.start) + " seconds") - logger.warning("Exiting RunManager") + self.improv_logger.info("Ran for " + str(time.time() - self.start) + " seconds") + self.improv_logger.warning("Exiting RunManager") return None @@ -473,3 +644,11 @@ def stop(): @staticmethod def stop_success(): return "stop success" + + @staticmethod + def status(): + return "status" + + @staticmethod + def waiting(): + return "waiting" diff --git a/improv/broker.py b/improv/broker.py new file mode 100644 index 00000000..e9055080 --- /dev/null +++ b/improv/broker.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +import logging +import signal + +import zmq +from zmq import SocketOption + +from improv.messaging import BrokerInfoMsg + +DEBUG = True + +local_log = logging.getLogger(__name__) + + +def bootstrap_broker(nexus_hostname, nexus_port): + if DEBUG: + local_log.addHandler(logging.FileHandler("broker_server.log")) + try: + broker = PubSubBroker(nexus_hostname, nexus_port) + broker.register_with_nexus() + broker.serve(broker.read_and_pub_message) + except Exception as e: + local_log.error(e) + for handler in local_log.handlers: + handler.close() + + +class PubSubBroker: + def __init__(self, nexus_hostname, nexus_comm_port): + self.running = True + self.nexus_hostname: str = nexus_hostname + self.nexus_comm_port: int = nexus_comm_port + self.zmq_context: zmq.Context | None = None + self.nexus_socket: zmq.Socket | None = None + self.pub_port: int + self.sub_port: int + self.pub_socket: zmq.Socket | None = None + self.sub_socket: zmq.Socket | None = None + + signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT) + for s in signals: + signal.signal(s, self.stop) + + def register_with_nexus(self): + # connect to nexus + self.zmq_context = zmq.Context() + self.zmq_context.setsockopt(SocketOption.LINGER, 0) + self.nexus_socket = self.zmq_context.socket(zmq.REQ) + self.nexus_socket.connect(f"tcp://{self.nexus_hostname}:{self.nexus_comm_port}") + + self.sub_socket = self.zmq_context.socket(zmq.SUB) + self.sub_socket.bind("tcp://*:0") + sub_port_string = self.sub_socket.getsockopt_string(SocketOption.LAST_ENDPOINT) + self.sub_port = int(sub_port_string.split(":")[-1]) + self.sub_socket.subscribe("") # receive all incoming messages + + self.pub_socket = self.zmq_context.socket(zmq.PUB) + self.pub_socket.bind("tcp://*:0") + pub_port_string = self.pub_socket.getsockopt_string(SocketOption.LAST_ENDPOINT) + self.pub_port = int(pub_port_string.split(":")[-1]) + + port_info = BrokerInfoMsg( + "broker", + self.pub_port, + self.sub_port, + "Ports up and running, ready to serve messages", + ) + + self.nexus_socket.send_pyobj(port_info) + + local_log.info("broker attempting to get message from nexus") + msg_available = 0 + retries = 3 + while (retries > 3) and (msg_available == 0): + msg_available = self.nexus_socket.poll(timeout=1000) + if msg_available == 0: + local_log.info( + "broker didn't get a reply from nexus. cycling socket and resending" + ) + self.nexus_socket.close(linger=0) + self.nexus_socket = self.zmq_context.socket(zmq.REQ) + self.nexus_socket.connect( + f"tcp://{self.nexus_hostname}:{self.nexus_comm_port}" + ) + self.nexus_socket.send_pyobj(port_info) + local_log.info("broker resent message") + retries -= 1 + + self.nexus_socket.recv_pyobj() + + local_log.info("broker got message back from nexus") + + return + + def serve(self, message_process_func): + local_log.info("broker serving") + while self.running: + # this is more testable but may have a performance overhead + message_process_func() + self.shutdown() + + def read_and_pub_message(self): + try: + msg_ready = self.sub_socket.poll(timeout=0) + if msg_ready != 0: + msg = self.sub_socket.recv_multipart() + self.pub_socket.send_multipart(msg) + except zmq.error.ZMQError: + self.running = False + + def shutdown(self): + for handler in local_log.handlers: + handler.close() + + if self.sub_socket: + self.sub_socket.close(linger=0) + + if self.pub_socket: + self.sub_socket.close(linger=0) + + if self.nexus_socket: + self.sub_socket.close(linger=0) + + if self.zmq_context: + self.zmq_context.destroy(linger=0) + + def stop(self, signum, frame): + local_log.info(f"Log server shutting down due to signal {signum}") + + self.running = False diff --git a/improv/cli.py b/improv/cli.py index 720d4e2f..b3f6e586 100644 --- a/improv/cli.py +++ b/improv/cli.py @@ -8,15 +8,10 @@ import psutil import time import datetime -from zmq import SocketOption -from zmq.log.handlers import PUBHandler from improv.tui import TUI from improv.nexus import Nexus MAX_PORT = 2**16 - 1 -DEFAULT_CONTROL_PORT = "0" -DEFAULT_OUTPUT_PORT = "0" -DEFAULT_LOGGING_PORT = "0" def file_exists(fname): @@ -80,21 +75,18 @@ def parse_cli_args(args): "-c", "--control-port", type=is_valid_port, - default=DEFAULT_CONTROL_PORT, help="local port on which control are sent to/from server", ) run_parser.add_argument( "-o", "--output-port", type=is_valid_port, - default=DEFAULT_OUTPUT_PORT, help="local port on which server output messages are broadcast", ) run_parser.add_argument( "-l", "--logging-port", type=is_valid_port, - default=DEFAULT_LOGGING_PORT, help="local port on which logging messages are broadcast", ) run_parser.add_argument( @@ -121,21 +113,18 @@ def parse_cli_args(args): "-c", "--control-port", type=is_valid_ip_addr, - default=DEFAULT_CONTROL_PORT, help="address on which control signals are sent to the server", ) client_parser.add_argument( "-s", "--server-port", type=is_valid_ip_addr, - default=DEFAULT_OUTPUT_PORT, help="address on which messages from the server are received", ) client_parser.add_argument( "-l", "--logging-port", type=is_valid_ip_addr, - default=DEFAULT_LOGGING_PORT, help="address on which logging messages are broadcast", ) client_parser.set_defaults(func=run_client) @@ -147,21 +136,18 @@ def parse_cli_args(args): "-c", "--control-port", type=is_valid_port, - default=DEFAULT_CONTROL_PORT, help="local port on which control signals are received", ) server_parser.add_argument( "-o", "--output-port", type=is_valid_port, - default=DEFAULT_OUTPUT_PORT, help="local port on which output messages are broadcast", ) server_parser.add_argument( "-l", "--logging-port", type=is_valid_port, - default=DEFAULT_LOGGING_PORT, help="local port on which logging messages are broadcast", ) server_parser.add_argument( @@ -212,18 +198,11 @@ def run_server(args): """ Runs the improv server in headless mode. """ - zmq_log_handler = PUBHandler("tcp://*:%s" % args.logging_port) - # in case we bound to a random port (default), get port number - logging_port = int( - zmq_log_handler.socket.getsockopt_string(SocketOption.LAST_ENDPOINT).split(":")[ - -1 - ] - ) logging.basicConfig( level=logging.DEBUG, format="%(name)s %(message)s", - handlers=[logging.FileHandler(args.logfile), zmq_log_handler], + handlers=[logging.FileHandler("improv-debug.log")], ) if not args.actor_path: @@ -232,18 +211,25 @@ def run_server(args): sys.path.extend(args.actor_path) server = Nexus() - control_port, output_port = server.createNexus( + control_port, output_port, log_port = server.create_nexus( file=args.configfile, control_port=args.control_port, output_port=args.output_port, + log_server_pub_port=args.logging_port, + logfile=args.logfile, ) curr_dt = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") print( f"{curr_dt} Server running on (control, output, log) ports " - f"({control_port}, {output_port}, {logging_port}).\n" + f"({control_port}, {output_port}, {log_port}).\n" f"Press Ctrl-C to quit." ) - server.startNexus() + try: + server.start_nexus(server.poll_queues, poll_function=server.poll_kernel) + except Exception as e: + print(f"CLI-started server run encountered uncaught error {e}") + logging.error(f"CLI-started server run encountered uncaught error {e}") + raise e if args.actor_path: for p in args.actor_path: @@ -254,7 +240,7 @@ def run_server(args): def run_list(args, printit=True): out_list = [] - pattern = re.compile(r"(improv (run|client|server)|plasma_store|redis-server)") + pattern = re.compile(r"(improv (run|client|server)|redis-server)") # mp_pattern = re.compile(r"-c from multiprocessing") # TODO is this right? for proc in psutil.process_iter(["pid", "name", "cmdline"]): if proc.info["cmdline"]: @@ -281,22 +267,38 @@ def run_cleanup(args, headless=False): if res.lower() == "y": for proc in proc_list: - if not proc.status() == psutil.STATUS_STOPPED: - logging.info( - f"process {proc.pid} {proc.name()}" - f" has status {proc.status()}. Interrupting." - ) - try: + try: + if not proc.status() == psutil.STATUS_STOPPED: + logging.info( + f"process {proc.pid} {proc.name()}" + f" has status {proc.status()}. Interrupting." + ) proc.send_signal(signal.SIGINT) - except psutil.NoSuchProcess: - pass + except psutil.NoSuchProcess: + pass gone, alive = psutil.wait_procs(proc_list, timeout=3) for p in alive: - p.send_signal(signal.SIGINT) try: + p.terminate() p.wait(timeout=10) except psutil.TimeoutExpired as e: logging.warning(f"{e}: Process did not exit on time.") + try: + p.kill() + except psutil.NoSuchProcess as e: + logging.warning( + f"{e}: Process exited after wait timeout" + f" but before kill signal attempted." + ) + # this happens sometimes because Nexus uses gracious + # timeout periods. + except psutil.NoSuchProcess as e: + logging.warning( + f"{e}: Process exited after wait timeout" + f" but before kill signal attempted." + ) + # this happens sometimes because Nexus uses gracious + # timeout periods. else: if not headless: @@ -313,19 +315,28 @@ def run(args, timeout=10): server_opts = [ "improv", "server", - "-c", - str(args.control_port), - "-o", - str(args.output_port), - "-l", - str(args.logging_port), "-f", args.logfile, ] + + if args.control_port: + server_opts.append("-c") + server_opts.append(str(args.control_port)) + + if args.output_port: + server_opts.append("-o") + server_opts.append(str(args.output_port)) + + if args.logging_port: + server_opts.append("-l") + server_opts.append(str(args.logging_port)) + server_opts.extend(apath_opts) server_opts.append(args.configfile) - with open(args.logfile, mode="a+") as logfile: + print(" ".join(server_opts)) + + with open("improv-debug.log", mode="a+") as logfile: server = subprocess.Popen(server_opts, stdout=logfile, stderr=logfile) # wait for server to start up @@ -338,7 +349,9 @@ def run(args, timeout=10): run_client(args) try: - server.wait(timeout=2) + wait_timeout = 60 + print(f"Waiting {wait_timeout} seconds for Nexus to complete shutdown.") + server.wait(timeout=wait_timeout) except subprocess.TimeoutExpired: print("Cleaning up the hard way. May have exited dirty.") server.terminate() @@ -354,9 +367,9 @@ def get_server_ports(args, timeout): time_now = 0 ports = None while time_now < timeout: - server_start_time = _server_start_logged(args.logfile) + server_start_time = _server_start_logged("improv-debug.log") if server_start_time and server_start_time >= curr_dt: - ports = _get_ports(args.logfile) + ports = _get_ports("improv-debug.log") if ports: break @@ -365,12 +378,12 @@ def get_server_ports(args, timeout): if not server_start_time: print( - f"Unable to read server start time from {args.logfile}.\n" + f"Unable to read server start time from {'improv-debug.log'}.\n" "This may be because the server could not be started or " "did not log its activity." ) elif not ports: - print(f"Unable to read ports from {args.logfile}.") + print(f"Unable to read ports from {'improv-debug.log'}.") return ports @@ -393,13 +406,16 @@ def _get_ports(logfile): # read logfile to get ports with open(logfile, mode="r") as logfile: contents = logfile.read() + return _read_log_contents_for_ports(contents) - pattern = re.compile(r"(?<=\(control, output, log\) ports \()\d*, \d*, \d*") - # get most recent match (log file may contain old runs) - port_str_list = pattern.findall(contents) - if port_str_list: - port_str = port_str_list[-1] - return (int(p) for p in port_str.split(", ")) - else: - return None +def _read_log_contents_for_ports(logfile_contents): + pattern = re.compile(r"(?<=\(control, output, log\) ports \()\d*, \d*, \d*") + + # get most recent match (log file may contain old runs) + port_str_list = pattern.findall(logfile_contents) + if port_str_list: + port_str = port_str_list[-1] + return (int(p) for p in port_str.split(", ")) + else: + return None diff --git a/improv/config.py b/improv/config.py index 5f73c18b..59446fa5 100644 --- a/improv/config.py +++ b/improv/config.py @@ -6,46 +6,48 @@ logger = logging.getLogger(__name__) +class CannotCreateConfigException(Exception): + def __init__(self, msg): + super().__init__("Cannot create config: {}".format(msg)) + + class Config: """Handles configuration and logs of configs for the entire server/processing pipeline. """ - def __init__(self, configFile): - if configFile is None: - logger.error("Need to specify a config file") - raise Exception - else: - # Reading config from other yaml file - self.configFile = configFile - - with open(self.configFile, "r") as ymlfile: - cfg = yaml.safe_load(ymlfile) - - try: - if "settings" in cfg: - self.settings = cfg["settings"] - else: - self.settings = {} + def __init__(self, config_file): + self.actors = {} + self.connections = {} + self.hasGUI = False + self.config_file = config_file - if "use_watcher" not in self.settings: - self.settings["use_watcher"] = False + with open(self.config_file, "r") as ymlfile: + self.config = yaml.safe_load(ymlfile) - except TypeError: - if cfg is None: - logger.error("Error: The config file is empty") + if self.config is None: + logger.error("The config file is empty") + raise CannotCreateConfigException("The config file is empty") - if type(cfg) is not dict: + if type(self.config) is not dict: logger.error("Error: The config file is not in dictionary format") raise TypeError - self.config = cfg + def parse_config(self): + self.populate_defaults() + self.validate_config() - self.actors = {} - self.connections = {} - self.hasGUI = False + self.settings = self.config["settings"] + self.redis_config = self.config["redis_config"] + + def populate_defaults(self): + self.populate_settings_defaults() + self.popoulate_redis_defaults() + + def validate_config(self): + self.validate_redis_config() - def createConfig(self): + def create_config(self): """Read yaml config file and create config for Nexus TODO: check for config file compliance, error handle it beyond what we have below. @@ -53,9 +55,6 @@ def createConfig(self): cfg = self.config for name, actor in cfg["actors"].items(): - if name in self.actors.keys(): - raise RepeatedActorError(name) - packagename = actor.pop("package") classname = actor.pop("class") @@ -63,11 +62,15 @@ def createConfig(self): __import__(packagename, fromlist=[classname]) mod = import_module(packagename) - clss = getattr(mod, classname) - sig = signature(clss) - configModule = ConfigModule(name, packagename, classname, options=actor) - sig.bind(configModule.options) + actor_class = getattr(mod, classname) + sig = signature(actor_class) + config_module = ConfigModule( + name, packagename, classname, options=actor + ) + sig.bind(config_module.options) + # TODO: this is not trivial to test, since our code formatting + # tools won't allow a file with a syntax error to exist except SyntaxError as e: logger.error(f"Error: syntax error when initializing actor {name}: {e}") return -1 @@ -87,101 +90,123 @@ def createConfig(self): except TypeError: logger.error("Error: Invalid arguments passed") params = "" - for parameter in sig.parameters: - params = params + " " + parameter.name + for param_name, param in sig.parameters.items(): + params = params + ", " + param.name logger.warning("Expected Parameters:" + params) return -1 - except Exception as e: + except Exception as e: # TODO: figure out how to test this logger.error(f"Error: {e}") return -1 if "GUI" in name: logger.info(f"Config detected a GUI actor: {name}") self.hasGUI = True - self.gui = configModule + self.gui = config_module else: - self.actors.update({name: configModule}) + self.actors.update({name: config_module}) for name, conn in cfg["connections"].items(): - if name in self.connections.keys(): - raise RepeatedConnectionsError(name) - self.connections.update({name: conn}) - if "datastore" in cfg.keys(): - self.datastore = cfg["datastore"] - return 0 - def addParams(self, type, param): - """Function to add paramter param of type type - TODO: Future work - """ - pass - - def saveActors(self): + def save_actors(self): """Saves the actors config to a specific file.""" wflag = True - saveFile = self.configFile.split(".")[0] + saveFile = self.config_file.split(".")[0] pathName = saveFile + "_actors.yaml" for a in self.actors.values(): - wflag = a.saveConfigModules(pathName, wflag) - - def use_plasma(self): - return "plasma_config" in self.config.keys() - - def get_redis_port(self): - if self.redis_port_specified(): - return self.config["redis_config"]["port"] - else: - return Config.get_default_redis_port() - - def redis_port_specified(self): - if "redis_config" in self.config.keys(): - return "port" in self.config["redis_config"] - return False - - def redis_saving_enabled(self): - if "redis_config" in self.config.keys(): - return ( - self.config["redis_config"]["enable_saving"] - if "enable_saving" in self.config["redis_config"] - else None + wflag = a.save_config_modules(pathName, wflag) + + def populate_settings_defaults(self): + if "settings" not in self.config: + self.config["settings"] = {} + + if "store_size" not in self.config["settings"]: + self.config["settings"]["store_size"] = 250_000_000 + if "control_port" not in self.config["settings"]: + self.config["settings"]["control_port"] = 5555 + if "output_port" not in self.config["settings"]: + self.config["settings"]["output_port"] = 5556 + if "actor_in_port" not in self.config["settings"]: + self.config["settings"]["actor_in_port"] = 0 + if "harvest_data_from_memory" not in self.config["settings"]: + self.config["settings"]["harvest_data_from_memory"] = None + + def popoulate_redis_defaults(self): + if "redis_config" not in self.config: + self.config["redis_config"] = {} + + if "enable_saving" not in self.config["redis_config"]: + self.config["redis_config"]["enable_saving"] = None + if "aof_dirname" not in self.config["redis_config"]: + self.config["redis_config"]["aof_dirname"] = None + if "generate_ephemeral_aof_dirname" not in self.config["redis_config"]: + self.config["redis_config"]["generate_ephemeral_aof_dirname"] = False + if "fsync_frequency" not in self.config["redis_config"]: + self.config["redis_config"]["fsync_frequency"] = None + + # enable saving automatically if the user configured a saving option + if ( + self.config["redis_config"]["aof_dirname"] + or self.config["redis_config"]["generate_ephemeral_aof_dirname"] + or self.config["redis_config"]["fsync_frequency"] + ) and self.config["redis_config"]["enable_saving"] is None: + self.config["redis_config"]["enable_saving"] = True + + if "port" not in self.config["redis_config"]: + self.config["redis_config"]["port"] = 6379 + + def validate_redis_config(self): + fsync_name_dict = { + "every_write": "always", + "every_second": "everysec", + "no_schedule": "no", + } + if ( + self.config["redis_config"]["aof_dirname"] + and self.config["redis_config"]["generate_ephemeral_aof_dirname"] + ): + logger.error( + "Cannot both generate a unique dirname and use the one provided." ) - - def generate_ephemeral_aof_dirname(self): - if "redis_config" in self.config.keys(): - return ( - self.config["redis_config"]["generate_ephemeral_aof_dirname"] - if "generate_ephemeral_aof_dirname" in self.config["redis_config"] - else None - ) - return False - - def get_redis_aof_dirname(self): - if "redis_config" in self.config.keys(): - return ( - self.config["redis_config"]["aof_dirname"] - if "aof_dirname" in self.config["redis_config"] - else None + raise Exception("Cannot use unique dirname and use the one provided.") + + if ( + self.config["redis_config"]["aof_dirname"] + or self.config["redis_config"]["generate_ephemeral_aof_dirname"] + or self.config["redis_config"]["fsync_frequency"] + ): + if not self.config["redis_config"]["enable_saving"]: + logger.error( + "Invalid configuration. Cannot save to disk with saving disabled." + ) + raise Exception("Cannot persist to disk with saving disabled.") + + if self.config["redis_config"]["fsync_frequency"] and self.config[ + "redis_config" + ]["fsync_frequency"] not in [ + "every_write", + "every_second", + "no_schedule", + ]: + logger.error( + f"Cannot use unknown fsync frequency " + f'{self.config["redis_config"]["fsync_frequency"]}' ) - return None - - def get_redis_fsync_frequency(self): - if "redis_config" in self.config.keys(): - frequency = ( - self.config["redis_config"]["fsync_frequency"] - if "fsync_frequency" in self.config["redis_config"] - else None + raise Exception( + f"Cannot use unknown fsync frequency " + f'{self.config["redis_config"]["fsync_frequency"]}' ) - return frequency + if self.config["redis_config"]["fsync_frequency"] is None: + self.config["redis_config"]["fsync_frequency"] = "no_schedule" - @staticmethod - def get_default_redis_port(): - return "6379" + self.config["redis_config"]["fsync_frequency"] = (fsync_name_dict)[ + self.config["redis_config"]["fsync_frequency"] + ] class ConfigModule: @@ -191,11 +216,11 @@ def __init__(self, name, packagename, classname, options=None): self.classname = classname self.options = options - def saveConfigModules(self, pathName, wflag): + def save_config_modules(self, path_name, wflag): """Loops through each actor to save the modules to the config file. Args: - pathName: + path_name: wflag (bool): Returns: @@ -213,32 +238,7 @@ def saveConfigModules(self, pathName, wflag): for key, value in self.options.items(): cfg[self.name].update({key: value}) - with open(pathName, writeOption) as file: + with open(path_name, writeOption) as file: yaml.dump(cfg, file) return wflag - - -class RepeatedActorError(Exception): - def __init__(self, repeat): - super().__init__() - - self.name = "RepeatedActorError" - self.repeat = repeat - - self.message = 'Actor name has already been used: "{}"'.format(repeat) - - def __str__(self): - return self.message - - -class RepeatedConnectionsError(Exception): - def __init__(self, repeat): - super().__init__() - self.name = "RepeatedConnectionsError" - self.repeat = repeat - - self.message = 'Connection name has already been used: "{}"'.format(repeat) - - def __str__(self): - return self.message diff --git a/improv/harvester.py b/improv/harvester.py new file mode 100644 index 00000000..ba66bdf6 --- /dev/null +++ b/improv/harvester.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +import logging +import signal +import time + +import zmq + +from improv.link import ZmqLink +from improv.log import ZmqLogHandler +from improv.store import RedisStoreInterface +from zmq import SocketOption + +from improv.messaging import HarvesterInfoMsg + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +def bootstrap_harvester( + nexus_hostname, + nexus_port, + redis_hostname, + redis_port, + broker_hostname, + broker_port, + logger_hostname, + logger_port, +): + harvester = RedisHarvester( + nexus_hostname, + nexus_port, + redis_hostname, + redis_port, + broker_hostname, + broker_port, + logger_hostname, + logger_port, + ) + harvester.establish_connections() + harvester.register_with_nexus() + harvester.serve(harvester.collect) + + +class RedisHarvester: + def __init__( + self, + nexus_hostname, + nexus_comm_port, + redis_hostname, + redis_port, + broker_hostname, + broker_port, + logger_hostname, + logger_port, + ): + self.link: ZmqLink | None = None + self.running = True + self.nexus_hostname: str = nexus_hostname + self.nexus_comm_port: int = nexus_comm_port + self.redis_hostname: str = redis_hostname + self.redis_port: int = redis_port + self.broker_hostname: str = broker_hostname + self.broker_port: int = broker_port + self.zmq_context: zmq.Context | None = None + self.nexus_socket: zmq.Socket | None = None + self.sub_port: int | None = None + self.sub_socket: zmq.Socket | None = None + self.store_client: RedisStoreInterface | None = None + self.logger_hostname: str = logger_hostname + self.logger_port: int = logger_port + + logger.addHandler(ZmqLogHandler(logger_hostname, logger_port)) + + signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT) + for s in signals: + signal.signal(s, self.stop) + + def establish_connections(self): + logger.info("Registering with Nexus") + # connect to nexus + self.zmq_context = zmq.Context() + self.zmq_context.setsockopt(SocketOption.LINGER, 0) + self.nexus_socket = self.zmq_context.socket(zmq.REQ) + self.nexus_socket.connect(f"tcp://{self.nexus_hostname}:{self.nexus_comm_port}") + + self.sub_socket = self.zmq_context.socket(zmq.SUB) + self.sub_socket.connect(f"tcp://{self.broker_hostname}:{self.broker_port}") + sub_port_string = self.sub_socket.getsockopt_string(SocketOption.LAST_ENDPOINT) + self.sub_port = int(sub_port_string.split(":")[-1]) + self.sub_socket.subscribe("") # receive all incoming messages + + self.store_client = RedisStoreInterface( + "harvester", self.redis_port, self.redis_hostname + ) + + self.link = ZmqLink(self.sub_socket, "harvester", "") + + # TODO: this is a small function that is currently easy to verify + # it could be unit-tested with a multiprocess that spins up a socket + # which this can communicate with, but isn't currenlty worth the time + def register_with_nexus(self): + port_info = HarvesterInfoMsg( + "harvester", + "Ports up and running, ready to serve messages", + ) + + self.nexus_socket.send_pyobj(port_info) + self.nexus_socket.recv_pyobj() + + return + + def serve(self, message_process_func, *args, **kwargs): + logger.info("Harvester beginning harvest") + while self.running: + message_process_func(*args, **kwargs) + self.shutdown() + + def collect(self, *args, **kwargs): + db_info = self.store_client.client.info() + max_memory = db_info["maxmemory"] + used_memory = db_info["used_memory"] + used_max_ratio = used_memory / max_memory + if used_max_ratio > 0.75: + while self.running and (used_max_ratio > 0.50): + try: + key = self.link.get(timeout=100) # 100ms + self.store_client.client.delete(key) + except TimeoutError: + pass + db_info = self.store_client.client.info() + max_memory = db_info["maxmemory"] + used_memory = db_info["used_memory"] + used_max_ratio = used_memory / max_memory + time.sleep(0.1) + return + + def shutdown(self): + for handler in logger.handlers: + if isinstance(handler, ZmqLogHandler): + handler.close() + logger.removeHandler(handler) + + if self.sub_socket: + self.sub_socket.close(linger=0) + + if self.nexus_socket: + self.nexus_socket.close(linger=0) + + if self.zmq_context: + self.zmq_context.destroy(linger=0) + + def stop(self, signum, frame): + self.running = False + logger.info(f"Harvester shutting down due to signal {signum}") diff --git a/improv/link.py b/improv/link.py index 5014204b..81fee4e5 100644 --- a/improv/link.py +++ b/improv/link.py @@ -1,91 +1,27 @@ import asyncio +import json import logging -from multiprocessing import Manager, cpu_count +from multiprocessing import cpu_count from concurrent.futures import ThreadPoolExecutor -from concurrent.futures._base import CancelledError + +import zmq +from zmq import SocketOption logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) -def Link(name, start, end): - """Function to construct a queue that Nexus uses for - inter-process (actor) signaling and information passing. - - A Link has an internal queue that can be synchronous (put, get) - as inherited from multiprocessing.Manager.Queue - or asynchronous (put_async, get_async) using async executors. - - Args: - See AsyncQueue constructor - - Returns: - AsyncQueue: queue for communicating between actors and with Nexus - """ - - m = Manager() - q = AsyncQueue(m.Queue(maxsize=0), name, start, end) - return q - - -class AsyncQueue(object): - """Single-output and asynchronous queue class. - - Attributes: - queue: - real_executor: - cancelled_join: boolean - name: - start: - end: - status: - result: - num: - dict: - """ - - def __init__(self, q, name, start, end): - """Constructor for the queue class. - - Args: - q (Queue): A queue from the Manager class - name (str): String description of this queue - start (str): The producer (input) actor name for the queue - end (str): The consumer (output) actor name for the queue - - """ - self.queue = q - self.real_executor = None - self.cancelled_join = False - +class ZmqLink: + def __init__(self, socket: zmq.Socket, name, topic=None): self.name = name - self.start = start - self.end = end - + self.real_executor = None + self.socket = socket # this must already be set up + self.socket_type = self.socket.getsockopt(SocketOption.TYPE) + if self.socket_type == zmq.PUB and topic is None: + raise Exception("Cannot open PUB link without topic") self.status = "pending" self.result = None - - def getStart(self): - """Gets the starting actor. - - The starting actor is the actor that is at the tail of the link. - This actor is the one that gives output. - - Returns: - start (str): The starting actor name - """ - return self.start - - def getEnd(self): - """Gets the ending actor. - - The ending actor is the actor that is at the head of the link. - This actor is the one that takes input. - - Returns: - end (str): The ending actor name - """ - return self.end + self.topic = topic @property def _executor(self): @@ -93,195 +29,57 @@ def _executor(self): self.real_executor = ThreadPoolExecutor(max_workers=cpu_count()) return self.real_executor - def __getstate__(self): - """Gets a dictionary of attributes. - - This function gets a dictionary, with keys being the names of - the attributes, and values being the values of the attributes. + def get(self, timeout=None): + if timeout is None: + if self.socket_type == zmq.SUB: + res = self.socket.recv_multipart() + return json.loads(res[1].decode("utf-8")) - Returns: - self_dict (dict): A dictionary containing attributes. - """ - self_dict = self.__dict__ - self_dict["_real_executor"] = None - return self_dict + return self.socket.recv_pyobj() - def __getattr__(self, name): - """Gets the attribute specified by "name". + msg_ready = self.socket.poll(timeout=timeout) + if msg_ready == 0: + raise TimeoutError - Args: - name (str): Name of the attribute to be returned. + if self.socket_type == zmq.SUB: + res = self.socket.recv_multipart() + return json.loads(res[1]) - Raises: - AttributeError: Restricts the available attributes to a - specific list. This error is raised if a different attribute - of the queue is requested. + return self.socket.recv_pyobj() - TODO: - Don't raise this? + def get_nowait(self): + return self.get(timeout=0) - Returns: - (object): Value of the attribute specified by "name". - """ - if name in ["qsize", "empty", "full", "get", "get_nowait", "close"]: - return getattr(self.queue, name) + def put( + self, item + ): # TODO: is it a problem we're implicitly handling inputs differently here? + if self.socket_type == zmq.PUB: + self.socket.send_multipart( + [self.topic.encode("utf-8"), json.dumps(item).encode("utf-8")] + ) else: - cn = self.__class__.__name__ - raise AttributeError("{} object has no attribute {}".format(cn, name)) - - def __repr__(self): - """String representation for Link. - - Returns: - (str): "Link" followed by the name given in the constructor. - """ - return "Link " + self.name - - def put(self, item): - """Function wrapper for put. - - Args: - item (object): Any item that can be sent through a queue - """ - self.queue.put(item) - - def put_nowait(self, item): - """Function wrapper for put without waiting - - Args: - item (object): Any item that can be sent through a queue - """ - self.queue.put_nowait(item) + self.socket.send_pyobj(item) async def put_async(self, item): - """Coroutine for an asynchronous put - - It adds the put request to the event loop and awaits. - - Args: - item (object): Any item that can be sent through a queue - - Returns: - Awaitable or result of the put - """ loop = asyncio.get_event_loop() - try: - res = await loop.run_in_executor(self._executor, self.put, item) - return res - except EOFError: - logger.warn("Link probably killed (EOF)") - except FileNotFoundError: - logger.warn("probably killed (file not found)") + res = await loop.run_in_executor(self._executor, self.put, item) + return res async def get_async(self): - """Coroutine for an asynchronous get - - It adds the get request to the event loop and awaits, setting - the status to pending. Once the get has returned, it returns the - result of the get and sets its status as done. - - Explicitly passes any exceptions to not hinder execution. - Errors are logged with the get_async tag. - - Returns: - Awaitable or result of the get. - - Raises: - CancelledError: task is cancelled - EOFError: - FileNotFoundError: - Exception: - """ loop = asyncio.get_event_loop() self.status = "pending" - try: - self.result = await loop.run_in_executor(self._executor, self.get) - self.status = "done" - return self.result - except CancelledError: - logger.info("Task {} Canceled".format(self.name)) - except EOFError: - logger.info("probably killed") - except FileNotFoundError: - logger.info("probably killed") - except Exception as e: - logger.exception("Error in get_async: {}".format(e)) - - def cancel_join_thread(self): - """Function wrapper for cancel_join_thread.""" - self._cancelled_join = True - self._queue.cancel_join_thread() - - def join_thread(self): - """Function wrapper for join_thread.""" - self._queue.join_thread() - if self._real_executor and not self._cancelled_join: - self._real_executor.shutdown() - - -def MultiLink(name, start, end): - """Function to generate links for the multi-output queue case. - - Args: - See constructor for AsyncQueue or MultiAsyncQueue - - Returns: - MultiAsyncQueue: Producer end of the queue - List: AsyncQueues for consumers - """ - m = Manager() - - q_out = [] - for endpoint in end: - q = AsyncQueue(m.Queue(maxsize=0), name, start, endpoint) - q_out.append(q) - - q = MultiAsyncQueue(m.Queue(maxsize=0), q_out, name, start, end) - - return q, q_out - - -class MultiAsyncQueue(AsyncQueue): - """Extension of AsyncQueue to have multiple endpoints. - - Inherits from AsyncQueue. - A single producer queue's 'put' is copied to multiple consumer's - queues, q_in is the producer queue, q_out are the consumer queues. - - TODO: - Test the async nature of this group of queues - """ - - def __init__(self, q_in, q_out, name, start, end): - self.queue = q_in - self.output = q_out - - self.real_executor = None - self.cancelled_join = False - - self.name = name - self.start = start - self.end = end[0] - self.status = "pending" - self.result = None - - def __repr__(self): - return "MultiLink " + self.name - - def __getattr__(self, name): - # Remove put and put_nowait and define behavior specifically - # TODO: remove get capability? - if name in ["qsize", "empty", "full", "get", "get_nowait", "close"]: - return getattr(self.queue, name) - else: - raise AttributeError( - "'%s' object has no attribute '%s'" % (self.__class__.__name__, name) - ) - - def put(self, item): - for q in self.output: - q.put(item) - - def put_nowait(self, item): - for q in self.output: - q.put_nowait(item) + # try: + self.result = await loop.run_in_executor(self._executor, self.get) + self.status = "done" + return self.result + # TODO: explicitly commenting these out because testing them is hard. + # It's better to let them bubble up so that we don't miss them + # due to being caught without a clear reproducible mechanism + # except CancelledError: + # logger.info("Task {} Canceled".format(self.name)) + # except EOFError: + # logger.info("probably killed") + # except FileNotFoundError: + # logger.info("probably killed") + # except Exception as e: + # logger.exception("Error in get_async: {}".format(e)) diff --git a/improv/log.py b/improv/log.py new file mode 100644 index 00000000..70373d57 --- /dev/null +++ b/improv/log.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +import logging +import signal +from logging import handlers +from logging.handlers import QueueHandler + +import zmq +from zmq import SocketOption +from zmq.log.handlers import PUBHandler + +from improv.messaging import LogInfoMsg + +local_log = logging.getLogger(__name__) + +DEBUG = True + +# TODO: ideally there should be some kind of drain at shutdown time +# so we don't miss any log messages, but that would make shutdown +# also take longer. TBD? + + +def bootstrap_log_server( + nexus_hostname, nexus_port, log_filename="global.log", logger_pull_port=None +): + if DEBUG: + local_log.addHandler(logging.FileHandler("log_server.log")) + try: + log_server = LogServer( + nexus_hostname, nexus_port, log_filename, logger_pull_port + ) + log_server.register_with_nexus() + log_server.serve(log_server.read_and_log_message) + except Exception as e: + local_log.error(e) + for handler in local_log.handlers: + handler.close() + + +class ZmqPullListener(handlers.QueueListener): + def __init__(self, ctx, /, *handlers, **kwargs): + self.sentinel = False + self.ctx = ctx + self.pull_socket = self.ctx.socket(zmq.PULL) + self.pull_socket.bind("tcp://*:0") + pull_port_string = self.pull_socket.getsockopt_string( + SocketOption.LAST_ENDPOINT + ) + self.pull_port = int(pull_port_string.split(":")[-1]) + super().__init__(self.pull_socket, *handlers, **kwargs) + + def dequeue(self, block=True): + msg = None + while msg is None: + if self.sentinel: + return handlers.QueueListener._sentinel + msg_ready = self.queue.poll(timeout=1000) + if msg_ready != 0: + msg = self.queue.recv_json() + return logging.makeLogRecord(msg) + + def enqueue_sentinel(self): + self.sentinel = True + + +class ZmqLogHandler(QueueHandler): + def __init__(self, hostname, port, ctx=None): + self.ctx = ctx if ctx else zmq.Context() + self.ctx.setsockopt(SocketOption.LINGER, 0) + self.socket = self.ctx.socket(zmq.PUSH) + self.socket.connect(f"tcp://{hostname}:{port}") + super().__init__(self.socket) + + def enqueue(self, record): + self.queue.send_json(record.__dict__) + + def close(self): + self.queue.close(linger=0) + + +class LogServer: + def __init__(self, nexus_hostname, nexus_comm_port, log_filename, pub_port): + self.running = True + self.pub_port: int | None = pub_port if pub_port else 0 + self.pub_socket: zmq.Socket | None = None + self.log_filename = log_filename + self.nexus_hostname: str = nexus_hostname + self.nexus_comm_port: int = nexus_comm_port + self.zmq_context: zmq.Context | None = None + self.nexus_socket: zmq.Socket | None = None + self.pull_socket: zmq.Socket | None = None + self.listener: ZmqPullListener | None = None + + signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT) + for s in signals: + signal.signal(s, self.stop) + + def register_with_nexus(self): + # connect to nexus + self.zmq_context = zmq.Context() + self.zmq_context.setsockopt(SocketOption.LINGER, 0) + self.nexus_socket = self.zmq_context.socket(zmq.REQ) + self.nexus_socket.connect(f"tcp://{self.nexus_hostname}:{self.nexus_comm_port}") + + self.pub_socket = self.zmq_context.socket(zmq.PUB) + self.pub_socket.bind(f"tcp://*:{self.pub_port}") + pub_port_string = self.pub_socket.getsockopt_string(SocketOption.LAST_ENDPOINT) + self.pub_port = int(pub_port_string.split(":")[-1]) + + self.listener = ZmqPullListener( + self.zmq_context, + # logging.StreamHandler(sys.stdout), + logging.FileHandler(self.log_filename), + PUBHandler(self.pub_socket, self.zmq_context, "nexus_logging"), + ) + + self.listener.start() + + local_log.info("logger started listening") + + port_info = LogInfoMsg( + "broker", + self.listener.pull_port, + self.pub_port, + "Port up and running, ready to log messages", + ) + + self.nexus_socket.send_pyobj(port_info) + local_log.info("logger sent message to nexus") + self.nexus_socket.recv_pyobj() + + local_log.info("logger got message from nexus") + + return + + def serve(self, log_func): + local_log.info("logger serving") + while self.running: + log_func() # this is more testable but may have a performance overhead + self.shutdown() + + def read_and_log_message(self): # receive and send back out + pass + + def shutdown(self): + if self.listener: + self.listener.stop() + + for handler in self.listener.handlers: + try: + handler.close() + except Exception as e: + local_log.error(e) + + if self.pull_socket: + self.pull_socket.close(linger=0) + + if self.nexus_socket: + self.nexus_socket.close(linger=0) + + if self.pub_socket: + self.pub_socket.close(linger=0) + + for handler in local_log.handlers: + handler.close() + + if self.zmq_context: + self.zmq_context.destroy(linger=0) + + def stop(self, signum, frame): + local_log.info(f"Log server shutting down due to signal {signum}") + + self.running = False diff --git a/improv/messaging.py b/improv/messaging.py new file mode 100644 index 00000000..29b0461f --- /dev/null +++ b/improv/messaging.py @@ -0,0 +1,71 @@ +class ActorStateMsg: + def __init__(self, actor_name, status, nexus_in_port, info): + self.actor_name = actor_name + self.status = status + self.nexus_in_port = nexus_in_port + self.info = info + + +class ActorStateReplyMsg: + def __init__(self, actor_name, status, info): + self.actor_name = actor_name + self.status = status + self.info = info + + +class ActorSignalMsg: + def __init__(self, actor_name, signal, info): + self.actor_name = actor_name + self.signal = signal + self.info = info + + +class ActorSignalReplyMsg: + def __init__(self, actor_name, signal, status, info): + self.actor_name = actor_name + self.signal = signal + self.status = status + self.info = info + + +class BrokerInfoMsg: + def __init__(self, name, pub_port, sub_port, info): + self.name = name + self.pub_port = pub_port + self.sub_port = sub_port + self.info = info + + +class BrokerInfoReplyMsg: + def __init__(self, name, status, info): + self.name = name + self.status = status + self.info = info + + +class LogInfoMsg: + def __init__(self, name, pull_port, pub_port, info): + self.name = name + self.pull_port = pull_port + self.pub_port = pub_port + self.info = info + + +class LogInfoReplyMsg: + def __init__(self, name, status, info): + self.name = name + self.status = status + self.info = info + + +class HarvesterInfoMsg: + def __init__(self, name, info): + self.name = name + self.info = info + + +class HarvesterInfoReplyMsg: + def __init__(self, name, status, info): + self.name = name + self.status = status + self.info = info diff --git a/improv/nexus.py b/improv/nexus.py index b5fe19c4..8089f666 100644 --- a/improv/nexus.py +++ b/improv/nexus.py @@ -1,4 +1,6 @@ -import os +from __future__ import annotations + +import multiprocessing import time import uuid import signal @@ -7,46 +9,118 @@ import concurrent import subprocess -from queue import Full from datetime import datetime -from multiprocessing import Process, get_context +from multiprocessing import get_context from importlib import import_module +import zmq as zmq_sync import zmq.asyncio as zmq -from zmq import PUB, REP, SocketOption - -from improv.store import StoreInterface, RedisStoreInterface, PlasmaStoreInterface -from improv.actor import Signal +from zmq import PUB, REP, REQ, SocketOption + +from improv import log +from improv.broker import bootstrap_broker +from improv.harvester import bootstrap_harvester +from improv.log import bootstrap_log_server +from improv.messaging import ( + ActorStateMsg, + ActorStateReplyMsg, + ActorSignalMsg, + BrokerInfoReplyMsg, + BrokerInfoMsg, + LogInfoMsg, + LogInfoReplyMsg, + HarvesterInfoMsg, + HarvesterInfoReplyMsg, +) +from improv.store import StoreInterface, RedisStoreInterface +from improv.actor import Signal, Actor, LinkInfo from improv.config import Config -from improv.link import Link, MultiLink + +ASYNC_DEBUG = False logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) +logger.setLevel(logging.INFO) +if logger.level == logging.DEBUG: + logger.addHandler(logging.StreamHandler()) + +# TODO: redo docsctrings since things are pretty different now + +# TODO: socket setup can fail - need to check it + +# TODO: redo how actors register with nexus so that we get actor states +# earlier + + +class ConfigFileNotProvidedException(Exception): + def __init__(self): + super().__init__("Config file not provided") -# TODO: Set up store.notify in async function (?) + +class ConfigFileNotValidException(Exception): + def __init__(self): + super().__init__("Config file not valid") + + +class ActorState: + def __init__( + self, actor_name, status, nexus_in_port, hostname="localhost", sig_socket=None + ): + self.actor_name = actor_name + self.status = status + self.nexus_in_port = nexus_in_port + self.hostname = hostname + self.sig_socket = None class Nexus: """Main server class for handling objects in improv""" def __init__(self, name="Server"): + self.logger_in_port: int | None = None + self.zmq_sync_context: zmq_sync.Context | None = None + self.logfile: str | None = None + self.p_harvester: multiprocessing.Process | None = None + self.p_broker: multiprocessing.Process | None = None + self.actor_in_socket_port: int | None = None + self.actor_in_socket: zmq.Socket | None = None + self.in_socket: zmq.Socket | None = None + self.out_socket: zmq.Socket | None = None + self.zmq_context: zmq.Context | None = None + self.logger_pub_port: int | None = None + self.logger_pull_port: int | None = None + self.logger_in_socket: zmq.Socket | None = None + self.p_logger: multiprocessing.Process | None = None + self.broker_pub_port = None + self.broker_sub_port = None + self.broker_in_port: int | None = None + self.broker_in_socket: zmq_sync.Socket | None = None + self.actor_states: dict[str, ActorState | None] = dict() self.redis_fsync_frequency = None self.store = None self.config = None self.name = name self.aof_dir = None self.redis_saving_enabled = False + self.allow_setup = False + self.outgoing_topics = dict() + self.incoming_topics = dict() + self.data_queues = {} + self.actors = {} + self.flags = {} + self.processes: list[multiprocessing.Process] = [] def __str__(self): return self.name - def createNexus( + def create_nexus( self, file=None, - use_watcher=None, - store_size=10_000_000, - control_port=0, - output_port=0, + store_size=None, + control_port=None, + output_port=None, + log_server_pub_port=None, + actor_in_port=None, + logfile="global.log", ): """Function to initialize class variables based on config file. @@ -56,101 +130,68 @@ def createNexus( Args: file (string): Name of the config file. - use_watcher (bool): Whether to use watcher for the store. store_size (int): initial store size control_port (int): port number for input socket output_port (int): port number for output socket + actor_in_port (int): port number for the socket which receives + actor communications Returns: string: "Shutting down", to notify start() that pollQueues has completed. """ + self.logfile = logfile + curr_dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") logger.info(f"************ new improv server session {curr_dt} ************") if file is None: logger.exception("Need a config file!") - raise Exception # TODO - else: - logger.info(f"Loading configuration file {file}:") - self.loadConfig(file=file) - with open(file, "r") as f: # write config file to log - logger.info(f.read()) - - # set config options loaded from file - # in Python 3.9, can just merge dictionaries using precedence - cfg = self.config.settings - if "use_watcher" not in cfg: - cfg["use_watcher"] = use_watcher - if "store_size" not in cfg: - cfg["store_size"] = store_size - if "control_port" not in cfg or control_port != 0: - cfg["control_port"] = control_port - if "output_port" not in cfg or output_port != 0: - cfg["output_port"] = output_port - - # set up socket in lieu of printing to stdout - self.zmq_context = zmq.Context() - self.zmq_context.setsockopt(SocketOption.LINGER, 1) - self.out_socket = self.zmq_context.socket(PUB) - self.out_socket.bind("tcp://*:%s" % cfg["output_port"]) - out_port_string = self.out_socket.getsockopt_string(SocketOption.LAST_ENDPOINT) - cfg["output_port"] = int(out_port_string.split(":")[-1]) - - self.in_socket = self.zmq_context.socket(REP) - self.in_socket.bind("tcp://*:%s" % cfg["control_port"]) - in_port_string = self.in_socket.getsockopt_string(SocketOption.LAST_ENDPOINT) - cfg["control_port"] = int(in_port_string.split(":")[-1]) - - self.configure_redis_persistence() + raise ConfigFileNotProvidedException - # default size should be system-dependent - if self.config and self.config.use_plasma(): - self._startStoreInterface(store_size) - else: - self._startStoreInterface(store_size) - logger.info("Redis server started") - - self.out_socket.send_string("StoreInterface started") + logger.info(f"Loading configuration file {file}:") + self.config = Config(config_file=file) + self.config.parse_config() - # connect to store and subscribe to notifications - logger.info("Create new store object") - if self.config and self.config.use_plasma(): - self.store = PlasmaStoreInterface(store_loc=self.store_loc) - else: - self.store = StoreInterface(server_port_num=self.store_port) - logger.info(f"Redis server connected on port {self.store_port}") + with open(file, "r") as f: # write config file to log + logger.info(f.read()) - self.store.subscribe() + logger.debug("Applying CLI parameter configuration overrides") + self.apply_cli_config_overrides( + store_size=store_size, + control_port=control_port, + output_port=output_port, + actor_in_port=actor_in_port, + ) - # TODO: Better logic/flow for using watcher as an option - self.p_watch = None - if cfg["use_watcher"]: - self.startWatcher() + logger.debug("Setting up sockets") + self.set_up_sockets(actor_in_port=self.config.settings["actor_in_port"]) - # Create dicts for reading config and creating actors - self.comm_queues = {} - self.sig_queues = {} - self.data_queues = {} - self.actors = {} - self.flags = {} - self.processes = [] + logger.debug("Setting up services") + self.start_improv_services( + log_server_pub_port=log_server_pub_port, + store_size=self.config.settings["store_size"], + ) - self.initConfig() + logger.debug("initializing config") + self.init_config() self.flags.update({"quit": False, "run": False, "load": False}) self.allowStart = False self.stopped = False - return (cfg["control_port"], cfg["output_port"]) - - def loadConfig(self, file): - """Load configuration file. - file: a YAML configuration file name - """ - self.config = Config(configFile=file) + logger.info( + f"control: {self.config.settings['control_port']}," + f" output: {self.config.settings['output_port']}," + f" logging: {self.logger_pub_port}" + ) + return ( + self.config.settings["control_port"], + self.config.settings["output_port"], + self.logger_pub_port, + ) - def initConfig(self): + def init_config(self): """For each connection: create a Link with a name (purpose), start, and end Start links to one actor's name, end to the other. @@ -168,121 +209,74 @@ def initConfig(self): """ # TODO load from file or user input, as in dialogue through FrontEnd? - flag = self.config.createConfig() + logger.info("initializing config") + flag = self.config.create_config() if flag == -1: logger.error( "An error occurred when loading the configuration file. " "Please see the log file for more details." ) + self.destroy_nexus() + raise ConfigFileNotValidException # create all data links requested from Config config - self.createConnections() - - if self.config.hasGUI: - # Have to load GUI first (at least with Caiman) - name = self.config.gui.name - m = self.config.gui # m is ConfigModule - # treat GUI uniquely since user communication comes from here - try: - visualClass = m.options["visual"] - # need to instantiate this actor - visualActor = self.config.actors[visualClass] - self.createActor(visualClass, visualActor) - # then add links for visual - for k, l in { - key: self.data_queues[key] - for key in self.data_queues.keys() - if visualClass in key - }.items(): - self.assignLink(k, l) - - # then give it to our GUI - self.createActor(name, m) - self.actors[name].setup(visual=self.actors[visualClass]) - - self.p_GUI = Process(target=self.actors[name].run, name=name) - self.p_GUI.daemon = True - self.p_GUI.start() - - except Exception as e: - logger.error(f"Exception in setting up GUI {name}: {e}") - - else: - # have fake GUI for communications - q_comm = Link("GUI_comm", "GUI", self.name) - self.comm_queues.update({q_comm.name: q_comm}) + self.create_connections() + + # if self.config.hasGUI: + # # Have to load GUI first (at least with Caiman) + # name = self.config.gui.name + # m = self.config.gui # m is ConfigModule + # # treat GUI uniquely since user communication comes from here + # try: + # visualClass = m.options["visual"] + # # need to instantiate this actor + # visualActor = self.config.actors[visualClass] + # self.create_actor(visualClass, visualActor) + # # then add links for visual + # for k, l in { + # key: self.data_queues[key] + # for key in self.data_queues.keys() + # if visualClass in key + # }.items(): + # self.assign_link(k, l) + # + # # then give it to our GUI + # self.create_actor(name, m) + # self.actors[name].setup(visual=self.actors[visualClass]) + # + # self.p_GUI = Process(target=self.actors[name].run, name=name) + # self.p_GUI.daemon = True + # self.p_GUI.start() + # + # except Exception as e: + # logger.error(f"Exception in setting up GUI {name}: {e}") # First set up each class/actor for name, actor in self.config.actors.items(): if name not in self.actors.keys(): # Check for actors being instantiated twice try: - self.createActor(name, actor) + self.create_actor(name, actor) logger.info(f"Setting up actor {name}") except Exception as e: logger.error(f"Exception in setting up actor {name}: {e}.") self.quit() + raise e # Second set up each connection b/t actors # TODO: error handling for if a user tries to use q_in without defining it - for name, link in self.data_queues.items(): - self.assignLink(name, link) - - if self.config.settings["use_watcher"]: - watchin = [] - for name in self.config.settings["use_watcher"]: - watch_link = Link(name + "_watch", name, "Watcher") - self.assignLink(name + ".watchout", watch_link) - watchin.append(watch_link) - self.createWatcher(watchin) + # for name, link in self.data_queues.items(): + # self.assign_link(name, link) def configure_redis_persistence(self): # invalid configs: specifying filename and using an ephemeral filename, # specifying that saving is off but providing either filename option - aof_dirname = self.config.get_redis_aof_dirname() - generate_unique_dirname = self.config.generate_ephemeral_aof_dirname() - redis_saving_enabled = self.config.redis_saving_enabled() - redis_fsync_frequency = self.config.get_redis_fsync_frequency() - - if aof_dirname and generate_unique_dirname: - logger.error( - "Cannot both generate a unique dirname and use the one provided." - ) - raise Exception("Cannot use unique dirname and use the one provided.") - - if aof_dirname or generate_unique_dirname or redis_fsync_frequency: - if redis_saving_enabled is None: - redis_saving_enabled = True - elif not redis_saving_enabled: - logger.error( - "Invalid configuration. Cannot save to disk with saving disabled." - ) - raise Exception("Cannot persist to disk with saving disabled.") - - self.redis_saving_enabled = redis_saving_enabled - - if redis_fsync_frequency and redis_fsync_frequency not in [ - "every_write", - "every_second", - "no_schedule", - ]: - logger.error("Cannot use unknown fsync frequency ", redis_fsync_frequency) - raise Exception( - "Cannot use unknown fsync frequency ", redis_fsync_frequency - ) - - if redis_fsync_frequency is None: - redis_fsync_frequency = "no_schedule" - - if redis_fsync_frequency == "every_write": - self.redis_fsync_frequency = "always" - elif redis_fsync_frequency == "every_second": - self.redis_fsync_frequency = "everysec" - elif redis_fsync_frequency == "no_schedule": - self.redis_fsync_frequency = "no" - else: - logger.error("Unknown fsync frequency ", redis_fsync_frequency) - raise Exception("Unknown fsync frequency ", redis_fsync_frequency) + aof_dirname = self.config.redis_config["aof_dirname"] + generate_unique_dirname = self.config.redis_config[ + "generate_ephemeral_aof_dirname" + ] + self.redis_saving_enabled = self.config.redis_config["enable_saving"] + self.redis_fsync_frequency = self.config.redis_config["fsync_frequency"] if aof_dirname: self.aof_dir = aof_dirname @@ -307,7 +301,7 @@ def configure_redis_persistence(self): return - def startNexus(self): + def start_nexus(self, serve_function, *args, **kwargs): """ Puts all actors in separate processes and begins polling to listen to comm queues @@ -321,21 +315,21 @@ def startNexus(self): p = ctx.Process(target=m.run, name=name) else: ctx = get_context("fork") - p = ctx.Process(target=self.runActor, name=name, args=(m,)) - if "Watcher" not in name: - if "daemon" in self.config.actors[name].options: - p.daemon = self.config.actors[name].options["daemon"] - logger.info("Setting daemon for {}".format(name)) - else: - p.daemon = True # default behavior + p = ctx.Process(target=self.run_actor, name=name, args=(m,)) + if "daemon" in self.config.actors[name].options: + p.daemon = self.config.actors[name].options["daemon"] + logger.info("Setting daemon for {}".format(name)) + else: + p.daemon = True # default behavior self.processes.append(p) self.start() loop = asyncio.get_event_loop() + res = "" try: self.out_socket.send_string("Awaiting input:") - res = loop.run_until_complete(self.pollQueues()) + res = loop.run_until_complete(self.serve(serve_function, *args, **kwargs)) except asyncio.CancelledError: logger.info("Loop is cancelled") @@ -346,10 +340,9 @@ def startNexus(self): logger.info(f"Current loop: {asyncio.get_event_loop()}") - loop.stop() - loop.close() + # loop.stop() + # loop.close() logger.info("Shutdown loop") - self.zmq_context.destroy() def start(self): """ @@ -363,30 +356,41 @@ def start(self): logger.info("All processes started") - def destroyNexus(self): + def destroy_nexus(self): """Method that calls the internal method - to kill the process running the store (plasma server) + to kill the processes running the store + and the message broker """ logger.warning("Destroying Nexus") - self._closeStoreInterface() + self._shutdown_harvester() + self._close_store_interface() - if hasattr(self, "store_loc"): - try: - os.remove(self.store_loc) - except FileNotFoundError: - logger.warning( - "StoreInterface file {} is already deleted".format(self.store_loc) - ) - logger.warning("Delete the store at location {0}".format(self.store_loc)) - - if hasattr(self, "out_socket"): + if self.out_socket: self.out_socket.close(linger=0) - if hasattr(self, "in_socket"): + if self.in_socket: self.in_socket.close(linger=0) - if hasattr(self, "zmq_context"): + if self.actor_in_socket: + self.actor_in_socket.close(linger=0) + if self.broker_in_socket: + self.broker_in_socket.close(linger=0) + if self.logger_in_socket: + self.logger_in_socket.close(linger=0) + + self._shutdown_broker() + self._shutdown_logger() + + for handler in logger.handlers: + # need to close this one since we're about to destroy the zmq context + if isinstance(handler, log.ZmqLogHandler): + handler.close() + logger.removeHandler(handler) + + if self.zmq_context: self.zmq_context.destroy(linger=0) + if self.zmq_sync_context: + self.zmq_sync_context.destroy(linger=0) - async def pollQueues(self): + async def poll_queues(self, poll_function, *args, **kwargs): """ Listens to links and processes their signals. @@ -400,6 +404,7 @@ async def pollQueues(self): string: "Shutting down", Notifies start() that pollQueues has completed. """ self.actorStates = dict.fromkeys(self.actors.keys()) + self.actor_states = dict.fromkeys(self.actors.keys(), None) if not self.config.hasGUI: # Since Visual is not started, it cannot send a ready signal. try: @@ -407,13 +412,11 @@ async def pollQueues(self): except Exception as e: logger.info("Visual is not started: {0}".format(e)) pass - polling = list(self.comm_queues.values()) - pollingNames = list(self.comm_queues.keys()) - self.tasks = [] - for q in polling: - self.tasks.append(asyncio.create_task(q.get_async())) + self.tasks = [] + self.tasks.append(asyncio.create_task(self.process_actor_message())) self.tasks.append(asyncio.create_task(self.remote_input())) + self.early_exit = False # add signal handlers @@ -421,40 +424,17 @@ async def pollQueues(self): signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT) for s in signals: loop.add_signal_handler( - s, lambda s=s: self.stop_polling_and_quit(s, polling) + s, lambda s=s: asyncio.create_task(self.stop_polling_and_quit(s)) ) + logger.info("Nexus signal handler added") + while not self.flags["quit"]: - try: - done, pending = await asyncio.wait( - self.tasks, return_when=concurrent.futures.FIRST_COMPLETED - ) - except asyncio.CancelledError: - pass + await poll_function(*args, **kwargs) - # sort through tasks to see where we got input from - # (so we can choose a handler) - for i, t in enumerate(self.tasks): - if i < len(polling): - if t in done or polling[i].status == "done": - # catch tasks that complete await wait/gather - r = polling[i].result - if r: - if "GUI" in pollingNames[i]: - self.processGuiSignal(r, pollingNames[i]) - else: - self.processActorSignal(r, pollingNames[i]) - self.tasks[i] = asyncio.create_task(polling[i].get_async()) - elif t in done: - logger.debug("t.result = " + str(t.result())) - self.tasks[i] = asyncio.create_task(self.remote_input()) - - if not self.early_exit: # don't run this again if we already have - self.stop_polling(Signal.quit(), polling) - logger.warning("Shutting down polling") return "Shutting Down" - def stop_polling_and_quit(self, signal, queues): + async def stop_polling_and_quit(self, signal): """ quit the process and stop polling signals from queues @@ -463,26 +443,108 @@ def stop_polling_and_quit(self, signal, queues): One of: signal.SIGHUP, signal.SIGTERM, signal.SIGINT queues (improv.link.AsyncQueue): Comm queues for links. """ - logger.warn( + logger.warning( "Shutting down via signal handler due to {}. \ Steps may be out of order or dirty.".format( signal ) ) - self.stop_polling(signal, queues) + await self.stop_polling(signal) + logger.info("Nexus waiting for async tasks to have a chance to send") + await asyncio.sleep(0) self.flags["quit"] = True self.early_exit = True self.quit() + def process_actor_state_update(self, msg: ActorStateMsg): + actor_state = None + if msg.actor_name in self.actor_states.keys(): + actor_state = self.actor_states[msg.actor_name] + if not actor_state: + logger.info( + f"Received state message from new actor {msg.actor_name}" + f" with info: {msg.info}\n" + ) + self.actor_states[msg.actor_name] = ActorState( + msg.actor_name, msg.status, msg.nexus_in_port + ) + actor_state = self.actor_states[msg.actor_name] + else: + logger.info( + f"Received state message from actor {msg.actor_name}" + f" with info: {msg.info}\n" + f"Current state:\n" + f"Name: {actor_state.actor_name}\n" + f"Status: {actor_state.status}\n" + f"Nexus control port: {actor_state.nexus_in_port}\n" + ) + if msg.nexus_in_port != actor_state.nexus_in_port: + pass + # TODO: this actor's signal socket changed + + actor_state.actor_name = msg.actor_name + actor_state.status = msg.status + actor_state.nexus_in_port = msg.nexus_in_port + + logger.info( + "Updated actor state:\n" + f"Name: {actor_state.actor_name}\n" + f"Status: {actor_state.status}\n" + f"Nexus control port: {actor_state.nexus_in_port}\n" + ) + + if all( + [ + actor_state is not None and actor_state.status == Signal.ready() + for actor_state in self.actor_states.values() + ] + ): + self.allowStart = True + + return True + + async def process_actor_message(self): + msg = await self.actor_in_socket.recv_pyobj() + if isinstance(msg, ActorStateMsg): + if self.process_actor_state_update(msg): + await self.actor_in_socket.send_pyobj( + ActorStateReplyMsg( + msg.actor_name, "OK", "actor state updated successfully" + ) + ) + else: + await self.actor_in_socket.send_pyobj( + ActorStateReplyMsg( + msg.actor_name, "ERROR", "actor state update failed" + ) + ) + if (not self.allow_setup) and all( + [ + actor_state is not None and actor_state.status == Signal.waiting() + for actor_state in self.actor_states.values() + ] + ): + logger.info("All actors connected to Nexus. Allowing setup.") + self.allow_setup = True + + if (not self.allowStart) and all( + [ + actor_state is not None and actor_state.status == Signal.ready() + for actor_state in self.actor_states.values() + ] + ): + logger.info("All actors ready. Allowing run.") + self.allowStart = True + async def remote_input(self): msg = await self.in_socket.recv_multipart() command = msg[0].decode("utf-8") await self.in_socket.send_string("Awaiting input:") if command == Signal.quit(): await self.out_socket.send_string("QUIT") - self.processGuiSignal([command], "TUI_Nexus") + await self.process_gui_signal([command], "TUI_Nexus") - def processGuiSignal(self, flag, name): + async def process_gui_signal(self, flag, name): """Receive flags from the Front End as user input""" name = name.split("_")[0] if flag: @@ -490,20 +552,24 @@ def processGuiSignal(self, flag, name): if flag[0] == Signal.run(): logger.info("Begin run!") # self.flags['run'] = True - self.run() + await self.run() elif flag[0] == Signal.setup(): logger.info("Running setup") - self.setup() + await self.setup() elif flag[0] == Signal.ready(): logger.info("GUI ready") self.actorStates[name] = flag[0] elif flag[0] == Signal.quit(): logger.warning("Quitting the program!") + task = asyncio.create_task(self.stop_polling_and_quit(Signal.quit())) + done, pending = await asyncio.wait(task) + while len(done) == 0: + done, pending = await asyncio.wait(task) self.flags["quit"] = True - self.quit() elif flag[0] == Signal.load(): logger.info("Loading Config config from file " + flag[1]) - self.loadConfig(flag[1]) + self.config = Config(flag[1]) + self.config.parse_config() elif flag[0] == Signal.pause(): logger.info("Pausing processes") # TODO. Also resume, reset @@ -526,20 +592,19 @@ def processGuiSignal(self, flag, name): p = ctx.Process(target=m.run, name=name) else: ctx = get_context("fork") - p = ctx.Process(target=self.runActor, name=name, args=(m,)) - if "Watcher" not in name: - if "daemon" in actor.options: - p.daemon = actor.options["daemon"] - logger.info("Setting daemon for {}".format(name)) - else: - p.daemon = True + p = ctx.Process(target=self.run_actor, name=name, args=(m,)) + if "daemon" in actor.options: + p.daemon = actor.options["daemon"] + logger.info("Setting daemon for {}".format(name)) + else: + p.daemon = True # Setting the stores for each actor to be the same # TODO: test if this works for fork -- don't think it does? al = [act for act in self.actors.values() if act.name != pro.name] - m.setStoreInterface(al[0].client) + m.set_store_interface(al[0].client) m.client = None - m._getStoreInterface() + m._get_store_interface() self.processes.append(p) p.start() @@ -550,45 +615,33 @@ def processGuiSignal(self, flag, name): self.processes = [p for p in list(self.processes) if p.exitcode is None] elif flag[0] == Signal.stop(): logger.info("Nexus received stop signal") - self.stop() + await self.stop() elif flag: logger.error("Unknown signal received from Nexus: {}".format(flag)) - def processActorSignal(self, sig, name): - if sig is not None: - logger.info("Received signal " + str(sig[0]) + " from " + name) - state_val = self.actorStates.values() - if not self.stopped and sig[0] == Signal.ready(): - self.actorStates[name.split("_")[0]] = sig[0] - if all(val == Signal.ready() for val in state_val): - self.allowStart = True - # TODO: replace with q_sig to FE/Visual - logger.info("Allowing start") - - elif self.stopped and sig[0] == Signal.stop_success(): - self.actorStates[name.split("_")[0]] = sig[0] - if all(val == Signal.stop_success() for val in state_val): - self.allowStart = True # TODO: replace with q_sig to FE/Visual - self.stoppped = False - logger.info("All stops were successful. Allowing start.") - - def setup(self): - for q in self.sig_queues.values(): - try: - logger.info("Starting setup: " + str(q)) - q.put_nowait(Signal.setup()) - except Full: - logger.warning("Signal queue" + q.name + "is full") + async def setup(self): + if not self.allow_setup: + logger.error( + "Not all actors connected to Nexus. Please wait, then try again." + ) + return + + for actor in self.actor_states.values(): + logger.info("Starting setup: " + str(actor.actor_name)) + actor.sig_socket = self.zmq_context.socket(REQ) + actor.sig_socket.connect(f"tcp://{actor.hostname}:{actor.nexus_in_port}") + await actor.sig_socket.send_pyobj( + ActorSignalMsg(actor.actor_name, Signal.setup(), "") + ) + await actor.sig_socket.recv_pyobj() - def run(self): + async def run(self): if self.allowStart: - for q in self.sig_queues.values(): - try: - q.put_nowait(Signal.run()) - except Full: - logger.warning("Signal queue" + q.name + "is full") - # queue full, keep going anyway - # TODO: add repeat trying as async task + for actor in self.actor_states.values(): + await actor.sig_socket.send_pyobj( + ActorSignalMsg(actor.actor_name, Signal.run(), "") + ) + await actor.sig_socket.recv_pyobj() else: logger.error("Not all actors ready yet, please wait and then try again.") @@ -596,43 +649,54 @@ def quit(self): logger.warning("Killing child processes") self.out_socket.send_string("QUIT") - for q in self.sig_queues.values(): - try: - q.put_nowait(Signal.quit()) - except Full: - logger.warning("Signal queue {} full, cannot quit".format(q.name)) - except FileNotFoundError: - logger.warning("Queue {} corrupted.".format(q.name)) - if self.config.hasGUI: self.processes.append(self.p_GUI) - if self.p_watch: - self.processes.append(self.p_watch) - for p in self.processes: p.terminate() - p.join() + p.join(timeout=5) + if p.exitcode is None: + p.kill() + logger.error("Process did not exit in time. Kill signal sent.") logger.warning("Actors terminated") - self.destroyNexus() + self.destroy_nexus() - def stop(self): + async def stop(self): logger.warning("Starting stop procedure") self.allowStart = False - for q in self.sig_queues.values(): + for actor in self.actor_states.values(): try: - q.put_nowait(Signal.stop()) - except Full: - logger.warning("Signal queue" + q.name + "is full") + await actor.sig_socket.send_pyobj( + ActorSignalMsg( + actor.actor_name, Signal.stop(), "Nexus sending stop signal" + ) + ) + msg_ready = await actor.sig_socket.poll(timeout=1000) + if msg_ready == 0: + raise TimeoutError + await actor.sig_socket.recv_pyobj() + except TimeoutError: + logger.info( + f"Timed out waiting for reply to stop message " + f"from actor {actor.actor_name}. " + f"Closing connection." + ) + actor.sig_socket.close(linger=0) + except Exception as e: + logger.info( + f"Unable to send stop message " + f"to actor {actor.actor_name}: " + f"{e}" + ) self.allowStart = True def revive(self): logger.warning("Starting revive") - def stop_polling(self, stop_signal, queues): + async def stop_polling(self, stop_signal): """Cancels outstanding tasks and fills their last request. Puts a string into all active queues, then cancels their @@ -647,11 +711,30 @@ def stop_polling(self, stop_signal, queues): logger.info(f"Stop signal: {stop_signal}") shutdown_message = Signal.quit() - for q in queues: + for actor in self.actor_states.values(): try: - q.put(shutdown_message) - except Exception: - logger.info("Unable to send shutdown message to {}.".format(q.name)) + await actor.sig_socket.send_pyobj( + ActorSignalMsg( + actor.actor_name, shutdown_message, "Nexus sending quit signal" + ) + ) + msg_ready = await actor.sig_socket.poll(timeout=1000) + if msg_ready == 0: + raise TimeoutError + await actor.sig_socket.recv_pyobj() + except TimeoutError: + logger.info( + f"Timed out waiting for reply to quit message " + f"from actor {actor.actor_name}. " + f"Closing connection." + ) + actor.sig_socket.close(linger=0) + except Exception as e: + logger.info( + f"Unable to send shutdown message " + f"to actor {actor.actor_name}: " + f"{e}" + ) logger.info("Canceling outstanding tasks") @@ -659,19 +742,14 @@ def stop_polling(self, stop_signal, queues): logger.info("Polling has stopped.") - def createStoreInterface(self, name): + def create_store_interface(self): """Creates StoreInterface""" - if self.config.use_plasma(): - return PlasmaStoreInterface(name, self.store_loc) - else: - return RedisStoreInterface(server_port_num=self.store_port) + return RedisStoreInterface(server_port_num=self.store_port) - def _startStoreInterface(self, size, attempts=20): - """Start a subprocess that runs the plasma store - Raises a RuntimeError exception size is undefined - Raises an Exception if the plasma store doesn't start - - #TODO: Generalize this to non-plasma stores + def _start_store_interface(self, size, attempts=20): + """Start a subprocess that runs the redis store + Raises a RuntimeError exception if size is undefined + Raises an Exception if the redis store doesn't start Args: size: in bytes @@ -683,64 +761,27 @@ def _startStoreInterface(self, size, attempts=20): """ if size is None: raise RuntimeError("Server size needs to be specified") - self.use_plasma = False - if self.config and self.config.use_plasma(): - self.use_plasma = True - self.store_loc = str(os.path.join("/tmp/", str(uuid.uuid4()))) - self.p_StoreInterface = subprocess.Popen( - [ - "plasma_store", - "-s", - self.store_loc, - "-m", - str(size), - "-e", - "hashtable://test", - ], - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - ) - logger.info("StoreInterface start successful: {}".format(self.store_loc)) - else: - logger.info("Setting up Redis store.") - self.store_port = ( - self.config.get_redis_port() - if self.config and self.config.redis_port_specified() - else Config.get_default_redis_port() + + logger.info("Setting up Redis store.") + self.store_port = self.config.redis_config["port"] if self.config else 6379 + logger.info("Searching for open port starting at specified port.") + for attempt in range(attempts): + logger.info( + "Attempting to connect to Redis on port {}".format(self.store_port) ) - if self.config and self.config.redis_port_specified(): - logger.info( - "Attempting to connect to Redis on port {}".format(self.store_port) - ) - # try with failure, incrementing port number - self.p_StoreInterface = self.start_redis(size) - time.sleep(3) - if self.p_StoreInterface.poll(): - logger.error("Could not start Redis on specified port number.") - raise Exception("Could not start Redis on specified port.") + # try with failure, incrementing port number + self.p_StoreInterface = self.start_redis(size) + time.sleep(1) + if self.p_StoreInterface.poll(): # Redis could not start + logger.info("Could not connect to port {}".format(self.store_port)) + self.store_port = str(int(self.store_port) + 1) else: - logger.info("Redis port not specified. Searching for open port.") - for attempt in range(attempts): - logger.info( - "Attempting to connect to Redis on port {}".format( - self.store_port - ) - ) - # try with failure, incrementing port number - self.p_StoreInterface = self.start_redis(size) - time.sleep(3) - if self.p_StoreInterface.poll(): # Redis could not start - logger.info( - "Could not connect to port {}".format(self.store_port) - ) - self.store_port = str(int(self.store_port) + 1) - else: - break - else: - logger.error("Could not start Redis on any tried port.") - raise Exception("Could not start Redis on any tried ports.") + break + else: + logger.error("Could not start Redis on any tried port.") + raise Exception("Could not start Redis on any tried ports.") - logger.info(f"StoreInterface start successful on port {self.store_port}") + logger.info(f"StoreInterface start successful on port {self.store_port}") def start_redis(self, size): subprocess_command = [ @@ -788,26 +829,27 @@ def start_redis(self, size): stderr=subprocess.DEVNULL, ) - def _closeStoreInterface(self): + def _close_store_interface(self): """Internal method to kill the subprocess - running the store (plasma sever) + running the store """ if hasattr(self, "p_StoreInterface"): try: self.p_StoreInterface.send_signal(signal.SIGINT) - self.p_StoreInterface.wait() - logger.info( - "StoreInterface close successful: {}".format( - self.store_loc - if self.config and self.config.use_plasma() - else self.store_port + try: + self.p_StoreInterface.wait(timeout=30) + logger.info( + "StoreInterface close successful: {}".format(self.store_port) ) - ) + except subprocess.TimeoutExpired as e: + logger.error(e) + self.p_StoreInterface.send_signal(signal.SIGKILL) + logger.info("Killed datastore process") except Exception as e: logger.exception("Cannot close store {}".format(e)) - def createActor(self, name, actor): + def create_actor(self, name, actor): """Function to instantiate actor, add signal and comm Links, and update self.actors dictionary @@ -818,35 +860,46 @@ def createActor(self, name, actor): # Instantiate selected class mod = import_module(actor.packagename) clss = getattr(mod, actor.classname) - if self.config.use_plasma(): - instance = clss(actor.name, store_loc=self.store_loc, **actor.options) - else: - instance = clss(actor.name, store_port_num=self.store_port, **actor.options) + outgoing_links = ( + self.outgoing_topics[actor.name] + if actor.name in self.outgoing_topics + else [] + ) + incoming_links = ( + self.incoming_topics[actor.name] + if actor.name in self.incoming_topics + else [] + ) + instance = clss( + name=actor.name, + nexus_comm_port=self.actor_in_socket_port, + broker_sub_port=self.broker_sub_port, + broker_pub_port=self.broker_pub_port, + log_pull_port=self.logger_pull_port, + outgoing_links=outgoing_links, + incoming_links=incoming_links, + store_port_num=self.store_port, + **actor.options, + ) if "method" in actor.options.keys(): # check for spawn if "fork" == actor.options["method"]: # Add link to StoreInterface store - store = self.createStoreInterface(actor.name) - instance.setStoreInterface(store) + store = self.create_store_interface() + instance.set_store_interface(store) else: # spawn or forkserver; can't pickle plasma store logger.info("No store for this actor yet {}".format(name)) else: # Add link to StoreInterface store - store = self.createStoreInterface(actor.name) - instance.setStoreInterface(store) - - q_comm = Link(actor.name + "_comm", actor.name, self.name) - q_sig = Link(actor.name + "_sig", self.name, actor.name) - self.comm_queues.update({q_comm.name: q_comm}) - self.sig_queues.update({q_sig.name: q_sig}) - instance.setCommLinks(q_comm, q_sig) + store = self.create_store_interface() + instance.set_store_interface(store) # Update information self.actors.update({name: instance}) - def runActor(self, actor): + def run_actor(self, actor: Actor): """Run the actor continually; used for separate processes #TODO: hook into monitoring here? @@ -855,26 +908,45 @@ def runActor(self, actor): """ actor.run() - def createConnections(self): - """Assemble links (multi or other) - for later assignment - """ - for source, drain in self.config.connections.items(): - name = source.split(".")[0] - # current assumption is connection goes from q_out to something(s) else - if len(drain) > 1: # we need multiasyncqueue - link, endLinks = MultiLink(name + "_multi", source, drain) - self.data_queues.update({source: link}) - for i, e in enumerate(endLinks): - self.data_queues.update({drain[i]: e}) - else: # single input, single output - d = drain[0] - d_name = d.split(".") # TODO: check if .anything, if not assume q_in - link = Link(name + "_" + d_name[0], source, d) - self.data_queues.update({source: link}) - self.data_queues.update({d: link}) - - def assignLink(self, name, link): + def create_connections(self): + for name, connection in self.config.connections.items(): + sources = connection["sources"] + if not isinstance(sources, list): + sources = [sources] + sinks = connection["sinks"] + if not isinstance(sinks, list): + sinks = [sinks] + + name_input = tuple( + ["inputs:"] + + [name for name in sources] + + ["outputs:"] + + [name for name in sinks] + ) + name = str( + hash(name_input) + ) # space-efficient key for uniquely referring to this connection + logger.info(f"Created link name {name} for link {name_input}") + + for source in sources: + source_actor = source.split(".")[0] + source_link = source.split(".")[1] + if source_actor not in self.outgoing_topics.keys(): + self.outgoing_topics[source_actor] = [LinkInfo(source_link, name)] + else: + self.outgoing_topics[source_actor].append( + LinkInfo(source_link, name) + ) + + for sink in sinks: + sink_actor = sink.split(".")[0] + sink_link = sink.split(".")[1] + if sink_actor not in self.incoming_topics.keys(): + self.incoming_topics[sink_actor] = [LinkInfo(sink_link, name)] + else: + self.incoming_topics[sink_actor].append(LinkInfo(sink_link, name)) + + def assign_link(self, name, link): """Function to set up Links between actors for data location passing Actor must already be instantiated @@ -886,24 +958,283 @@ def assignLink(self, name, link): classname = name.split(".")[0] linktype = name.split(".")[1] if linktype == "q_out": - self.actors[classname].setLinkOut(link) + self.actors[classname].set_link_out(link) elif linktype == "q_in": - self.actors[classname].setLinkIn(link) + self.actors[classname].set_link_in(link) elif linktype == "watchout": - self.actors[classname].setLinkWatch(link) + self.actors[classname].set_link_watch(link) else: - self.actors[classname].addLink(linktype, link) + self.actors[classname].add_link(linktype, link) + + def start_logger(self, log_server_pub_port): + spawn_context = get_context("spawn") + self.p_logger = spawn_context.Process( + target=bootstrap_log_server, + args=( + "localhost", + self.logger_in_port, + self.logfile, + log_server_pub_port, + ), + ) + logger.debug("logger created") + self.p_logger.start() + time.sleep(1) + logger.debug("logger started") + if not self.p_logger.is_alive(): + logger.error( + "Logger process failed to start. " + "Please see the log server log file for more information. " + "The improv server will now exit." + ) + self.quit() + raise Exception("Could not start log server.") + logger.debug("logger is alive") + poll_res = self.logger_in_socket.poll(timeout=5000) + if poll_res == 0: + logger.error( + "Never got reply from logger. Cannot proceed setting up Nexus." + ) + try: + with open("log_server.log", "r") as file: + logger.debug(file.read()) + except Exception as e: + logger.error(e) + self.destroy_nexus() + logger.error("exiting after destroy") + exit(1) + logger_info: LogInfoMsg = self.logger_in_socket.recv_pyobj() + self.logger_pull_port = logger_info.pull_port + self.logger_pub_port = logger_info.pub_port + self.logger_in_socket.send_pyobj( + LogInfoReplyMsg(logger_info.name, "OK", "registered logger information") + ) + logger.debug("logger replied with setup message") - # TODO: StoreInterface access here seems wrong, need to test - def startWatcher(self): - from improv.watcher import Watcher + def start_message_broker(self): + spawn_context = get_context("spawn") + self.p_broker = spawn_context.Process( + target=bootstrap_broker, args=("localhost", self.broker_in_port) + ) + logger.debug("broker created") + self.p_broker.start() + time.sleep(1) + logger.debug("broker started") + if not self.p_broker.is_alive(): + logger.error( + "Broker process failed to start. " + "Please see the log file for more information. " + "The improv server will now exit." + ) + self.quit() + raise Exception("Could not start message broker server.") + logger.debug("broker is alive") + poll_res = self.broker_in_socket.poll(timeout=5000) + if poll_res == 0: + logger.error("Never got reply from broker. Cannot proceed.") + try: + with open("broker_server.log", "r") as file: + logger.debug(file.read()) + except Exception as e: + logger.error(e) + self.destroy_nexus() + logger.debug("exiting after destroy") + exit(1) + broker_info: BrokerInfoMsg = self.broker_in_socket.recv_pyobj() + self.broker_sub_port = broker_info.sub_port + self.broker_pub_port = broker_info.pub_port + self.broker_in_socket.send_pyobj( + BrokerInfoReplyMsg(broker_info.name, "OK", "registered broker information") + ) + logger.debug("broker replied with setup message") - self.watcher = Watcher("watcher", self.createStoreInterface("watcher")) - q_sig = Link("watcher_sig", self.name, "watcher") - self.watcher.setLinks(q_sig) - self.sig_queues.update({q_sig.name: q_sig}) + def _shutdown_broker(self): + """Internal method to kill the subprocess + running the message broker + """ + if self.p_broker: + try: + self.p_broker.terminate() + self.p_broker.join(timeout=5) + if self.p_broker.exitcode is None: + self.p_broker.kill() + logger.error("Killed broker process") + else: + logger.info( + "Broker shutdown successful with exit code {}".format( + self.p_broker.exitcode + ) + ) + except Exception as e: + logger.exception(f"Unable to close broker {e}") + + def _shutdown_logger(self): + """Internal method to kill the subprocess + running the logger + """ + if self.p_logger: + try: + self.p_logger.terminate() + self.p_logger.join(timeout=5) + if self.p_logger.exitcode is None: + self.p_logger.kill() + logger.error("Killed logger process") + else: + logger.info("Logger shutdown successful") + except Exception as e: + logger.exception(f"Unable to close logger: {e}") - self.p_watch = Process(target=self.watcher.run, name="watcher_process") - self.p_watch.daemon = True - self.p_watch.start() - self.processes.append(self.p_watch) + def _shutdown_harvester(self): + """Internal method to kill the subprocess + running the logger + """ + if self.p_harvester: + try: + self.p_harvester.terminate() + self.p_harvester.join(timeout=5) + if self.p_harvester.exitcode is None: + self.p_harvester.kill() + logger.error("Killed harvester process") + else: + logger.info("Harvester shutdown successful") + except Exception as e: + logger.exception(f"Unable to close harvester: {e}") + + def set_up_sockets(self, actor_in_port): + + logger.debug("Connecting to output") + cfg = self.config.settings # this could be self.settings instead + self.zmq_context = zmq.Context() + self.zmq_context.setsockopt(SocketOption.LINGER, 0) + self.out_socket = self.zmq_context.socket(PUB) + self.out_socket.bind("tcp://*:%s" % cfg["output_port"]) + out_port_string = self.out_socket.getsockopt_string(SocketOption.LAST_ENDPOINT) + cfg["output_port"] = int(out_port_string.split(":")[-1]) + + logger.debug("Connecting to control") + self.in_socket = self.zmq_context.socket(REP) + self.in_socket.bind("tcp://*:%s" % cfg["control_port"]) + in_port_string = self.in_socket.getsockopt_string(SocketOption.LAST_ENDPOINT) + cfg["control_port"] = int(in_port_string.split(":")[-1]) + + logger.debug("Connecting to actor comm socket") + self.actor_in_socket = self.zmq_context.socket(REP) + self.actor_in_socket.bind(f"tcp://*:{actor_in_port}") + in_port_string = self.actor_in_socket.getsockopt_string( + SocketOption.LAST_ENDPOINT + ) + self.actor_in_socket_port = int(in_port_string.split(":")[-1]) + + logger.debug("Setting up sync server startup socket") + self.zmq_sync_context = zmq_sync.Context() + self.zmq_sync_context.setsockopt(SocketOption.LINGER, 0) + + self.logger_in_socket = self.zmq_sync_context.socket(REP) + self.logger_in_socket.bind("tcp://*:0") + logger_in_port_string = self.logger_in_socket.getsockopt_string( + SocketOption.LAST_ENDPOINT + ) + self.logger_in_port = int(logger_in_port_string.split(":")[-1]) + + self.broker_in_socket = self.zmq_sync_context.socket(REP) + self.broker_in_socket.bind("tcp://*:0") + broker_in_port_string = self.broker_in_socket.getsockopt_string( + SocketOption.LAST_ENDPOINT + ) + self.broker_in_port = int(broker_in_port_string.split(":")[-1]) + + def start_improv_services(self, log_server_pub_port, store_size): + logger.debug("Starting logger") + self.start_logger(log_server_pub_port) + logger.addHandler( + log.ZmqLogHandler("localhost", self.logger_pull_port, self.zmq_sync_context) + ) + + logger.debug("starting broker") + self.start_message_broker() + + logger.debug("Parsing redis persistence") + self.configure_redis_persistence() + + logger.debug("Starting redis server") + # default size should be system-dependent + self._start_store_interface(store_size) + logger.info("Redis server started") + + self.out_socket.send_string("StoreInterface started") + + if self.config.settings["harvest_data_from_memory"]: + logger.debug("starting harvester") + self.start_harvester() + + # connect to store and subscribe to notifications + logger.info("Create new store object") + self.store = StoreInterface(server_port_num=self.store_port) + logger.info(f"Redis server connected on port {self.store_port}") + logger.info("all services started") + + def apply_cli_config_overrides( + self, store_size, control_port, output_port, actor_in_port + ): + if store_size is not None: + self.config.settings["store_size"] = store_size + if control_port is not None: + self.config.settings["control_port"] = control_port + if output_port is not None: + self.config.settings["output_port"] = output_port + if actor_in_port is not None: + self.config.settings["actor_in_port"] = actor_in_port + + def start_harvester(self): + spawn_context = get_context("spawn") + self.p_harvester = spawn_context.Process( + target=bootstrap_harvester, + args=( + "localhost", + self.broker_in_port, + "localhost", + self.store_port, + "localhost", + self.broker_pub_port, + "localhost", + self.logger_pull_port, + ), + ) + self.p_harvester.start() + time.sleep(1) + if not self.p_harvester.is_alive(): + logger.error( + "Harvester process failed to start. " + "Please see the log file for more information. " + "The improv server will now exit." + ) + self.quit() + raise Exception("Could not start harvester server.") + + harvester_info: HarvesterInfoMsg = self.broker_in_socket.recv_pyobj() + self.broker_in_socket.send_pyobj( + HarvesterInfoReplyMsg( + harvester_info.name, "OK", "registered harvester information" + ) + ) + logger.info("Harvester server started") + + async def serve(self, serve_function, *args, **kwargs): + await serve_function(*args, **kwargs) + + async def poll_kernel(self): + try: + done, pending = await asyncio.wait( + self.tasks, return_when=concurrent.futures.FIRST_COMPLETED + ) + except asyncio.CancelledError: + pass + + # sort through tasks to see where we got input from + # (so we can choose a handler) + for i, t in enumerate(self.tasks): + if i == 0: + if t in done: + self.tasks[i] = asyncio.create_task(self.process_actor_message()) + elif t in done: + self.tasks[i] = asyncio.create_task(self.remote_input()) diff --git a/improv/store.py b/improv/store.py index 98f54a63..86e9cf30 100644 --- a/improv/store.py +++ b/improv/store.py @@ -3,21 +3,14 @@ import pickle import logging -import traceback - -import numpy as np -import pyarrow.plasma as plasma +import zlib from redis import Redis from redis.retry import Retry from redis.backoff import ConstantBackoff from redis.exceptions import BusyLoadingError, ConnectionError, TimeoutError -from scipy.sparse import csc_matrix -from pyarrow.lib import ArrowIOError -from pyarrow._plasma import PlasmaObjectExists, ObjectNotAvailable - -REDIS_GLOBAL_TOPIC = "global_topic" +ZLIB_COMPRESSION_LEVEL = -1 logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -43,17 +36,21 @@ def subscribe(self): class RedisStoreInterface(StoreInterface): - def __init__(self, name="default", server_port_num=6379, hostname="localhost"): + def __init__( + self, + name="default", + server_port_num=6379, + hostname="localhost", + compression_level=ZLIB_COMPRESSION_LEVEL, + ): self.name = name self.server_port_num = server_port_num self.hostname = hostname self.client = self.connect_to_server() + self.compression_level = compression_level def connect_to_server(self): - # TODO this should scan for available ports, but only if configured to do so. - # This happens when the config doesn't have Redis settings, - # so we need to communicate this somehow to the StoreInterface here. - """Connect to the store at store_loc, max 20 retries to connect + """Connect to the store, max 20 retries to connect Raises exception if can't connect Returns the Redis client if successful @@ -107,18 +104,11 @@ def put(self, object): object: the object that was a """ object_key = str(os.getpid()) + str(uuid.uuid4()) - try: - # buffers would theoretically go here if we need to force out-of-band - # serialization for single objects - # TODO this will actually just silently fail if we use an existing - # TODO key; not sure it's worth the network overhead to check every - # TODO key twice every time. we still need a better solution for - # TODO this, but it will work now singlethreaded most of the time. - - self.client.set(object_key, pickle.dumps(object, protocol=5), nx=True) - except Exception: - logger.error("Could not store object {}".format(object_key)) - logger.error(traceback.format_exc()) + data = zlib.compress( + pickle.dumps(object, protocol=5), level=self.compression_level + ) + + self.client.set(object_key, data, nx=True) return object_key @@ -138,14 +128,10 @@ def get(self, object_key): object_value = self.client.get(object_key) if object_value: # buffers would also go here to force out-of-band deserialization - return pickle.loads(object_value) + return pickle.loads(zlib.decompress(object_value)) logger.warning("Object {} cannot be found.".format(object_key)) - raise ObjectNotFoundError - - def subscribe(self, topic=REDIS_GLOBAL_TOPIC): - p = self.client.pubsub() - p.subscribe(topic) + raise ObjectNotFoundError(object_key) def get_list(self, ids): """Get multiple objects from the store @@ -156,7 +142,10 @@ def get_list(self, ids): Returns: list of the objects """ - return self.client.mget(ids) + return [ + pickle.loads(zlib.decompress(object_value)) + for object_value in self.client.mget(ids) + ] def get_all(self): """Get a listing of all objects in the store. @@ -166,7 +155,7 @@ def get_all(self): list of all the objects in the store """ all_keys = self.client.keys() # defaults to "*" pattern, so will fetch all - return self.client.mget(all_keys) + return self.get_list(all_keys) def reset(self): """Reset client connection""" @@ -175,226 +164,6 @@ def reset(self): "Reset local connection to store on port: {0}".format(self.server_port_num) ) - def notify(self): - pass # I don't see any call sites for this, so leaving it blank at the moment - - -class PlasmaStoreInterface(StoreInterface): - """Basic interface for our specific data store implemented with apache arrow plasma - Objects are stored with object_ids - References to objects are contained in a dict where key is shortname, - value is object_id - """ - - def __init__(self, name="default", store_loc="/tmp/store"): - """ - Constructor for the StoreInterface - - :param name: - :param store_loc: Apache Arrow Plasma client location - """ - - self.name = name - self.store_loc = store_loc - self.client = self.connect_store(store_loc) - self.stored = {} - - def connect_store(self, store_loc): - """Connect to the store at store_loc, max 20 retries to connect - Raises exception if can't connect - Returns the plasmaclient if successful - Updates the client internal - - Args: - store_loc: store location - """ - try: - self.client = plasma.connect(store_loc, 20) - logger.info("Successfully connected to store: {} ".format(store_loc)) - except Exception: - logger.exception("Cannot connect to store: {}".format(store_loc)) - raise CannotConnectToStoreInterfaceError(store_loc) - return self.client - - def put(self, object, object_name): - """ - Put a single object referenced by its string name - into the store - Raises PlasmaObjectExists if we are overwriting - Unknown error - - Args: - object: - object_name (str): - flush_this_immediately (bool): - - Returns: - class 'plasma.ObjectID': Plasma object ID - - Raises: - PlasmaObjectExists: if we are overwriting \ - unknown error - """ - object_id = None - try: - # Need to pickle if object is csc_matrix - if isinstance(object, csc_matrix): - prot = pickle.HIGHEST_PROTOCOL - object_id = self.client.put(pickle.dumps(object, protocol=prot)) - else: - object_id = self.client.put(object) - - except PlasmaObjectExists: - logger.error("Object already exists. Meant to call replace?") - except ArrowIOError: - logger.error("Could not store object {}".format(object_name)) - logger.info("Refreshing connection and continuing") - self.reset() - except Exception: - logger.error("Could not store object {}".format(object_name)) - logger.error(traceback.format_exc()) - - return object_id - - def get(self, object_name): - """Get a single object from the store by object name - Checks to see if it knows the object first - Otherwise throw CannotGetObject to request dict update - TODO: update for lists of objects - TODO: replace with getID - - Returns: - Stored object - """ - # print('trying to get ', object_name) - # if self.stored.get(object_name) is None: - # logger.error('Never recorded storing this object: '+object_name) - # # Don't know anything about this object, treat as problematic - # raise CannotGetObjectError(query = object_name) - # else: - return self.getID(object_name) - - def getID(self, obj_id): - """ - Get object by object ID - - Args: - obj_id (class 'plasma.ObjectID'): the id of the object - - Returns: - Stored object - - Raises: - ObjectNotFoundError: If the id is not found - """ - res = self.client.get(obj_id, 0) # Timeout = 0 ms - if res is not plasma.ObjectNotAvailable: - return res if not isinstance(res, bytes) else pickle.loads(res) - - logger.warning("Object {} cannot be found.".format(obj_id)) - raise ObjectNotFoundError - - def getList(self, ids): - """Get multiple objects from the store - - Args: - ids (list): of type plasma.ObjectID - - Returns: - list of the objects - """ - # self._get() - return self.client.get(ids) - - def get_all(self): - """Get a listing of all objects in the store - - Returns: - list of all the objects in the store - """ - return self.client.list() - - def reset(self): - """Reset client connection""" - self.client = self.connect_store(self.store_loc) - logger.debug("Reset local connection to store: {0}".format(self.store_loc)) - - def release(self): - self.client.disconnect() - - # Subscribe to notifications about sealed objects? - def subscribe(self): - """Subscribe to a section? of the ds for singals - - Raises: - Exception: Unknown error - """ - try: - self.client.subscribe() - except Exception as e: - logger.error("Unknown error: {}".format(e)) - raise Exception - - def notify(self): - try: - notification_info = self.client.get_next_notification() - # recv_objid, recv_dsize, recv_msize = notification_info - except ArrowIOError: - notification_info = None - except Exception as e: - logger.exception("Notification error: {}".format(e)) - raise Exception - - return notification_info - - # Necessary? plasma.ObjectID.from_random() - def random_ObjectID(self, number=1): - ids = [] - for i in range(number): - ids.append(plasma.ObjectID(np.random.bytes(20))) - return ids - - def updateStoreInterfaced(self, object_name, object_id): - """Update local dict with info we need locally - Report to Nexus that we updated the store - (did a put or delete/replace) - - Args: - object_name (str): the name of the object to update - object_id (): the id of the object to update - """ - self.stored.update({object_name: object_id}) - - def getStored(self): - """ - Returns: - its info about what it has stored - """ - return self.stored - - def _put(self, obj, id): - """Internal put""" - return self.client.put(obj, id) - - def _get(self, object_name): - """Get an object from the store using its name - Assumes we know the id for the object_name - - Raises: - ObjectNotFound: if object_id returns no object from the store - """ - # Most errors not shown to user. - # Maintain separation between external and internal function calls. - res = self.getID(self.stored.get(object_name)) - # Can also use contains() to check - - logger.warning("{}".format(object_name)) - if isinstance(res, ObjectNotAvailable): - logger.warning("Object {} cannot be found.".format(object_name)) - raise ObjectNotFoundError(obj_id_or_name=object_name) # TODO: Don't raise? - else: - return res - StoreInterface = RedisStoreInterface @@ -428,12 +197,12 @@ def __str__(self): class CannotConnectToStoreInterfaceError(Exception): """Raised when failing to connect to store.""" - def __init__(self, store_loc): + def __init__(self, store_port): super().__init__() self.name = "CannotConnectToStoreInterfaceError" - self.message = "Cannot connect to store at {}".format(str(store_loc)) + self.message = "Cannot connect to store at {}".format(str(store_port)) def __str__(self): return self.message diff --git a/improv/utils/checks.py b/improv/utils/checks.py index a72f4657..97c06f24 100644 --- a/improv/utils/checks.py +++ b/improv/utils/checks.py @@ -32,10 +32,17 @@ def check_if_connections_acyclic(path_to_yaml): # Need to keep only module names connections = {} - for key, values in raw.items(): - new_key = key.split(".")[0] - new_values = [value.split(".")[0] for value in values] - connections[new_key] = new_values + for connection, values in raw.items(): + for source in values["sources"]: + source_name = source.split(".")[0] + if source_name not in connections.keys(): + connections[source_name] = [ + sink_name.split(".")[0] for sink_name in values["sinks"] + ] + else: + connections[source_name] = connections[source_name] + [ + sink_name.split(".")[0] for sink_name in values["sinks"] + ] g = nx.DiGraph(connections) dag = nx.is_directed_acyclic_graph(g) diff --git a/improv/watcher.py b/improv/watcher.py deleted file mode 100644 index ad2dbf1f..00000000 --- a/improv/watcher.py +++ /dev/null @@ -1,175 +0,0 @@ -import numpy as np -import asyncio - -# import pyarrow.plasma as plasma -# from multiprocessing import Process, Queue, Manager, cpu_count, set_start_method -# import subprocess -# import signal -# import time -from queue import Empty -import logging - -import concurrent - -# from pyarrow.plasma import ObjectNotAvailable -from improv.actor import Actor, Signal, RunManager -from improv.store import ObjectNotFoundError -import pickle - -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - - -class BasicWatcher(Actor): - """ - Actor that monitors stored objects from the other actors - and saves objects that have been flagged by those actors - """ - - def __init__(self, *args, inputs=None): - super().__init__(*args) - - self.watchin = inputs - - def setup(self): - """ - set up tasks and polling based on inputs which will - be used for asynchronous polling of input queues - """ - self.numSaved = 0 - self.tasks = [] - self.polling = self.watchin - self.setUp = False - - def run(self): - """ - continually run the watcher to check all of the - input queues for objects to save - """ - - with RunManager( - self.name, self.watchrun, self.setup, self.q_sig, self.q_comm - ) as rm: - logger.info(rm) - - print("watcher saved " + str(self.numSaved) + " objects") - - def watchrun(self): - """ - set up async loop for polling - """ - loop = asyncio.get_event_loop() - loop.run_until_complete(self.watch()) - - async def watch(self): - """ - function for asynchronous polling of input queues - loops through each of the queues in watchin and checks - if an object is present and then saves the object if found - """ - - if self.setUp is False: - for q in self.polling: - self.tasks.append(asyncio.create_task(q.get_async())) - self.setUp = True - - done, pending = await asyncio.wait( - self.tasks, return_when=concurrent.futures.FIRST_COMPLETED - ) - - for i, t in enumerate(self.tasks): - if t in done or self.polling[i].status == "done": - r = self.polling[i].result # r is array with id and name of object - actorID = self.polling[ - i - ].getStart() # name of actor asking watcher to save the object - try: - obj = self.client.getID(r[0]) - np.save("output/saved/" + actorID + r[1], obj) - except ObjectNotFoundError as e: - logger.info(e.message) - pass - self.tasks[i] = asyncio.create_task(self.polling[i].get_async()) - - -class Watcher: - """Monitors the store as separate process - TODO: Facilitate Watcher being used in multiple processes (shared list) - """ - - # Related to subscribe - could be private, i.e., _subscribe - def __init__(self, name, client): - self.name = name - self.client = client - self.flag = False - self.saved_ids = [] - - self.client.subscribe() - self.n = 0 - - def setLinks(self, links): - self.q_sig = links - - def run(self): - while True: - if self.flag: - try: - self.checkStoreInterface2() - except Exception as e: - logger.error("Watcher exception during run: {}".format(e)) - # break - try: - signal = self.q_sig.get(timeout=0.005) - if signal == Signal.run(): - self.flag = True - logger.warning("Received run signal, begin running") - elif signal == Signal.quit(): - logger.warning("Received quit signal, aborting") - break - elif signal == Signal.pause(): - logger.warning("Received pause signal, pending...") - self.flag = False - elif signal == Signal.resume(): # currently treat as same as run - logger.warning("Received resume signal, resuming") - self.flag = True - except Empty: - pass # no signal from Nexus - - # def checkStoreInterface(self): - # notification_info = self.client.notify() - # recv_objid, recv_dsize, recv_msize = notification_info - # obj = self.client.getID(recv_objid) - # try: - # self.saveObj(obj) - # self.n += 1 - # except Exception as e: - # logger.error('Watcher error: {}'.format(e)) - - def saveObj(self, obj, name): - with open( - "/media/hawkwings/Ext Hard Drive/dump/dump" + name + ".pkl", "wb" - ) as output: - pickle.dump(obj, output) - - def checkStoreInterface2(self): - objs = list(self.client.get_all().keys()) - ids_to_save = list(set(objs) - set(self.saved_ids)) - - # with Pool() as pool: - # saved_ids = pool.map(saveObjbyID, ids_to_save) - # print('Saved :', len(saved_ids)) - # self.saved_ids.extend(saved_ids) - - for id in ids_to_save: - self.saveObj(self.client.getID(id), str(id)) - self.saved_ids.append(id) - - -# def saveObjbyID(id): -# client = plasma.connect('/tmp/store') -# obj = client.get(id) -# with open( -# '/media/hawkwings/Ext\ Hard\ Drive/dump/dump'+str(id)+'.pkl', 'wb' -# ) as output: -# pickle.dump(obj, output) -# return id diff --git a/pyproject.toml b/pyproject.toml index bf5a42b7..fd185aad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,6 @@ dependencies = [ "numpy<=1.26", "scipy", "matplotlib", - "pyarrow==9.0.0", "PyQt5", "pyyaml", "textual==0.15.0", @@ -56,6 +55,7 @@ exclude = ["test", "pytest", "env", "demos", "figures"] [tool.pytest.ini_options] asyncio_mode = "auto" filterwarnings = [ ] +asyncio_default_fixture_loop_scope = "function" log_cli = true log_cli_level = "INFO" diff --git a/test/actors/actor_for_bad_args.py b/test/actors/actor_for_bad_args.py new file mode 100644 index 00000000..bac95922 --- /dev/null +++ b/test/actors/actor_for_bad_args.py @@ -0,0 +1,3 @@ +class BadActorArgs: + def __init__(self, existent_parameter, another_required_param): + pass diff --git a/test/actors/sample_generator.py b/test/actors/sample_generator.py index 9d1c3891..d57a5b28 100644 --- a/test/actors/sample_generator.py +++ b/test/actors/sample_generator.py @@ -1,4 +1,4 @@ -from improv.actor import Actor +from improv.actor import ZmqActor from datetime import date # used for saving import numpy as np import logging @@ -7,7 +7,7 @@ logger.setLevel(logging.INFO) -class Generator(Actor): +class Generator(ZmqActor): """Sample actor to generate data to pass into a sample processor. Intended for use along with sample_processor.py. diff --git a/test/actors/sample_generator_wrong_init.py b/test/actors/sample_generator_wrong_init.py index 9d1c3891..405ab92f 100644 --- a/test/actors/sample_generator_wrong_init.py +++ b/test/actors/sample_generator_wrong_init.py @@ -1,76 +1,2 @@ -from improv.actor import Actor -from datetime import date # used for saving -import numpy as np -import logging - -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - - -class Generator(Actor): - """Sample actor to generate data to pass into a sample processor. - - Intended for use along with sample_processor.py. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.data = None - self.name = "Generator" - self.frame_num = 0 - - def __str__(self): - return f"Name: {self.name}, Data: {self.data}" - - def setup(self): - """Generates an array that serves as an initial source of data. - - Initial array is a 100 row, 5 column numpy matrix that contains - integers from 1-99, inclusive. - """ - - self.data = np.asmatrix(np.random.randint(100, size=(100, 5))) - logger.info("Completed setup for Generator") - - # def run(self): - # """ Send array into the store. - # """ - # self.fcns = {} - # self.fcns['setup'] = self.setup - # self.fcns['run'] = self.runStep - # self.fcns['stop'] = self.stop - - # with RunManager(self.name, self.fcns, self.links) as rm: - # logger.info(rm) - - def stop(self): - """Save current randint vector to a file.""" - - print("Generator stopping") - np.save(f"sample_generator_data_{date.today()}", self.data) - # This is not the best example of a save function, - # will overwrite previous files with the same name. - return 0 - - def runStep(self): - """Generates additional data after initial setup data is exhausted. - - Data is of a different form as the setup data in that although it is - the same size (5x1 vector), it is uniformly distributed in [1, 10] - instead of in [1, 100]. Therefore, the average over time should - converge to 5.5. - """ - - if self.frame_num < np.shape(self.data)[0]: - data_id = self.client.put( - self.data[self.frame_num], str(f"Gen_raw: {self.frame_num}") - ) - try: - self.q_out.put([[data_id, str(self.frame_num)]]) - self.frame_num += 1 - except Exception as e: - logger.error(f"Generator Exception: {e}") - else: - self.data = np.concatenate( - (self.data, np.asmatrix(np.random.randint(10, size=(1, 5)))), axis=0 - ) +class GeneratorActor: + pass diff --git a/demos/minimal/actors/sample_generator.py b/test/actors/sample_generator_zmq.py similarity index 62% rename from demos/minimal/actors/sample_generator.py rename to test/actors/sample_generator_zmq.py index 31335833..0bc5b4f5 100644 --- a/demos/minimal/actors/sample_generator.py +++ b/test/actors/sample_generator_zmq.py @@ -1,4 +1,5 @@ -from improv.actor import Actor +from improv.actor import ZmqActor +from datetime import date # used for saving import numpy as np import logging @@ -6,7 +7,7 @@ logger.setLevel(logging.INFO) -class Generator(Actor): +class Generator(ZmqActor): """Sample actor to generate data to pass into a sample processor. Intended for use along with sample_processor.py. @@ -28,18 +29,30 @@ def setup(self): integers from 1-99, inclusive. """ - logger.info("Beginning setup for Generator") self.data = np.asmatrix(np.random.randint(100, size=(100, 5))) logger.info("Completed setup for Generator") + # def run(self): + # """ Send array into the store. + # """ + # self.fcns = {} + # self.fcns['setup'] = self.setup + # self.fcns['run'] = self.runStep + # self.fcns['stop'] = self.stop + + # with RunManager(self.name, self.fcns, self.links) as rm: + # logger.info(rm) + def stop(self): """Save current randint vector to a file.""" - logger.info("Generator stopping") - np.save("sample_generator_data.npy", self.data) + print("Generator stopping") + np.save(f"sample_generator_data_{date.today()}", self.data) + # This is not the best example of a save function, + # will overwrite previous files with the same name. return 0 - def runStep(self): + def run_step(self): """Generates additional data after initial setup data is exhausted. Data is of a different form as the setup data in that although it is @@ -49,25 +62,14 @@ def runStep(self): """ if self.frame_num < np.shape(self.data)[0]: - if self.store_loc: - data_id = self.client.put( - self.data[self.frame_num], str(f"Gen_raw: {self.frame_num}") - ) - else: - data_id = self.client.put(self.data[self.frame_num]) - # logger.info('Put data in store') + data_id = self.client.put(self.data[self.frame_num]) try: - if self.store_loc: - self.q_out.put([[data_id, str(self.frame_num)]]) - else: - self.q_out.put(data_id) - # logger.info("Sent message on") - + self.q_out.put(data_id) + # logger.info(f"Sent {self.data[self.frame_num]} with key {data_id}") self.frame_num += 1 + except Exception as e: - logger.error( - f"--------------------------------Generator Exception: {e}" - ) + logger.error(f"Generator Exception: {e}") else: self.data = np.concatenate( (self.data, np.asmatrix(np.random.randint(10, size=(1, 5)))), axis=0 diff --git a/demos/zmq/actors/zmq_rr_sample_processor.py b/test/actors/sample_processor_zmq.py similarity index 58% rename from demos/zmq/actors/zmq_rr_sample_processor.py rename to test/actors/sample_processor_zmq.py index e4ed9cd3..1948e778 100644 --- a/demos/zmq/actors/zmq_rr_sample_processor.py +++ b/test/actors/sample_processor_zmq.py @@ -1,34 +1,33 @@ -from improv.actor import Actor, AsyncActor +from improv.actor import ZmqActor import numpy as np import logging -from demos.sample_actors.zmqActor import ZmqActor - logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) class Processor(ZmqActor): - """Sample processor used to calculate the average of an array of integers - using async ZMQ to communicate. + """Sample processor used to calculate the average of an array of integers. Intended for use with sample_generator.py. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + if "name" in kwargs: + self.name = kwargs["name"] def setup(self): """Initializes all class variables. - Sets up a ZmqRRActor to receive data from the generator. self.name (string): name of the actor. - self.frame (ObjectID): Store object id referencing data from the store. + self.frame (ObjectID): StoreInterface object id referencing data from the store. self.avg_list (list): list that contains averages of individual vectors. self.frame_num (int): index of current frame. """ - self.name = "Processor" + if not hasattr(self, "name"): + self.name = "Processor" self.frame = None self.avg_list = [] self.frame_num = 1 @@ -38,10 +37,21 @@ def stop(self): """Trivial stop function for testing purposes.""" logger.info("Processor stopping") + return 0 + + # def run(self): + # """ Send array into the store. + # """ + # self.fcns = {} + # self.fcns['setup'] = self.setup + # self.fcns['run'] = self.runStep + # self.fcns['stop'] = self.stop + + # with RunManager(self.name, self.fcns, self.links) as rm: + # logger.info(rm) - def runStep(self): + def run_step(self): """Gets from the input queue and calculates the average. - Receives data from the generator using a ZmqRRActor. Receives an ObjectID, references data in the store using that ObjectID, calculates the average of that data, and finally prints @@ -49,21 +59,19 @@ def runStep(self): """ frame = None - try: - frame = self.get(reply='received') - - except: - logger.error("Could not get frame!") + frame = self.q_in.get(timeout=0.05) + except Exception as e: + logger.error(f"{self.name} could not get frame! At {self.frame_num}: {e}") pass if frame is not None and self.frame_num is not None: self.done = False - self.frame = self.client.getID(frame) + self.frame = self.client.get(frame) avg = np.mean(self.frame[0]) - + # logger.info(f"{self.name} got frame {frame} with value {self.frame}") # logger.info(f"Average: {avg}") self.avg_list.append(avg) - logger.info(f"Overall Average: {np.mean(self.avg_list)}") + # logger.info(f"Overall Average: {np.mean(self.avg_list)}") # logger.info(f"Frame number: {self.frame_num}") self.frame_num += 1 diff --git a/test/configs/bad_args.yaml b/test/configs/bad_args.yaml index be94212c..ce96b6b6 100644 --- a/test/configs/bad_args.yaml +++ b/test/configs/bad_args.yaml @@ -1,20 +1,7 @@ actors: Acquirer: - package: demos.sample_actors.acquire - class: FileAcquirer - fiasdfe: data/Tolias_mesoscope_2.hdf5 - fraefawe: 30 - - Processor: - package: demos.sample_actors.process - class: CaimanProcessor - init_filename: data/tbif_ex_crop.h5 - config_file: eva_caiman_params.txt - - Analysis: - package: demos.sample_actors.analysis - class: MeanAnalysis + package: actors.actor_for_bad_args + class: BadActorArgs + existent_parameter: 42 connections: - Acquirer.q_out: [Processor.q_in] - Processor.q_out: [Analysis.q_in] diff --git a/test/configs/complex_graph.yaml b/test/configs/complex_graph.yaml index 3c82029a..16961825 100644 --- a/test/configs/complex_graph.yaml +++ b/test/configs/complex_graph.yaml @@ -14,8 +14,14 @@ actors: class: BehaviorAcquirer connections: - Acquirer.q_out: [Analysis.q_in, InputStim.q_in] - Analysis.q_out: [InputStim.q_in] - -# settings: -# use_watcher: [Acquirer, Processor, Visual, Analysis] + Acquirer-out: + sources: + - Acquirer.q_out + sinks: + - Analysis.q_in + - InputStim.q_in + Analysis-InputStim: + sources: + - Analysis.q_out + sinks: + - InputStim.q_in diff --git a/test/configs/cyclic_config.yaml b/test/configs/cyclic_config.yaml index 61c91303..8e87b880 100644 --- a/test/configs/cyclic_config.yaml +++ b/test/configs/cyclic_config.yaml @@ -14,5 +14,14 @@ actors: class: BehaviorAcquirer connections: - Acquirer.q_out: [Analysis.q_in, InputStim.q_in] - Analysis.q_out: [Acquirer.q_in] + Acquirer-out: + sources: + - Acquirer.q_out + sinks: + - Analysis.q_in + - InputStim.q_in + Analysis-Acquirer: + sources: + - Analysis.q_out + sinks: + - Acquirer.q_in diff --git a/test/configs/good_config.yaml b/test/configs/good_config.yaml index b0c70b70..6f72da5c 100644 --- a/test/configs/good_config.yaml +++ b/test/configs/good_config.yaml @@ -10,4 +10,8 @@ actors: class: SimpleAnalysis connections: - Acquirer.q_out: [Analysis.q_in] + Acquirer-Analysis: + sources: + - Acquirer.q_out + sinks: + - Analysis.q_in diff --git a/test/configs/good_config_actors.yaml b/test/configs/good_config_actors.yaml deleted file mode 100644 index e84fd262..00000000 --- a/test/configs/good_config_actors.yaml +++ /dev/null @@ -1,8 +0,0 @@ -Acquirer: - class: FileAcquirer - filename: data/Tolias_mesoscope_2.hdf5 - framerate: 30 - package: demos.sample_actors.acquire -Analysis: - class: SimpleAnalysis - package: demos.sample_actors.simple_analysis diff --git a/test/configs/good_config_plasma.yaml b/test/configs/good_config_plasma.yaml deleted file mode 100644 index 8323017b..00000000 --- a/test/configs/good_config_plasma.yaml +++ /dev/null @@ -1,15 +0,0 @@ -actors: - Acquirer: - package: demos.sample_actors.acquire - class: FileAcquirer - filename: data/Tolias_mesoscope_2.hdf5 - framerate: 30 - - Analysis: - package: demos.sample_actors.simple_analysis - class: SimpleAnalysis - -connections: - Acquirer.q_out: [Analysis.q_in] - -plasma_config: \ No newline at end of file diff --git a/test/configs/minimal.yaml b/test/configs/minimal.yaml index 230282d1..d752b415 100644 --- a/test/configs/minimal.yaml +++ b/test/configs/minimal.yaml @@ -1,11 +1,15 @@ actors: Generator: - package: actors.sample_generator + package: actors.sample_generator_zmq class: Generator Processor: - package: actors.sample_processor + package: actors.sample_processor_zmq class: Processor connections: - Generator.q_out: [Processor.q_in] \ No newline at end of file + Generator-Processor: + sources: + - Generator.q_out + sinks: + - Processor.q_in diff --git a/test/configs/minimal_gui.yaml b/test/configs/minimal_gui.yaml new file mode 100644 index 00000000..8dd8bec5 --- /dev/null +++ b/test/configs/minimal_gui.yaml @@ -0,0 +1,15 @@ +actors: + GUI: + package: actors.sample_generator_zmq + class: Generator + + Processor: + package: actors.sample_processor_zmq + class: Processor + +connections: + Generator-Processor: + sources: + - Generator.q_out + sinks: + - Processor.q_in diff --git a/test/configs/minimal_harvester.yaml b/test/configs/minimal_harvester.yaml new file mode 100644 index 00000000..74847433 --- /dev/null +++ b/test/configs/minimal_harvester.yaml @@ -0,0 +1,18 @@ +actors: + Generator: + package: actors.sample_generator_zmq + class: Generator + + Processor: + package: actors.sample_processor_zmq + class: Processor + +connections: + Generator-Processor: + sources: + - Generator.q_out + sinks: + - Processor.q_in + +settings: + harvest_data_from_memory: True \ No newline at end of file diff --git a/test/configs/minimal_with_custom_aof_dirname.yaml b/test/configs/minimal_with_custom_aof_dirname.yaml index 1bac5863..f07c0c44 100644 --- a/test/configs/minimal_with_custom_aof_dirname.yaml +++ b/test/configs/minimal_with_custom_aof_dirname.yaml @@ -1,14 +1,18 @@ actors: Generator: - package: actors.sample_generator + package: actors.sample_generator_zmq class: Generator Processor: - package: actors.sample_processor + package: actors.sample_processor_zmq class: Processor connections: - Generator.q_out: [Processor.q_in] + Generator-Processor: + sources: + - Generator.q_out + sinks: + - Processor.q_in redis_config: aof_dirname: custom_aof_dirname \ No newline at end of file diff --git a/test/configs/minimal_with_ephemeral_aof_dirname.yaml b/test/configs/minimal_with_ephemeral_aof_dirname.yaml index 9933896e..d4f5a635 100644 --- a/test/configs/minimal_with_ephemeral_aof_dirname.yaml +++ b/test/configs/minimal_with_ephemeral_aof_dirname.yaml @@ -1,14 +1,18 @@ actors: Generator: - package: actors.sample_generator + package: actors.sample_generator_zmq class: Generator Processor: - package: actors.sample_processor + package: actors.sample_processor_zmq class: Processor connections: - Generator.q_out: [Processor.q_in] + Generator-Processor: + sources: + - Generator.q_out + sinks: + - Processor.q_in redis_config: generate_ephemeral_aof_dirname: True \ No newline at end of file diff --git a/test/configs/minimal_with_every_second_saving.yaml b/test/configs/minimal_with_every_second_saving.yaml index 74d4b314..6d381f72 100644 --- a/test/configs/minimal_with_every_second_saving.yaml +++ b/test/configs/minimal_with_every_second_saving.yaml @@ -1,14 +1,18 @@ actors: Generator: - package: actors.sample_generator + package: actors.sample_generator_zmq class: Generator Processor: - package: actors.sample_processor + package: actors.sample_processor_zmq class: Processor connections: - Generator.q_out: [Processor.q_in] + Generator-Processor: + sources: + - Generator.q_out + sinks: + - Processor.q_in redis_config: enable_saving: True diff --git a/test/configs/minimal_with_every_write_saving.yaml b/test/configs/minimal_with_every_write_saving.yaml index add0c74f..4f1e3af4 100644 --- a/test/configs/minimal_with_every_write_saving.yaml +++ b/test/configs/minimal_with_every_write_saving.yaml @@ -1,14 +1,18 @@ actors: Generator: - package: actors.sample_generator + package: actors.sample_generator_zmq class: Generator Processor: - package: actors.sample_processor + package: actors.sample_processor_zmq class: Processor connections: - Generator.q_out: [Processor.q_in] + Generator-Processor: + sources: + - Generator.q_out + sinks: + - Processor.q_in redis_config: enable_saving: True diff --git a/test/configs/minimal_with_fixed_default_redis_port.yaml b/test/configs/minimal_with_fixed_default_redis_port.yaml index c8e0b24e..8d33ded9 100644 --- a/test/configs/minimal_with_fixed_default_redis_port.yaml +++ b/test/configs/minimal_with_fixed_default_redis_port.yaml @@ -1,14 +1,18 @@ actors: Generator: - package: actors.sample_generator + package: actors.sample_generator_zmq class: Generator Processor: - package: actors.sample_processor + package: actors.sample_processor_zmq class: Processor connections: - Generator.q_out: [Processor.q_in] + Generator-Processor: + sources: + - Generator.q_out + sinks: + - Processor.q_in redis_config: port: 6379 \ No newline at end of file diff --git a/test/configs/minimal_with_fixed_redis_port.yaml b/test/configs/minimal_with_fixed_redis_port.yaml index f6f98253..b63471f4 100644 --- a/test/configs/minimal_with_fixed_redis_port.yaml +++ b/test/configs/minimal_with_fixed_redis_port.yaml @@ -1,14 +1,18 @@ actors: Generator: - package: actors.sample_generator + package: actors.sample_generator_zmq class: Generator Processor: - package: actors.sample_processor + package: actors.sample_processor_zmq class: Processor connections: - Generator.q_out: [Processor.q_in] + Generator-Processor: + sources: + - Generator.q_out + sinks: + - Processor.q_in redis_config: port: 6378 \ No newline at end of file diff --git a/test/configs/minimal_with_no_schedule_saving.yaml b/test/configs/minimal_with_no_schedule_saving.yaml index 2e5f62f3..20731f3d 100644 --- a/test/configs/minimal_with_no_schedule_saving.yaml +++ b/test/configs/minimal_with_no_schedule_saving.yaml @@ -1,14 +1,18 @@ actors: Generator: - package: actors.sample_generator + package: actors.sample_generator_zmq class: Generator Processor: - package: actors.sample_processor + package: actors.sample_processor_zmq class: Processor connections: - Generator.q_out: [Processor.q_in] + Generator-Processor: + sources: + - Generator.q_out + sinks: + - Processor.q_in redis_config: enable_saving: True diff --git a/test/configs/minimal_with_redis_saving.yaml b/test/configs/minimal_with_redis_saving.yaml index 15e95b6c..087f685d 100644 --- a/test/configs/minimal_with_redis_saving.yaml +++ b/test/configs/minimal_with_redis_saving.yaml @@ -1,14 +1,18 @@ actors: Generator: - package: actors.sample_generator + package: actors.sample_generator_zmq class: Generator Processor: - package: actors.sample_processor + package: actors.sample_processor_zmq class: Processor connections: - Generator.q_out: [Processor.q_in] + Generator-Processor: + sources: + - Generator.q_out + sinks: + - Processor.q_in redis_config: enable_saving: True \ No newline at end of file diff --git a/test/configs/minimal_with_settings.yaml b/test/configs/minimal_with_settings.yaml index a5aec2d8..daff183d 100644 --- a/test/configs/minimal_with_settings.yaml +++ b/test/configs/minimal_with_settings.yaml @@ -4,16 +4,19 @@ settings: control_port: 6000 output_port: 6001 logging_port: 6002 - use_watcher: false actors: Generator: - package: actors.sample_generator + package: actors.sample_generator_zmq class: Generator Processor: - package: actors.sample_processor + package: actors.sample_processor_zmq class: Processor connections: - Generator.q_out: [Processor.q_in] \ No newline at end of file + Generator-Processor: + sources: + - Generator.q_out + sinks: + - Processor.q_in diff --git a/test/configs/minimal_wrong_import.yaml b/test/configs/minimal_wrong_import.yaml index 6c5e51c5..3fbc7822 100644 --- a/test/configs/minimal_wrong_import.yaml +++ b/test/configs/minimal_wrong_import.yaml @@ -4,8 +4,12 @@ actors: class: Generator Processor: - package: actors.sample_processor + package: actors.sample_processor_zmq class: Processor connections: - Generator.q_out: [Processor.q_in] \ No newline at end of file + Generator-Processor: + sources: + - Generator.q_out + sinks: + - Processor.q_in \ No newline at end of file diff --git a/test/configs/minimal_wrong_init.yaml b/test/configs/minimal_wrong_init.yaml index a59c954a..edb45b51 100644 --- a/test/configs/minimal_wrong_init.yaml +++ b/test/configs/minimal_wrong_init.yaml @@ -4,8 +4,12 @@ actors: class: Generator Processor: - package: actors.sample_processor + package: actors.sample_processor_zmq class: Processor connections: - Generator.q_out: [Processor.q_in] \ No newline at end of file + Generator-Processor: + sources: + - Generator.q_out + sinks: + - Processor.q_in \ No newline at end of file diff --git a/test/configs/simple_graph.yaml b/test/configs/simple_graph.yaml index 67086def..9238739c 100644 --- a/test/configs/simple_graph.yaml +++ b/test/configs/simple_graph.yaml @@ -10,7 +10,8 @@ actors: class: SimpleAnalysis connections: - Acquirer.q_out: [Analysis.q_in] - -# settings: -# use_watcher: [Acquirer, Processor, Visual, Analysis] + Acquirer-Analysis: + sources: + - Acquirer.q_out + sinks: + - Analysis.q_in diff --git a/test/configs/single_actor.yaml b/test/configs/single_actor.yaml index 812480b1..c9c6b0e2 100644 --- a/test/configs/single_actor.yaml +++ b/test/configs/single_actor.yaml @@ -6,5 +6,3 @@ actors: framerate: 15 connections: -# settings: -# use_watcher: [Acquirer, Processor, Visual, Analysis] diff --git a/test/configs/single_actor_plasma.yaml b/test/configs/single_actor_plasma.yaml deleted file mode 100644 index 812480b1..00000000 --- a/test/configs/single_actor_plasma.yaml +++ /dev/null @@ -1,10 +0,0 @@ -actors: - Acquirer: - package: demos.sample_actors.acquire - class: FileAcquirer - filename: data/Tolias_mesoscope_2.hdf5 - framerate: 15 - -connections: -# settings: -# use_watcher: [Acquirer, Processor, Visual, Analysis] diff --git a/test/conftest.py b/test/conftest.py index f9b80fa5..d123fcdc 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,17 +1,79 @@ +import logging +import multiprocessing import os import signal -import uuid +import time + import pytest import subprocess -store_loc = str(os.path.join("/tmp/", str(uuid.uuid4()))) +import zmq + +from improv.actor import ZmqActor +from improv.harvester import bootstrap_harvester +from improv.nexus import Nexus + redis_port_num = 6379 -WAIT_TIMEOUT = 120 +WAIT_TIMEOUT = 10 + +SERVER_COUNTER = 0 + +signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT) + + +@pytest.fixture +def ports(): + global SERVER_COUNTER + CONTROL_PORT = 30000 + OUTPUT_PORT = 30001 + LOGGING_PORT = 30002 + ACTOR_IN_PORT = 30003 + yield ( + CONTROL_PORT + SERVER_COUNTER, + OUTPUT_PORT + SERVER_COUNTER, + LOGGING_PORT + SERVER_COUNTER, + ACTOR_IN_PORT + SERVER_COUNTER, + ) + SERVER_COUNTER += 4 + + +@pytest.fixture +def setdir(): + prev = os.getcwd() + os.chdir(os.path.dirname(__file__) + "/configs") + yield None + os.chdir(prev) + + +@pytest.fixture +def set_dir_config_parent(): + prev = os.getcwd() + os.chdir(os.path.dirname(__file__)) + yield None + os.chdir(prev) @pytest.fixture -def set_store_loc(): - return store_loc +def sample_nex(setdir, ports): + nex = Nexus("test") + try: + nex.create_nexus( + file="good_config.yaml", + store_size=40000000, + control_port=ports[0], + output_port=ports[1], + ) + except Exception as e: + print(f"error caught in test harness during create_nexus step: {e}") + logging.error(f"error caught in test harness during create_nexus step: {e}") + raise e + yield nex + try: + nex.destroy_nexus() + except Exception as e: + print(f"error caught in test harness during destroy_nexus step: {e}") + logging.error(f"error caught in test harness during destroy_nexus step: {e}") + raise e @pytest.fixture @@ -37,21 +99,96 @@ def setup_store(server_port_num): stderr=subprocess.DEVNULL, ) + time.sleep(3) + if p.poll() is not None: + raise Exception("redis-server failed to start") + yield p # kill the subprocess when the caller is done with it p.send_signal(signal.SIGINT) - p.wait(WAIT_TIMEOUT) + p.wait(timeout=WAIT_TIMEOUT) -@pytest.fixture -def setup_plasma_store(set_store_loc, scope="module"): - """Fixture to set up the store subprocess with 10 mb.""" - p = subprocess.Popen( - ["plasma_store", "-s", set_store_loc, "-m", str(10000000)], - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, +def nex_startup(ports, filename): + nex = Nexus("test") + nex.create_nexus( + file=filename, + store_size=100000000, + control_port=ports[0], + output_port=ports[1], + actor_in_port=ports[3], ) + nex.start_nexus() + + +@pytest.fixture +def start_nexus_minimal_zmq(ports): + filename = "minimal.yaml" + p = multiprocessing.Process(target=nex_startup, args=(ports, filename)) + p.start() + time.sleep(1) + yield p - p.send_signal(signal.SIGINT) - p.wait(WAIT_TIMEOUT) + + p.terminate() + p.join(WAIT_TIMEOUT) + if p.exitcode is None: + logging.exception("Timed out waiting for nexus to stop") + p.kill() + + +@pytest.fixture +def zmq_actor(ports): + actor = ZmqActor(ports[3], None, None, None, None, None, name="test") + + p = multiprocessing.Process(target=actor_startup, args=(actor,)) + + yield p + + p.terminate() + p.join(WAIT_TIMEOUT) + if p.exitcode is None: + p.kill() + + +def actor_startup(actor): + actor.register_with_nexus() + + +@pytest.fixture +def harvester(ports): + ctx = zmq.Context() + socket = ctx.socket(zmq.PUB) + socket.bind("tcp://*:1234") + p = multiprocessing.Process( + target=bootstrap_harvester, + args=( + "localhost", + ports[3], + "localhost", + 6379, + "localhost", + 1234, + "localhost", + 12345, + ), + ) + p.start() + time.sleep(1) + yield ports, socket, p + socket.close(linger=0) + ctx.destroy(linger=0) + + +class SignalManager: + def __init__(self): + self.signal_handlers = dict() + + def __enter__(self): + for sig in signals: + self.signal_handlers[sig] = signal.getsignal(sig) + + def __exit__(self, type, value, traceback): + for sig, handler in self.signal_handlers.items(): + signal.signal(sig, handler) diff --git a/test/conftest_with_errors.py b/test/conftest_with_errors.py deleted file mode 100644 index c299959d..00000000 --- a/test/conftest_with_errors.py +++ /dev/null @@ -1,78 +0,0 @@ -# import pytest -import subprocess -import asyncio -from improv.actor import RunManager, AsyncRunManager -import os -import uuid - - -class StoreInterfaceDependentTestCase: - def set_up(self): - """Start the server""" - print("Setting up Plasma store.") - store_loc = str(os.path.join("/tmp/", str(uuid.uuid4()))) - self.p = subprocess.Popen( - ["plasma_store", "-s", store_loc, "-m", str(10000000)], - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - ) - - def tear_down(self): - """Kill the server""" - print("Tearing down Plasma store.") - self.p.kill() - self.p.wait() - - -class ActorDependentTestCase: - def set_up(self): - """Start the server""" - print("Setting up Plasma store.") - store_loc = str(os.path.join("/tmp/", str(uuid.uuid4()))) - self.p = subprocess.Popen( - ["plasma_store", "-s", store_loc, "-m", str(10000000)], - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - ) - - def tear_down(self): - """Kill the server""" - print("Tearing down Plasma store.") - self.p.kill() - self.p.wait() - - def run_setup(self): - print("Set up = True.") - self.is_set_up = True - - def run_method(self): - # Accurate print statement? - print("Running method.") - self.run_num += 1 - - def process_setup(self): - print("Processing setup.") - pass - - def process_run(self): - print("Processing run - ran.") - self.q_comm.put("ran") - - def create_process(self, q_sig, q_comm): - print("Creating process.") - with RunManager( - "test", self.process_run, self.process_setup, q_sig, q_comm - ) as rm: - print(rm) - - async def createAsyncProcess(self, q_sig, q_comm): - print("Creating asyn process.") - async with AsyncRunManager( - "test", self.process_run, self.process_setup, q_sig, q_comm - ) as rm: - print(rm) - - async def a_put(self, signal, time): - print("Async put.") - await asyncio.sleep(time) - self.q_sig.put_async(signal) diff --git a/test/nexus_analog.py b/test/nexus_analog.py deleted file mode 100644 index 0fc698ec..00000000 --- a/test/nexus_analog.py +++ /dev/null @@ -1,119 +0,0 @@ -import concurrent -import time -import asyncio -import math -import uuid -import os -from improv.link import Link -from improv.store import StoreInterface -import subprocess - - -def clean_list_print(lst): - print("\n=======================\n") - for el in lst: - print(el) - print("\n") - print("\n=======================\n") - - -def setup_store(): - """Fixture to set up the store subprocess with 10 mb. - - This fixture runs a subprocess that instantiates the store with a - memory of 10 megabytes. It specifies that "/tmp/store/" is the - location of the store socket. - - Yields: - StoreInterface: An instance of the store. - - TODO: - Figure out the scope. - """ - store_loc = str(os.path.join("/tmp/", str(uuid.uuid4()))) - subprocess.Popen( - ["plasma_store", "-s", store_loc, "-m", str(10000000)], - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - ) - store = StoreInterface(store_loc=store_loc) - return store - - -async def pollQueues(links): - tasks = [] - for link in links: - tasks.append(asyncio.create_task(link.get_async())) - - links_cpy = links - t_0 = time.perf_counter() - t_1 = time.perf_counter() - print("time get") - cur_time = 0 - while t_1 - t_0 < 5: - # need to add something to the queue such that asyncio.wait returns - - links[0].put("Message") - done, pending = await asyncio.wait( - tasks, return_when=concurrent.futures.FIRST_COMPLETED - ) - for i, t in enumerate(tasks): - if t in done: - pass - tasks[i] = asyncio.create_task(links_cpy[i].get_async()) - - t_1 = time.perf_counter() - - if math.floor(t_1 - t_0) != cur_time: - print(math.floor(t_1 - t_0)) - cur_time = math.floor(t_1 - t_0) - - print("All tasks prior to stop polling: \n") - clean_list_print([task for task in tasks]) - - asyncio.get_running_loop() - return stop_polling(tasks, links) - - -def start(): - links = [ - Link(f"Link {i}", f"start {i}", f"end {i}", setup_store()) for i in range(4) - ] - loop = asyncio.get_event_loop() - print("RUC loop") - res = loop.run_until_complete(pollQueues(links)) - print(f"RES: {res}") - - print("**********************\nAll tasks at the end of execution:") - clean_list_print(res) - print("**********************") - print(f"Loop: {loop}") - loop.close() - print(f"Loop: {loop}") - - -def stop_polling(tasks, links): - # asyncio.gather(*tasks) - print("Cancelling") - - [lnk.put("msg") for lnk in links] - - # [task.cancel() for task in tasks] - # [task.cancel() for task in tasks] - - # [lnk.put("msg") for lnk in links] - - print("All tasks: \n") - clean_list_print([task for task in tasks]) - print("Pending:\n") - clean_list_print([task for task in tasks if not task.done()]) - print("Cancelled: \n") - clean_list_print([task for task in tasks if task.cancelled()]) - print("Pending and cancelled: \n") - clean_list_print([task for task in tasks if not task.done() and task.cancelled()]) - - return [task for task in tasks] - - -if __name__ == "__main__": - start() diff --git a/test/special_configs/basic_demo.yaml b/test/special_configs/basic_demo.yaml index 1b84f7d8..6f113765 100644 --- a/test/special_configs/basic_demo.yaml +++ b/test/special_configs/basic_demo.yaml @@ -16,6 +16,3 @@ actors: connections: Acquirer.q_out: [Analysis.q_in] InputStim.q_out: [Analysis.input_stim_queue] - -# settings: -# use_watcher: [Acquirer, Processor, Visual, Analysis] diff --git a/test/special_configs/basic_demo_with_GUI.yaml b/test/special_configs/basic_demo_with_GUI.yaml index edfb09e1..0e5dfc27 100644 --- a/test/special_configs/basic_demo_with_GUI.yaml +++ b/test/special_configs/basic_demo_with_GUI.yaml @@ -21,6 +21,3 @@ actors: connections: Acquirer.q_out: [Analysis.q_in] InputStim.q_out: [Analysis.input_stim_queue] - -# settings: -# use_watcher: [Acquirer, Processor, Visual, Analysis] diff --git a/test/test_actor.py b/test/test_actor.py index 448fd34e..82251418 100644 --- a/test/test_actor.py +++ b/test/test_actor.py @@ -1,9 +1,12 @@ import os import psutil import pytest -from improv.link import Link # , AsyncQueue +import zmq + from improv.actor import AbstractActor as Actor -from improv.store import StoreInterface, PlasmaStoreInterface +from improv.messaging import ActorStateMsg, ActorStateReplyMsg +from improv.store import StoreInterface +from improv.link import ZmqLink # set global_variables @@ -12,10 +15,10 @@ @pytest.fixture -def init_actor(set_store_loc): +def init_actor(): """Fixture to initialize and teardown an instance of actor.""" - act = Actor("Test", set_store_loc) + act = Actor("Test") yield act act = None @@ -33,33 +36,15 @@ def example_links(setup_store, server_port_num): """Fixture to provide link objects as test input and setup store.""" StoreInterface(server_port_num=server_port_num) - acts = [ - Actor("act" + str(i), server_port_num) for i in range(1, 5) - ] # range must be even + ctx = zmq.Context() + s = ctx.socket(zmq.PUSH) - links = [ - Link("L" + str(i + 1), acts[i], acts[i + 1]) for i in range(len(acts) // 2) - ] + links = [ZmqLink(s, "test") for i in range(2)] link_dict = {links[i].name: links[i] for i, l in enumerate(links)} pytest.example_links = link_dict - return pytest.example_links - - -@pytest.fixture -def example_links_plasma(setup_store, set_store_loc): - """Fixture to provide link objects as test input and setup store.""" - PlasmaStoreInterface(store_loc=set_store_loc) - - acts = [ - Actor("act" + str(i), set_store_loc) for i in range(1, 5) - ] # range must be even - - links = [ - Link("L" + str(i + 1), acts[i], acts[i + 1]) for i in range(len(acts) // 2) - ] - link_dict = {links[i].name: links[i] for i, l in enumerate(links)} - pytest.example_links = link_dict - return pytest.example_links + yield pytest.example_links + s.close(linger=0) + ctx.destroy(linger=0) @pytest.mark.parametrize( @@ -89,66 +74,34 @@ def test_repr_default_initialization(init_actor): assert rep == "Test: dict_keys([])" -def test_repr(example_string_links, set_store_loc): +def test_repr(example_string_links): """Test if the actor representation has the right, nonempty, dict.""" - act = Actor("Test", set_store_loc) - act.setLinks(example_string_links) + act = Actor("Test") + act.set_links(example_string_links) assert act.__repr__() == "Test: dict_keys(['1', '2', '3'])" -def test_setStoreInterface(setup_store, server_port_num): +def test_set_store_interface(setup_store, server_port_num): """Tests if the store is started and linked with the actor.""" act = Actor("Acquirer", server_port_num) store = StoreInterface(server_port_num=server_port_num) - act.setStoreInterface(store.client) - assert act.client is store.client - - -def test_plasma_setStoreInterface(setup_plasma_store, set_store_loc): - """Tests if the store is started and linked with the actor.""" - - act = Actor("Acquirer", set_store_loc) - store = PlasmaStoreInterface(store_loc=set_store_loc) - act.setStoreInterface(store.client) + act.set_store_interface(store.client) assert act.client is store.client @pytest.mark.parametrize( - "links", [(pytest.example_string_links), ({}), (pytest.example_links), (None)] + "links", [pytest.example_string_links, ({}), pytest.example_links, None] ) -def test_setLinks(links, set_store_loc): +def test_set_links(links): """Tests if the actors links can be set to certain values.""" - act = Actor("test", set_store_loc) - act.setLinks(links) + act = Actor("test") + act.set_links(links) assert act.links == links -@pytest.mark.parametrize( - ("qc", "qs"), - [ - ("comm", "sig"), - (None, None), - ("", ""), - ("LINK", "LINK"), # these are placeholder names (store is not setup) - ], -) -def test_setCommLinks(example_links, qc, qs, init_actor, setup_store, set_store_loc): - """Tests if commLinks can be added to the actor"s links.""" - - if qc == "LINK" and qs == "LINK": - qc = Link("L1", Actor("1", set_store_loc), Actor("2", set_store_loc)) - qs = Link("L2", Actor("3", set_store_loc), Actor("4", set_store_loc)) - act = init_actor - act.setLinks(example_links) - act.setCommLinks(qc, qs) - - example_links.update({"q_comm": qc, "q_sig": qs}) - assert act.links == example_links - - @pytest.mark.parametrize( ("links", "expected"), [ @@ -158,18 +111,18 @@ def test_setCommLinks(example_links, qc, qs, init_actor, setup_store, set_store_ (None, TypeError), ], ) -def test_setLinkIn(init_actor, example_string_links, example_links, links, expected): +def test_set_link_in(init_actor, example_string_links, example_links, links, expected): """Tests if we can set the input queue.""" act = init_actor - act.setLinks(links) + act.set_links(links) if links is not None: - act.setLinkIn("input_q") + act.set_link_in("input_q") expected.update({"q_in": "input_q"}) assert act.links == expected else: with pytest.raises(AttributeError): - act.setLinkIn("input_queue") + act.set_link_in("input_queue") @pytest.mark.parametrize( @@ -181,18 +134,18 @@ def test_setLinkIn(init_actor, example_string_links, example_links, links, expec (None, TypeError), ], ) -def test_setLinkOut(init_actor, example_string_links, example_links, links, expected): +def test_set_link_out(init_actor, example_string_links, example_links, links, expected): """Tests if we can set the output queue.""" act = init_actor - act.setLinks(links) + act.set_links(links) if links is not None: - act.setLinkOut("output_q") + act.set_link_out("output_q") expected.update({"q_out": "output_q"}) assert act.links == expected else: with pytest.raises(AttributeError): - act.setLinkIn("output_queue") + act.set_link_in("output_queue") @pytest.mark.parametrize( @@ -204,29 +157,31 @@ def test_setLinkOut(init_actor, example_string_links, example_links, links, expe (None, TypeError), ], ) -def test_setLinkWatch(init_actor, example_string_links, example_links, links, expected): +def test_set_link_watch( + init_actor, example_string_links, example_links, links, expected +): """Tests if we can set the watch queue.""" act = init_actor - act.setLinks(links) + act.set_links(links) if links is not None: - act.setLinkWatch("watch_q") + act.set_link_watch("watch_q") expected.update({"q_watchout": "watch_q"}) assert act.links == expected else: with pytest.raises(AttributeError): - act.setLinkIn("input_queue") + act.set_link_in("input_queue") -def test_addLink(setup_store, set_store_loc): +def test_add_link(setup_store): """Tests if a link can be added to the dictionary of links.""" - act = Actor("test", set_store_loc) + act = Actor("test") links = {"1": "one", "2": "two"} - act.setLinks(links) + act.set_links(links) newName = "3" newLink = "three" - act.addLink(newName, newLink) + act.add_link(newName, newLink) links.update({"3": "three"}) # trying to check for two separate conditions while being able to @@ -234,7 +189,7 @@ def test_addLink(setup_store, set_store_loc): passes = [] err_messages = [] - if act.getLinks()["3"] == "three": + if act.get_links()["3"] == "three": passes.append(True) else: passes.append(False) @@ -243,7 +198,7 @@ def test_addLink(setup_store, set_store_loc): actor.getLinks()['3'] is not equal to \"three\"" ) - if act.getLinks() == links: + if act.get_links() == links: passes.append(True) else: passes.append("False") @@ -256,7 +211,7 @@ def test_addLink(setup_store, set_store_loc): assert all(passes), f"The following errors occurred: {err_out}" -def test_getLinks(init_actor, example_string_links): +def test_get_links(init_actor, example_string_links): """Tests if we can access the dictionary of links. TODO: @@ -265,25 +220,9 @@ def test_getLinks(init_actor, example_string_links): act = init_actor links = example_string_links - act.setLinks(links) - - assert act.getLinks() == {"1": "one", "2": "two", "3": "three"} - + act.set_links(links) -@pytest.mark.skip( - reason="this is something we'll do later because\ - we will subclass actor w/ watcher later" -) -def test_put(init_actor): - """Tests if data keys can be put to output links. - - TODO: - Ask Anne to take a look. - """ - - act = init_actor - act.put() - assert True + assert act.get_links() == {"1": "one", "2": "two", "3": "three"} def test_run(init_actor): @@ -293,12 +232,12 @@ def test_run(init_actor): act.run() -def test_changePriority(init_actor): +def test_change_priority(init_actor): """Tests if we are able to change the priority of an actor.""" act = init_actor act.lower_priority = True - act.changePriority() + act.change_priority() assert psutil.Process(os.getpid()).nice() == 19 @@ -311,39 +250,23 @@ def test_actor_connection(setup_store, server_port_num): one actor. Then, in the other actor, it is removed from the queue, and checked to verify it matches the original message. """ - act1 = Actor("a1", server_port_num) - act2 = Actor("a2", server_port_num) - - StoreInterface(server_port_num=server_port_num) - link = Link("L12", act1, act2) - act1.setLinkIn(link) - act2.setLinkOut(link) - - msg = "message" - - act1.q_in.put(msg) + assert True - assert act2.q_out.get() == msg +def test_actor_registration_with_nexus(ports, zmq_actor): + context = zmq.Context() + nex_socket = context.socket(zmq.REP) + nex_socket.bind(f"tcp://*:{ports[3]}") # actor in port -def test_plasma_actor_connection(setup_plasma_store, set_store_loc): - """Test if the links between actors are established correctly. + zmq_actor.start() - This test instantiates two actors with different names, then instantiates - a Link object linking the two actors. A string is put to the input queue of - one actor. Then, in the other actor, it is removed from the queue, and - checked to verify it matches the original message. - """ - act1 = Actor("a1", set_store_loc) - act2 = Actor("a2", set_store_loc) + res = nex_socket.recv_pyobj() + assert isinstance(res, ActorStateMsg) - PlasmaStoreInterface(store_loc=set_store_loc) - link = Link("L12", act1, act2) - act1.setLinkIn(link) - act2.setLinkOut(link) + nex_socket.send_pyobj(ActorStateReplyMsg("test", "OK", "")) - msg = "message" + zmq_actor.terminate() + zmq_actor.join(10) - act1.q_in.put(msg) - assert act2.q_out.get() == msg +# TODO: register with broker test diff --git a/test/test_cli.py b/test/test_cli.py index c4b34bff..beba7db3 100644 --- a/test/test_cli.py +++ b/test/test_cli.py @@ -1,36 +1,29 @@ +import time + import pytest import os import datetime from collections import namedtuple import subprocess -import asyncio import signal -import improv.cli as cli - -from test_nexus import ports -SERVER_WARMUP = 16 -SERVER_TIMEOUT = 16 +import improv.cli as cli +from conftest import ports -@pytest.fixture -def setdir(): - prev = os.getcwd() - os.chdir(os.path.dirname(__file__)) - yield None - os.chdir(prev) +SERVER_WARMUP = 10 +SERVER_TIMEOUT = 15 @pytest.fixture -async def server(setdir, ports): +def server(setdir, ports): """ Sets up a server using minimal.yaml in the configs folder. Requires the actor path command line argument and so implicitly tests that as well. """ - os.chdir("configs") - control_port, output_port, logging_port = ports + control_port, output_port, logging_port, actor_in_port = ports # start server server_opts = [ @@ -49,23 +42,18 @@ async def server(setdir, ports): "minimal.yaml", ] - server = subprocess.Popen( - server_opts, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL - ) - await asyncio.sleep(SERVER_WARMUP) + server = subprocess.Popen(server_opts) + time.sleep(SERVER_WARMUP) yield server - server.wait(SERVER_TIMEOUT) - try: - os.remove("testlog") - except FileNotFoundError: - pass + if server.poll() is None: + pytest.fail("Server did not shut down correctly.") @pytest.fixture -async def cli_args(setdir, ports): +def cli_args(setdir, ports): logfile = "tmp.log" - control_port, output_port, logging_port = ports - config_file = "configs/minimal.yaml" + control_port, output_port, logging_port, actor_in_port = ports + config_file = "minimal.yaml" Args = namedtuple( "cli_args", "control_port output_port logging_port logfile configfile actor_path", @@ -86,7 +74,7 @@ def test_configfile_required(setdir): cli.parse_cli_args(["server", "does_not_exist.yaml"]) -def test_multiple_actor_path(setdir): +def test_multiple_actor_path(set_dir_config_parent): args = cli.parse_cli_args( ["run", "-a", "actors", "-a", "configs", "configs/blank_file.yaml"] ) @@ -113,7 +101,7 @@ def test_multiple_actor_path(setdir): ], ) def test_can_override_ports(mode, flag, expected, setdir): - file = "configs/blank_file.yaml" + file = "blank_file.yaml" localhost = "127.0.0.1:" params = { "-c": "control_port", @@ -142,7 +130,7 @@ def test_can_override_ports(mode, flag, expected, setdir): ], ) def test_non_port_is_error(mode, flag, expected): - file = "configs/blank_file.yaml" + file = "blank_file.yaml" with pytest.raises(SystemExit): cli.parse_cli_args([mode, flag, expected, file]) @@ -182,27 +170,29 @@ def test_can_override_ip(mode, flag, expected): assert vars(args)[params[flag]] == expected -async def test_sigint_kills_server(server): +def test_sigint_kills_server(server): server.send_signal(signal.SIGINT) + server.wait(SERVER_TIMEOUT) -async def test_improv_list_nonempty(server): +def test_improv_list_nonempty(server): proc_list = cli.run_list("", printit=False) assert len(proc_list) > 0 server.send_signal(signal.SIGINT) + server.wait(SERVER_TIMEOUT) -async def test_improv_kill_empties_list(server): +def test_improv_kill_empties_list(server): proc_list = cli.run_list("", printit=False) assert len(proc_list) > 0 cli.run_cleanup("", headless=True) proc_list = cli.run_list("", printit=False) assert len(proc_list) == 0 + server.wait(SERVER_TIMEOUT) -async def test_improv_run_writes_stderr_to_log(setdir, ports): - os.chdir("configs") - control_port, output_port, logging_port = ports +def test_improv_run_writes_stderr_to_log(setdir, ports): + control_port, output_port, logging_port, actor_in_port = ports # start server server_opts = [ @@ -223,17 +213,19 @@ async def test_improv_run_writes_stderr_to_log(setdir, ports): server = subprocess.Popen( server_opts, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL ) - await asyncio.sleep(SERVER_WARMUP) + time.sleep(SERVER_WARMUP) server.kill() server.wait(SERVER_TIMEOUT) - with open("testlog") as log: + with open("improv-debug.log") as log: contents = log.read() + print(contents) assert "Traceback" in contents + os.remove("improv-debug.log") os.remove("testlog") cli.run_cleanup("", headless=True) -async def test_get_ports_from_logfile(setdir): +def test_get_ports_from_logfile(setdir): test_control_port = 53349 test_output_port = 53350 test_logging_port = 53351 @@ -258,8 +250,8 @@ async def test_get_ports_from_logfile(setdir): assert logging_port == test_logging_port -async def test_no_server_start_in_logfile_raises_error(setdir, cli_args, capsys): - with open(cli_args.logfile, mode="w") as f: +def test_no_server_start_in_logfile_raises_error(setdir, cli_args, capsys): + with open("improv-debug.log", mode="w") as f: f.write("this is some placeholder text") cli.get_server_ports(cli_args, timeout=1) @@ -267,18 +259,18 @@ async def test_no_server_start_in_logfile_raises_error(setdir, cli_args, capsys) captured = capsys.readouterr() assert "Unable to read server start time" in captured.out - os.remove(cli_args.logfile) + os.remove("improv-debug.log") cli.run_cleanup("", headless=True) -async def test_no_ports_in_logfile_raises_error(setdir, cli_args, capsys): +def test_no_ports_in_logfile_raises_error(setdir, cli_args, capsys): curr_dt = datetime.datetime.now().replace(microsecond=0) - with open(cli_args.logfile, mode="w") as f: + with open("improv-debug.log", mode="w") as f: f.write(f"{curr_dt} Server running on (control, output, log) ports XXX\n") cli.get_server_ports(cli_args, timeout=1) captured = capsys.readouterr() - assert f"Unable to read ports from {cli_args.logfile}." in captured.out + assert f"Unable to read ports from {'improv-debug.log'}." in captured.out - os.remove(cli_args.logfile) + os.remove("improv-debug.log") cli.run_cleanup("", headless=True) diff --git a/test/test_config.py b/test/test_config.py index 12cc79bf..f0230dcc 100644 --- a/test/test_config.py +++ b/test/test_config.py @@ -6,13 +6,14 @@ # from importlib import import_module # from improv.config import RepeatedActorError -from improv.config import Config +from improv.config import Config, CannotCreateConfigException from improv.utils import checks import logging logger = logging.getLogger(__name__) + # set global variables @@ -34,61 +35,42 @@ def test_init(test_input, set_configdir): """ cfg = Config(test_input) - assert cfg.configFile == test_input - - -# def test_init_attributes(): -# """ Tests if config has correct default attributes on initialization. - -# Checks if actors, connection, and hasGUI are all empty or -# nonexistent. Detects errors by maintaining a list of errors, and -# then adding to it every time an unexpected behavior is encountered. - -# Asserts: -# If the default attributes are empty or nonexistent. + assert cfg.config_file == test_input -# """ -# cfg = config() -# errors = [] - -# if(cfg.actors != {}): -# errors.append("config.actors is not empty! ") -# if(cfg.connections != {}): -# errors.append("config.connections is not empty! ") -# if(cfg.hasGUI): -# errors.append("config.hasGUI already exists! ") - -# assert not errors, "The following errors occurred:\n{}".format( -# "\n".join(errors)) - - -def test_createConfig_settings(set_configdir): +def test_create_config_settings(set_configdir): """Check if the default way config creates config.settings is correct. Asserts: - If the default setting is the dictionary {"use_watcher": "None"} + If the default setting is the config as a dict. """ cfg = Config("good_config.yaml") - cfg.createConfig() - assert cfg.settings == {"use_watcher": False} + cfg.parse_config() + cfg.create_config() + assert cfg.settings == { + "control_port": 5555, + "output_port": 5556, + "store_size": 250_000_000, + "actor_in_port": 0, + "harvest_data_from_memory": None, + } -# File with syntax error cannot pass the format check -# def test_createConfig_init_typo(set_configdir): -# """Tests if createConfig can catch actors with errors in init function. +def test_create_config_init_typo(set_configdir): + """Tests if createConfig can catch actors with errors in init function. -# Asserts: -# If createConfig raise any errors. -# """ + Asserts: + If createConfig raise any errors. + """ -# cfg = config("minimal_wrong_init.yaml") -# res = cfg.createConfig() -# assert res == -1 + cfg = Config("minimal_wrong_init.yaml") + cfg.parse_config() + res = cfg.create_config() + assert res == -1 -def test_createConfig_wrong_import(set_configdir): +def test_create_config_wrong_import(set_configdir): """Tests if createConfig can catch actors with errors during import. Asserts: @@ -96,11 +78,12 @@ def test_createConfig_wrong_import(set_configdir): """ cfg = Config("minimal_wrong_import.yaml") - res = cfg.createConfig() + cfg.parse_config() + res = cfg.create_config() assert res == -1 -def test_createConfig_clean(set_configdir): +def test_create_config_clean(set_configdir): """Tests if createConfig runs without error given a good config. Asserts: @@ -108,52 +91,57 @@ def test_createConfig_clean(set_configdir): """ cfg = Config("good_config.yaml") + cfg.parse_config() try: - cfg.createConfig() + cfg.create_config() except Exception as exc: pytest.fail(f"createConfig() raised an exception {exc}") -def test_createConfig_noActor(set_configdir): +def test_create_config_no_actor(set_configdir): """Tests if AttributeError is raised when there are no actors.""" cfg = Config("no_actor.yaml") + cfg.parse_config() with pytest.raises(AttributeError): - cfg.createConfig() + cfg.create_config() -def test_createConfig_ModuleNotFound(set_configdir): +def test_create_config_module_not_found(set_configdir): """Tests if an error is raised when the package can"t be found.""" cfg = Config("bad_package.yaml") - res = cfg.createConfig() + cfg.parse_config() + res = cfg.create_config() assert res == -1 -def test_createConfig_class_ImportError(set_configdir): +def test_create_config_class_import_error(set_configdir): """Tests if an error is raised when the class name is invalid.""" cfg = Config("bad_class.yaml") - res = cfg.createConfig() + cfg.parse_config() + res = cfg.create_config() assert res == -1 -def test_createConfig_AttributeError(set_configdir): +def test_create_config_attribute_error(set_configdir): """Tests if AttributeError is raised.""" cfg = Config("bad_class.yaml") - res = cfg.createConfig() + cfg.parse_config() + res = cfg.create_config() assert res == -1 -def test_createConfig_blank_file(set_configdir): +def test_create_config_blank_file(set_configdir): """Tests if a blank config file raises an error.""" - with pytest.raises(TypeError): + with pytest.raises(CannotCreateConfigException): Config("blank_file.yaml") -def test_createConfig_nonsense_file(set_configdir, caplog): +def test_create_config_nonsense_file(set_configdir, caplog): """Tests if an improperly formatted config raises an error.""" with pytest.raises(TypeError): @@ -170,12 +158,13 @@ def test_cyclic_graph(set_configdir): assert not checks.check_if_connections_acyclic(path) -def test_saveActors_clean(set_configdir): +def test_save_actors_clean(set_configdir): """Compares internal actor representation to what was saved in the file.""" cfg = Config("good_config.yaml") - cfg.createConfig() - cfg.saveActors() + cfg.parse_config() + cfg.create_config() + cfg.save_actors() with open("good_config_actors.yaml") as savedConfig: data = yaml.safe_load(savedConfig) @@ -185,9 +174,116 @@ def test_saveActors_clean(set_configdir): assert savedKeys == originalKeys + os.remove("good_config_actors.yaml") + def test_config_settings_read(set_configdir): cfg = Config("minimal_with_settings.yaml") - cfg.createConfig() + cfg.parse_config() + cfg.create_config() assert "store_size" in cfg.settings + + +def test_config_bad_actor_args(set_configdir): + cfg = Config("bad_args.yaml") + cfg.parse_config() + res = cfg.create_config() + assert res == -1 + + +def test_config_harvester_disabled(set_configdir): + cfg = Config("minimal.yaml") + cfg.config = dict() + cfg.config["settings"] = dict() + cfg.parse_config() + assert cfg.settings["harvest_data_from_memory"] is None + + +def test_config_harvester_enabled(set_configdir): + cfg = Config("minimal.yaml") + cfg.config = dict() + cfg.config["settings"] = dict() + cfg.config["settings"]["harvest_data_from_memory"] = True + cfg.parse_config() + assert cfg.settings["harvest_data_from_memory"] + + +def test_config_redis_aof_enabled_saving_not_specified(set_configdir): + cfg = Config("minimal.yaml") + cfg.config = dict() + cfg.config["redis_config"] = dict() + cfg.config["redis_config"]["aof_dirname"] = "test" + cfg.parse_config() + assert cfg.redis_config["aof_dirname"] == "test" + assert cfg.redis_config["enable_saving"] is True + + +def test_config_redis_ephemeral_dirname_enabled_saving_not_specified(set_configdir): + cfg = Config("minimal.yaml") + cfg.config = dict() + cfg.config["redis_config"] = dict() + cfg.config["redis_config"]["generate_ephemeral_aof_dirname"] = True + cfg.parse_config() + assert cfg.redis_config["generate_ephemeral_aof_dirname"] is True + assert cfg.redis_config["enable_saving"] is True + + +def test_config_redis_ephemeral_dirname_and_aof_dirname_specified(set_configdir): + cfg = Config("minimal.yaml") + cfg.config = dict() + cfg.config["redis_config"] = dict() + cfg.config["redis_config"]["generate_ephemeral_aof_dirname"] = True + cfg.config["redis_config"]["aof_dirname"] = "test" + with pytest.raises( + Exception, match="Cannot use unique dirname and use the one provided." + ): + cfg.parse_config() + + +def test_config_redis_ephemeral_dirname_enabled_saving_disabled(set_configdir): + cfg = Config("minimal.yaml") + cfg.config = dict() + cfg.config["redis_config"] = dict() + cfg.config["redis_config"]["generate_ephemeral_aof_dirname"] = True + cfg.config["redis_config"]["enable_saving"] = False + with pytest.raises(Exception, match="Cannot persist to disk with saving disabled."): + cfg.parse_config() + + +def test_config_redis_aof_dirname_enabled_saving_disabled(set_configdir): + cfg = Config("minimal.yaml") + cfg.config = dict() + cfg.config["redis_config"] = dict() + cfg.config["redis_config"]["aof_dirname"] = "test" + cfg.config["redis_config"]["enable_saving"] = False + with pytest.raises(Exception, match="Cannot persist to disk with saving disabled."): + cfg.parse_config() + + +def test_config_redis_fsync_enabled_saving_disabled(set_configdir): + cfg = Config("minimal.yaml") + cfg.config = dict() + cfg.config["redis_config"] = dict() + cfg.config["redis_config"]["fsync_frequency"] = "always" + cfg.config["redis_config"]["enable_saving"] = False + with pytest.raises(Exception, match="Cannot persist to disk with saving disabled."): + cfg.parse_config() + + +def test_config_redis_unknown_fsync_freq(set_configdir): + cfg = Config("minimal.yaml") + cfg.config = dict() + cfg.config["redis_config"] = dict() + cfg.config["redis_config"]["fsync_frequency"] = "unknown" + with pytest.raises(Exception, match="Cannot use unknown fsync frequency unknown"): + cfg.parse_config() + + +def test_config_gui(set_configdir): + cfg = Config("minimal_gui.yaml") + cfg.parse_config() + cfg.create_config() + + assert cfg.hasGUI is True + assert cfg.gui.classname == "Generator" diff --git a/test/test_demos.py b/test/test_demos.py index bc6e96c6..97f16f0e 100644 --- a/test/test_demos.py +++ b/test/test_demos.py @@ -4,18 +4,15 @@ import os import asyncio import subprocess + import improv.tui as tui -import concurrent.futures import logging -from demos.sample_actors.zmqActor import ZmqActor -from test_nexus import ports - LOGGER = logging.getLogger(__name__) LOGGER.setLevel(logging.DEBUG) - SERVER_WARMUP = 10 +UNUSED_TCP_PORT = 10567 @pytest.fixture @@ -35,17 +32,25 @@ def ip(): return pytest.ip +@pytest.fixture +def unused_tcp_port(): + global UNUSED_TCP_PORT + pytest.unused_tcp_port = UNUSED_TCP_PORT + yield pytest.unused_tcp_port + UNUSED_TCP_PORT += 1 + + +@pytest.mark.asyncio @pytest.mark.parametrize( ("dir", "configfile", "logfile"), [ ("minimal", "minimal.yaml", "testlog"), - ("minimal", "minimal_plasma.yaml", "testlog"), ], ) async def test_simple_boot_and_quit(dir, configfile, logfile, setdir, ports): os.chdir(dir) - control_port, output_port, logging_port = ports + control_port, output_port, logging_port, actor_in_port = ports # start server server_opts = [ @@ -81,21 +86,22 @@ async def test_simple_boot_and_quit(dir, configfile, logfile, setdir, ports): assert not pilot.app._running # wait on server to fully shut down - server.wait(10) - os.remove(logfile) # later, might want to read this file and check for messages + server.wait(15) + # os.remove(logfile) # later, might want to read this file and check for messages +@pytest.mark.asyncio @pytest.mark.parametrize( ("dir", "configfile", "logfile", "datafile"), [ ("minimal", "minimal.yaml", "testlog", "sample_generator_data.npy"), - ("minimal", "minimal_spawn.yaml", "testlog", "sample_generator_data.npy"), + ("minimal", "minimal_persistence.yaml", "testlog", "test_persistence.csv"), ], ) async def test_stop_output(dir, configfile, logfile, datafile, setdir, ports): os.chdir(dir) - control_port, output_port, logging_port = ports + control_port, output_port, logging_port, actor_in_port = ports # start server server_opts = [ @@ -143,61 +149,60 @@ async def test_stop_output(dir, configfile, logfile, datafile, setdir, ports): os.remove(logfile) # later, might want to read this file and check for messages -def test_zmq_ps(ip, unused_tcp_port): - """Tests if we can set the zmq PUB/SUB socket and send message.""" - port = unused_tcp_port - LOGGER.info("beginning test") - act1 = ZmqActor("act1", type="PUB", ip=ip, port=port) - act2 = ZmqActor("act2", type="SUB", ip=ip, port=port) - LOGGER.info("ZMQ Actors constructed") - # Note these sockets must be set up for testing - # this is not needed for running in improv - act1.setSendSocket() - act2.setRecvSocket() - - msg = "hello" - act1.put(msg) - LOGGER.info("sent message") - recvmsg = act2.get() - LOGGER.info("received message") - assert recvmsg == msg - - -def test_zmq_rr(ip, unused_tcp_port): - """Tests if we can set the zmq REQ/REP socket and send message.""" - port = unused_tcp_port - act1 = ZmqActor("act1", "/tmp/store", type="REQ", ip=ip, port=port) - act2 = ZmqActor("act2", "/tmp/store", type="REP", ip=ip, port=port) - msg = "hello" - reply = "world" - - def handle_request(): - return act1.put(msg) - - def handle_reply(): - return act2.get(reply) - - # Use a ThreadPoolExecutor to run handle_request() - # and handle_reply() in separate threads. - - with concurrent.futures.ThreadPoolExecutor() as executor: - future1 = executor.submit(handle_request) - future2 = executor.submit(handle_reply) - - # Ensure the request is sent before the reply. - request_result = future1.result() - reply_result = future2.result() - - # Check if the received message is equal to the original message. - assert reply_result == msg - # Check if the reply is correct. - assert request_result == reply - - -def test_zmq_rr_timeout(ip, unused_tcp_port): - """Test for requestMsg where we timeout or fail to send""" - port = unused_tcp_port - act1 = ZmqActor("act1", "/tmp/store", type="REQ", ip=ip, port=port) - msg = "hello" - replymsg = act1.put(msg) - assert replymsg is None +@pytest.mark.skip +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("dir", "configfile", "logfile", "datafile"), + [ + ("minimal", "minimal_spawn.yaml", "testlog", "sample_generator_data.npy"), + ], +) +async def test_stop_output_spawn(dir, configfile, logfile, datafile, setdir, ports): + os.chdir(dir) + + control_port, output_port, logging_port, actor_in_port = ports + + # start server + server_opts = [ + "improv", + "server", + "-c", + str(control_port), + "-o", + str(output_port), + "-l", + str(logging_port), + "-f", + logfile, + configfile, + ] + + with open(logfile, mode="a+") as log: + server = subprocess.Popen(server_opts, stdout=log, stderr=log) + await asyncio.sleep(SERVER_WARMUP) + + # initialize client + app = tui.TUI(control_port, output_port, logging_port) + + # run client + async with app.run_test() as pilot: + print("running pilot") + await pilot.press(*"setup", "enter") + await pilot.pause(0.5) + await pilot.press(*"run", "enter") + await pilot.pause(1) + await pilot.press(*"stop", "enter") + await pilot.pause(2) + await pilot.press(*"quit", "enter") + await pilot.pause(3) + assert not pilot.app._running + + # wait on server to fully shut down + server.wait(10) + + # check that the file written by Generator's stop function got written + os.path.isfile(datafile) + + # then remove that file and logile + os.remove(datafile) + os.remove(logfile) # later, might want to read this file and check for messages diff --git a/test/test_harvester.py b/test/test_harvester.py new file mode 100644 index 00000000..7f078573 --- /dev/null +++ b/test/test_harvester.py @@ -0,0 +1,224 @@ +import logging +import signal +import time + +import pytest +import zmq +from zmq.backend.cython._zmq import SocketOption + +from improv.harvester import RedisHarvester +from improv.store import RedisStoreInterface + +from improv.link import ZmqLink +from improv.messaging import HarvesterInfoReplyMsg +from conftest import SignalManager + + +def test_harvester_shuts_down_on_sigint(setup_store, harvester): + harvester_ports, broker_socket, p = harvester + ctx = zmq.Context() + s = ctx.socket(zmq.REP) + s.bind(f"tcp://*:{harvester_ports[3]}") + s.recv_pyobj() + reply = HarvesterInfoReplyMsg("harvester", "OK", "") + s.send_pyobj(reply) + time.sleep(2) + p.terminate() + p.join(5) + if p.exitcode is None: + p.kill() + pytest.fail("Harvester did not exit in time") + else: + assert True + + s.close(linger=0) + ctx.destroy(linger=0) + + +def test_harvester_relieves_memory_pressure(setup_store, harvester): + store_interface = RedisStoreInterface() + harvester_ports, broker_socket, p = harvester + broker_link = ZmqLink(broker_socket, "test", "test topic") + ctx = zmq.Context() + s = ctx.socket(zmq.REP) + s.bind(f"tcp://*:{harvester_ports[3]}") + s.recv_pyobj() + reply = HarvesterInfoReplyMsg("harvester", "OK", "") + s.send_pyobj(reply) + client = RedisStoreInterface() + for i in range(10): + message = [i for i in range(500000)] + key = client.put(message) + broker_link.put(key) + time.sleep(2) + + db_info = store_interface.client.info() + max_memory = db_info["maxmemory"] + used_memory = db_info["used_memory"] + used_max_ratio = used_memory / max_memory + assert used_max_ratio <= 0.75 + + p.terminate() + p.join(5) + if p.exitcode is None: + p.kill() + pytest.fail("Harvester did not exit in time") + + s.close(linger=0) + ctx.destroy(linger=0) + + +def test_harvester_stop_logs_and_halts_running(): + with SignalManager(): + ctx = zmq.Context() + s = ctx.socket(zmq.PULL) + s.bind("tcp://*:0") + pull_port_string = s.getsockopt_string(SocketOption.LAST_ENDPOINT) + pull_port = int(pull_port_string.split(":")[-1]) + harvester = RedisHarvester( + nexus_hostname="localhost", + nexus_comm_port=10000, # never gets called in this test + redis_hostname="localhost", + redis_port=10000, # never gets called in this test + broker_hostname="localhost", + broker_port=10000, # never gets called in this test + logger_hostname="localhost", + logger_port=pull_port, + ) + harvester.stop(signal.SIGINT, None) + assert not harvester.running + msg_available = s.poll(timeout=1000) + assert msg_available + record = s.recv_json() + assert ( + record["message"] + == f"Harvester shutting down due to signal {signal.SIGINT}" + ) + + s.close(linger=0) + ctx.destroy(linger=0) + + +# @pytest.mark.skip +def test_harvester_relieves_memory_pressure_one_loop(ports, setup_store): + def harvest_and_quit(harvester_instance: RedisHarvester): + harvester_instance.collect() + harvester_instance.stop(signal.SIGINT, None) + + with SignalManager(): + try: + ctx = zmq.Context() + nex_s = ctx.socket(zmq.REP) + nex_s.bind(f"tcp://*:{ports[3]}") + + broker_s = ctx.socket(zmq.PUB) + broker_s.bind("tcp://*:1234") + broker_link = ZmqLink(broker_s, "test", "test topic") + + log_s = ctx.socket(zmq.PULL) + log_s.bind("tcp://*:0") + pull_port_string = log_s.getsockopt_string(SocketOption.LAST_ENDPOINT) + pull_port = int(pull_port_string.split(":")[-1]) + + harvester = RedisHarvester( + nexus_hostname="localhost", + nexus_comm_port=ports[3], # never gets called in this test + redis_hostname="localhost", + redis_port=6379, # never gets called in this test + broker_hostname="localhost", + broker_port=1234, # never gets called in this test + logger_hostname="localhost", + logger_port=pull_port, + ) + + harvester.establish_connections() + + client = RedisStoreInterface() + for i in range(9): + message = [i for i in range(500000)] + try: + key = client.put(message) + broker_link.put(key) + except Exception as e: + logging.warning(e) + logging.warning("Proceeding under the assumption Redis is full.") + break + time.sleep(2) + + harvester.serve(harvest_and_quit, harvester_instance=harvester) + + db_info = client.client.info() + max_memory = db_info["maxmemory"] + used_memory = db_info["used_memory"] + used_max_ratio = used_memory / max_memory + + assert used_max_ratio <= 0.5 + assert not harvester.running + assert harvester.nexus_socket.closed + assert harvester.sub_socket.closed + assert harvester.zmq_context.closed + logging.info("Passed harvester one loop test") + + except Exception as e: + logging.error(f"Encountered exception {e} during execution of test") + print(e) + pass + + try: + + nex_s.close(linger=0) + broker_s.close(linger=0) + log_s.close(linger=0) + ctx.destroy(linger=0) + except Exception as e: + logging.error(f"Encountered exception {e} during teardown") + print(e) + pass + + +def test_harvester_loops_with_no_memory_pressure(ports, setup_store): + def harvest_and_quit(harvester_instance: RedisHarvester): + harvester_instance.collect() + harvester_instance.stop(signal.SIGINT, None) + + with SignalManager(): + + ctx = zmq.Context() + nex_s = ctx.socket(zmq.REP) + nex_s.bind(f"tcp://*:{ports[3]}") + + log_s = ctx.socket(zmq.PULL) + log_s.bind("tcp://*:0") + pull_port_string = log_s.getsockopt_string(SocketOption.LAST_ENDPOINT) + pull_port = int(pull_port_string.split(":")[-1]) + + harvester = RedisHarvester( + nexus_hostname="localhost", + nexus_comm_port=ports[3], # never gets called in this test + redis_hostname="localhost", + redis_port=6379, # never gets called in this test + broker_hostname="localhost", + broker_port=1234, # never gets called in this test + logger_hostname="localhost", + logger_port=pull_port, + ) + + harvester.establish_connections() + + client = RedisStoreInterface() + + harvester.serve(harvest_and_quit, harvester_instance=harvester) + + db_info = client.client.info() + max_memory = db_info["maxmemory"] + used_memory = db_info["used_memory"] + used_max_ratio = used_memory / max_memory + assert used_max_ratio <= 0.5 + assert not harvester.running + assert harvester.nexus_socket.closed + assert harvester.sub_socket.closed + assert harvester.zmq_context.closed + + nex_s.close(linger=0) + log_s.close(linger=0) + ctx.destroy(linger=0) diff --git a/test/test_link.py b/test/test_link.py index eec9a521..90056c20 100644 --- a/test/test_link.py +++ b/test/test_link.py @@ -1,170 +1,137 @@ -import asyncio -import queue -import subprocess +import json import time import pytest - +import zmq from improv.actor import Actor +from zmq import SocketOption -from improv.link import Link +from improv.link import ZmqLink -def init_actors(n=1): - """Function to return n unique actors. +@pytest.fixture +def test_sub_link(): + """Fixture to provide a commonly used Link object.""" + ctx = zmq.Context() + link_socket = ctx.socket(zmq.SUB) + link_socket.bind("tcp://*:0") + link_socket_string = link_socket.getsockopt_string(SocketOption.LAST_ENDPOINT) + link_socket_port = int(link_socket_string.split(":")[-1]) - Returns: - list: A list of n actors, each being named after its index. - """ + link_socket.poll(timeout=0) - # the links must be specified as an empty dictionary to avoid - # actors sharing a dictionary of links + link_pub_socket = ctx.socket(zmq.PUB) + link_pub_socket.connect(f"tcp://localhost:{link_socket_port}") - return [Actor("test " + str(i), "/tmp/store", links={}) for i in range(n)] + link_socket.poll(timeout=0) # open libzmq bug (not pyzmq) + link_socket.subscribe("test_topic") + time.sleep(0.5) + link_socket.poll(timeout=0) + link = ZmqLink(link_socket, "test_link", "test_topic") -@pytest.fixture -def example_link(setup_store): - """Fixture to provide a commonly used Link object.""" - act = init_actors(2) - lnk = Link("Example", act[0].name, act[1].name) - yield lnk - lnk = None + yield link, link_pub_socket + link.socket.close(linger=0) + link_pub_socket.close(linger=0) + ctx.destroy(linger=0) @pytest.fixture -def example_actor_system(setup_store): - """Fixture to provide a list of 4 connected actors.""" - - # store = setup_store - acts = init_actors(4) - - L01 = Link("L01", acts[0].name, acts[1].name) - L13 = Link("L13", acts[1].name, acts[3].name) - L12 = Link("L12", acts[1].name, acts[2].name) - L23 = Link("L23", acts[2].name, acts[3].name) - - links = [L01, L13, L12, L23] +def test_pub_link(): + """Fixture to provide a commonly used Link object.""" + ctx = zmq.Context() + link_socket = ctx.socket(zmq.PUB) + link_socket.bind("tcp://*:0") + link_socket_string = link_socket.getsockopt_string(SocketOption.LAST_ENDPOINT) + link_socket_port = int(link_socket_string.split(":")[-1]) - acts[0].addLink("q_out_1", L01) - acts[1].addLink("q_out_1", L13) - acts[1].addLink("q_out_2", L12) - acts[2].addLink("q_out_1", L23) + link_sub_socket = ctx.socket(zmq.SUB) + link_sub_socket.connect(f"tcp://localhost:{link_socket_port}") + link_sub_socket.poll(timeout=0) + link_sub_socket.subscribe("test_topic") + time.sleep(0.5) + link_sub_socket.poll(timeout=0) - acts[1].addLink("q_in_1", L01) - acts[2].addLink("q_in_1", L12) - acts[3].addLink("q_in_1", L13) - acts[3].addLink("q_in_2", L23) + link = ZmqLink(link_socket, "test_link", "test_topic") - yield [acts, links] # also yield Links - acts = None + yield link, link_sub_socket + link.socket.close(linger=0) + link_sub_socket.close(linger=0) + ctx.destroy(linger=0) @pytest.fixture -def _kill_pytest_processes(): - """Kills all processes with "pytest" in their name. +def test_req_link(): + """Fixture to provide a commonly used Link object.""" + ctx = zmq.Context() + link_socket = ctx.socket(zmq.REQ) + link_socket.bind("tcp://*:0") + link_socket_string = link_socket.getsockopt_string(SocketOption.LAST_ENDPOINT) + link_socket_port = int(link_socket_string.split(":")[-1]) - NOTE: - This fixture should only be used at the end of testing. - """ + link_rep_socket = ctx.socket(zmq.REP) + link_rep_socket.connect(f"tcp://localhost:{link_socket_port}") - subprocess.Popen( - ["kill", "`pgrep pytest`"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL - ) + link = ZmqLink(link_socket, "test_link") - -@pytest.mark.parametrize( - ("attribute", "expected"), - [ - ("name", "Example"), - ("real_executor", None), - ("cancelled_join", False), - ("status", "pending"), - ("result", None), - ], -) -def test_Link_init(setup_store, example_link, attribute, expected): - """Tests if the default initialization attributes are set.""" - - lnk = example_link - atr = getattr(lnk, attribute) - assert atr == expected + yield link, link_rep_socket + link.socket.close(linger=0) + link_rep_socket.close(linger=0) + ctx.destroy(linger=0) -def test_Link_init_start_end(setup_store): - """Tests if the initialization has the right actors.""" +@pytest.fixture +def test_rep_link(): + """Fixture to provide a commonly used Link object.""" + ctx = zmq.Context() + link_socket = ctx.socket(zmq.REP) + link_socket.bind("tcp://*:0") + link_socket_string = link_socket.getsockopt_string(SocketOption.LAST_ENDPOINT) + link_socket_port = int(link_socket_string.split(":")[-1]) - act = init_actors(2) - lnk = Link("example_link", act[0].name, act[1].name) + link_req_socket = ctx.socket(zmq.REQ) + link_req_socket.connect(f"tcp://localhost:{link_socket_port}") - assert lnk.start == act[0].name - assert lnk.end == act[1].name + link = ZmqLink(link_socket, "test_link") + yield link, link_req_socket + link.socket.close(linger=0) + link_req_socket.close(linger=0) + ctx.destroy(linger=0) -def test_getstate(example_link): - """Tests if __getstate__ has the right values on initialization. - Gets the dictionary of the link, then compares them against known - default values. Does not compare store and actors. +def test_pub_put(test_pub_link): + """Tests if messages can be put into the link. TODO: - Compare store and actors. + Parametrize multiple test input types. """ - res = example_link.__getstate__() - errors = [] - errors.append(res["_real_executor"] is None) - errors.append(res["cancelled_join"] is False) - - assert all(errors) - - -@pytest.mark.parametrize( - "input", - [([None]), ([1]), ([i for i in range(5)]), ([str(i**i) for i in range(10)])], -) -def test_qsize_empty(example_link, input): - """Tests that the queue has the number of elements in "input".""" - - lnk = example_link - for i in input: - lnk.put(i) - - qsize = lnk.queue.qsize() - assert qsize == len(input) - - -def test_getStart(example_link): - """Tests if getStart returns the starting actor.""" - - lnk = example_link - - assert lnk.getStart() == Actor("test 0", "/tmp/store").name - - -def test_getEnd(example_link): - """Tests if getEnd returns the ending actor.""" - - lnk = example_link + link, link_sub_socket = test_pub_link + msg = "message" - assert lnk.getEnd() == Actor("test 1", "/tmp/store").name + link.put(msg) + res = link_sub_socket.recv_multipart() + assert res[0].decode("utf-8") == "test_topic" + assert json.loads(res[1].decode("utf-8")) == msg -def test_put(example_link): +def test_req_put(test_req_link): """Tests if messages can be put into the link. TODO: Parametrize multiple test input types. """ - lnk = example_link + link, link_rep_socket = test_req_link msg = "message" - lnk.put(msg) - assert lnk.get() == "message" + link.put(msg) + res = link_rep_socket.recv_pyobj() + assert res == msg -def test_put_unserializable(example_link, caplog, setup_store): +def test_put_unserializable(test_pub_link): """Tests if an unserializable object raises an error. Instantiates an actor, which is unserializable, and passes it into @@ -173,185 +140,199 @@ def test_put_unserializable(example_link, caplog, setup_store): Raises: SerializationCallbackError: Actor objects are unserializable. """ - # store = setup_store act = Actor("test", "/tmp/store") - lnk = example_link + link, link_sub_socket = test_pub_link sentinel = True try: - lnk.put(act) + link.put(act) except Exception: sentinel = False - assert sentinel, "Unable to put" - assert str(lnk.get()) == str(act) + assert not sentinel -def test_put_irreducible(example_link, setup_store): +def test_put_irreducible(test_pub_link, setup_store): """Tests if an irreducible object raises an error.""" - lnk = example_link + link, link_sub_socket = test_pub_link store = setup_store with pytest.raises(TypeError): - lnk.put(store) + link.put(store) -def test_put_nowait(example_link): +def test_put_nowait(test_pub_link): """Tests if messages can be put into the link without blocking. TODO: Parametrize multiple test input types. """ - lnk = example_link + link, link_sub_socket = test_pub_link msg = "message" t_0 = time.perf_counter() - lnk.put_nowait(msg) + link.put(msg) # put is already async even in synchronous zmq t_1 = time.perf_counter() t_net = t_1 - t_0 assert t_net < 0.005 # 5 ms -@pytest.mark.asyncio -async def test_put_async_success(example_link): - """Tests if put_async returns None. - - TODO: - Parametrize test input. - """ - - lnk = example_link - msg = "message" - res = await lnk.put_async(msg) - assert res is None - - -@pytest.mark.asyncio -async def test_put_async_multiple(example_link): +def test_put_multiple(test_pub_link): """Tests if async putting multiple objects preserves their order.""" + link, link_sub_socket = test_pub_link + messages = [str(i) for i in range(10)] messages_out = [] for msg in messages: - await example_link.put_async(msg) + link.put(msg) for msg in messages: - messages_out.append(example_link.get()) + messages_out.append( + json.loads(link_sub_socket.recv_multipart()[1].decode("utf-8")) + ) assert messages_out == messages @pytest.mark.asyncio -async def test_put_and_get_async(example_link): +async def test_put_and_get_async(test_pub_link): """Tests if async get preserves order after async put.""" messages = [str(i) for i in range(10)] messages_out = [] + link, link_sub_socket = test_pub_link + for msg in messages: - await example_link.put_async(msg) + await link.put_async(msg) for msg in messages: - messages_out.append(await example_link.get_async()) + messages_out.append( + json.loads(link_sub_socket.recv_multipart()[1].decode("utf-8")) + ) assert messages_out == messages -@pytest.mark.skip( - reason="This test needs additional work to cause an overflow in the datastore." +@pytest.mark.parametrize( + "message", + [ + "message", + "", + None, + [str(i) for i in range(5)], + ], ) -def test_put_overflow(setup_store, server_port_num, caplog): - """Tests if putting too large of an object raises an error.""" - - p = subprocess.Popen( - ["redis-server", "--port", str(server_port_num), "--maxmemory", str(1000)], - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - ) - - acts = init_actors(2) - lnk = Link("L1", acts[0], acts[1]) - - message = [i for i in range(10**6)] # 24000 bytes +@pytest.mark.parametrize( + "timeout", + [ + None, + 5, + ], +) +def test_sub_get(test_sub_link, message, timeout): + """Tests if get gets the correct element from the queue.""" - lnk.put(message) + link, link_pub_socket = test_sub_link - p.kill() - p.wait() + time.sleep(1) - if caplog.records: - for record in caplog.records: - if "PlasmaStoreInterfaceFull" in record.msg: - assert True + if type(message) is list: + for i in message: + link_pub_socket.send_multipart( + [link.topic.encode("utf-8"), json.dumps(i).encode("utf-8")] + ) + expected = message[0] else: - pytest.fail("expected an error!") + link_pub_socket.send_multipart( + [link.topic.encode("utf-8"), json.dumps(message).encode("utf-8")] + ) + expected = message + + assert link.get(timeout=timeout) == expected @pytest.mark.parametrize( "message", [ - ("message"), - (""), - (None), - ([str(i) for i in range(5)]), + "message", + "", + None, + [str(i) for i in range(5)], + ], +) +@pytest.mark.parametrize( + "timeout", + [ + None, + 5, ], ) -def test_get(example_link, message): +def test_rep_get(test_rep_link, message, timeout): """Tests if get gets the correct element from the queue.""" - lnk = example_link + link, link_req_socket = test_rep_link - if type(message) is list: - for i in message: - lnk.put(i) - expected = message[0] - else: - lnk.put(message) - expected = message + link_req_socket.send_pyobj(message) + + expected = message - assert lnk.get() == expected + assert link.get(timeout=timeout) == expected -def test_get_empty(example_link): +def test_get_empty(test_sub_link): """Tests if get blocks if the queue is empty.""" - lnk = example_link - if lnk.queue.empty: - with pytest.raises(queue.Empty): - lnk.get(timeout=5.0) - else: - pytest.fail("expected a timeout!") + link, unused = test_sub_link + + time.sleep(0.1) + + with pytest.raises(TimeoutError): + link.get(timeout=0.5) @pytest.mark.parametrize( "message", [ - ("message"), - (""), + "message", + "", ([str(i) for i in range(5)]), ], ) -def test_get_nowait(example_link, message): +def test_get_nowait(test_sub_link, message): """Tests if get_nowait gets the correct element from the queue.""" - lnk = example_link + link, link_pub_socket = test_sub_link if type(message) is list: for i in message: - lnk.put(i) + link_pub_socket.send_multipart( + [link.topic.encode("utf-8"), json.dumps(i).encode("utf-8")] + ) expected = message[0] else: - lnk.put(message) + link_pub_socket.send_multipart( + [link.topic.encode("utf-8"), json.dumps(message).encode("utf-8")] + ) expected = message + for i in range(20): + time.sleep(0.1) + avail = link.socket.poll(timeout=100) + if avail: + break + else: + pytest.fail("Message was not sent to link after 4s") + t_0 = time.perf_counter() - res = lnk.get_nowait() + res = link.get_nowait() t_1 = time.perf_counter() @@ -359,95 +340,31 @@ def test_get_nowait(example_link, message): assert t_1 - t_0 < 0.005 # 5 msg -def test_get_nowait_empty(example_link): +def test_get_nowait_empty(test_sub_link): """Tests if get_nowait raises an error when the queue is empty.""" - lnk = example_link - if lnk.queue.empty(): - with pytest.raises(queue.Empty): - lnk.get_nowait() - else: - pytest.fail("the queue is not empty") + link, unused = test_sub_link + with pytest.raises(TimeoutError): + link.get_nowait() @pytest.mark.asyncio -async def test_get_async_success(example_link): +async def test_get_async_success(test_sub_link): """Tests if async_get gets the correct element from the queue.""" - lnk = example_link + link, link_pub_socket = test_sub_link msg = "message" - await lnk.put_async(msg) - res = await lnk.get_async() + link_pub_socket.send_multipart( + [link.topic.encode("utf-8"), json.dumps(msg).encode("utf-8")] + ) + res = await link.get_async() assert res == "message" -@pytest.mark.asyncio -async def test_get_async_empty(example_link): - """Tests if get_async times out given an empty queue. - - TODO: - Implement a way to kill the task after execution (subprocess)? - """ - - lnk = example_link - timeout = 5.0 - - with pytest.raises(asyncio.TimeoutError): - task = asyncio.create_task(lnk.get_async()) - await asyncio.wait_for(task, timeout) - task.cancel() - - lnk.put("exit") # this is here to break out of get_async() - - -@pytest.mark.skip(reason="unfinished") -def test_cancel_join_thread(example_link): - """Tests cancel_join_thread. This test is unfinished - - TODO: - Identify where and when cancel_join_thread is being called. - """ - - lnk = example_link - lnk.cancel_join_thread() - - assert lnk._cancelled_join is True - - -@pytest.mark.skip(reason="unfinished") -@pytest.mark.asyncio -async def test_join_thread(example_link): - """Tests join_thread. This test is unfinished - - TODO: - Identify where and when join_thread is being called. - """ - lnk = example_link - await lnk.put_async("message") - # msg = await lnk.get_async() - lnk.join_thread() - assert True - - -@pytest.mark.asyncio -async def test_multi_actor_system(example_actor_system, setup_store): - """Tests if async puts/gets with many actors have good messages.""" - - setup_store - - graph = example_actor_system - - acts = graph[0] - - heavy_msg = [str(i) for i in range(10**6)] - light_msgs = ["message" + str(i) for i in range(3)] - - await acts[0].links["q_out_1"].put_async(heavy_msg) - await acts[1].links["q_out_1"].put_async(light_msgs[0]) - await acts[1].links["q_out_2"].put_async(light_msgs[1]) - await acts[2].links["q_out_1"].put_async(light_msgs[2]) - - assert await acts[1].links["q_in_1"].get_async() == heavy_msg - assert await acts[2].links["q_in_1"].get_async() == light_msgs[1] - assert await acts[3].links["q_in_1"].get_async() == light_msgs[0] - assert await acts[3].links["q_in_2"].get_async() == light_msgs[2] +def test_pub_put_no_topic(): + ctx = zmq.Context() + s = ctx.socket(zmq.PUB) + with pytest.raises(Exception, match="Cannot open PUB link without topic"): + ZmqLink(s, "test") + s.close(linger=0) + ctx.destroy(linger=0) diff --git a/test/test_messaging.py b/test/test_messaging.py new file mode 100644 index 00000000..a8efff3f --- /dev/null +++ b/test/test_messaging.py @@ -0,0 +1,107 @@ +import improv.messaging + + +def test_actor_state_msg(): + name = "test name" + status = "test status" + port = 12345 + info = "test info" + msg = improv.messaging.ActorStateMsg(name, status, port, info) + assert msg.actor_name == name + assert msg.status == status + assert msg.nexus_in_port == port + assert msg.info == info + + +def test_actor_state_reply_msg(): + name = "test name" + status = "test status" + info = "test info" + msg = improv.messaging.ActorStateReplyMsg(name, status, info) + assert msg.actor_name == name + assert msg.status == status + assert msg.info == info + + +def test_actor_signal_msg(): + name = "test name" + signal = "test signal" + info = "test info" + msg = improv.messaging.ActorSignalMsg(name, signal, info) + assert msg.actor_name == name + assert msg.signal == signal + assert msg.info == info + + +def test_actor_signal_reply_msg(): + name = "test name" + status = "test status" + signal = "test_signal" + info = "test info" + msg = improv.messaging.ActorSignalReplyMsg(name, signal, status, info) + assert msg.actor_name == name + assert msg.status == status + assert msg.info == info + assert msg.signal == signal + + +def test_broker_info_msg(): + name = "test name" + pub_port = 54321 + sub_port = 12345 + info = "test info" + msg = improv.messaging.BrokerInfoMsg(name, pub_port, sub_port, info) + assert msg.name == name + assert msg.pub_port == pub_port + assert msg.sub_port == sub_port + assert msg.info == info + + +def test_broker_info_reply_msg(): + name = "test name" + status = "test status" + info = "test info" + msg = improv.messaging.BrokerInfoReplyMsg(name, status, info) + assert msg.name == name + assert msg.status == status + assert msg.info == info + + +def test_log_info_msg(): + name = "test name" + pull_port = 54321 + pub_port = 12345 + info = "test info" + msg = improv.messaging.LogInfoMsg(name, pull_port, pub_port, info) + assert msg.name == name + assert msg.pub_port == pub_port + assert msg.pull_port == pull_port + assert msg.info == info + + +def test_log_info_reply_msg(): + name = "test name" + status = "test status" + info = "test info" + msg = improv.messaging.LogInfoReplyMsg(name, status, info) + assert msg.name == name + assert msg.status == status + assert msg.info == info + + +def test_harvester_info_msg(): + name = "test name" + info = "test info" + msg = improv.messaging.HarvesterInfoMsg(name, info) + assert msg.name == name + assert msg.info == info + + +def test_harvester_info_reply_msg(): + name = "test name" + status = "test status" + info = "test info" + msg = improv.messaging.HarvesterInfoReplyMsg(name, status, info) + assert msg.name == name + assert msg.status == status + assert msg.info == info diff --git a/test/test_nexus.py b/test/test_nexus.py index e1a0e5f6..f4479f33 100644 --- a/test/test_nexus.py +++ b/test/test_nexus.py @@ -2,77 +2,21 @@ import shutil import time import os + import pytest import logging -import subprocess -import signal -import yaml - -from improv.nexus import Nexus -from improv.store import StoreInterface - -# from improv.actor import Actor -# from improv.store import StoreInterface - -SERVER_COUNTER = 0 - - -@pytest.fixture -def ports(): - global SERVER_COUNTER - CONTROL_PORT = 5555 - OUTPUT_PORT = 5556 - LOGGING_PORT = 5557 - yield ( - CONTROL_PORT + SERVER_COUNTER, - OUTPUT_PORT + SERVER_COUNTER, - LOGGING_PORT + SERVER_COUNTER, - ) - SERVER_COUNTER += 3 - - -@pytest.fixture -def setdir(): - prev = os.getcwd() - os.chdir(os.path.dirname(__file__) + "/configs") - yield None - os.chdir(prev) - - -@pytest.fixture -def sample_nex(setdir, ports): - nex = Nexus("test") - nex.createNexus( - file="good_config.yaml", - store_size=40000000, - control_port=ports[0], - output_port=ports[1], - ) - yield nex - nex.destroyNexus() - -# @pytest.fixture -# def setup_store(setdir): -# """ Fixture to set up the store subprocess with 10 mb. +import zmq -# This fixture runs a subprocess that instantiates the store with a -# memory of 10 megabytes. It specifies that "/tmp/store/" is the -# location of the store socket. - -# Yields: -# StoreInterface: An instance of the store. - -# TODO: -# Figure out the scope. -# """ -# setdir -# p = subprocess.Popen( -# ['plasma_store', '-s', '/tmp/store/', '-m', str(10000000)],\ -# stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) -# store = StoreInterface(store_loc = "/tmp/store/") -# yield store -# p.kill() +from improv.config import CannotCreateConfigException +from improv.messaging import ActorStateMsg +from improv.nexus import ( + Nexus, + ConfigFileNotProvidedException, + ConfigFileNotValidException, +) +from improv.store import StoreInterface +from conftest import SignalManager def test_init(setdir): @@ -85,32 +29,24 @@ def test_init(setdir): "cfg_name", [ "good_config.yaml", - "good_config_plasma.yaml", ], ) -def test_createNexus(setdir, ports, cfg_name): +def test_create_nexus(setdir, ports, cfg_name): nex = Nexus("test") - nex.createNexus(file=cfg_name, control_port=ports[0], output_port=ports[1]) - assert list(nex.comm_queues.keys()) == [ - "GUI_comm", - "Acquirer_comm", - "Analysis_comm", - ] - assert list(nex.sig_queues.keys()) == ["Acquirer_sig", "Analysis_sig"] - assert list(nex.data_queues.keys()) == ["Acquirer.q_out", "Analysis.q_in"] + nex.create_nexus(file=cfg_name, control_port=ports[0], output_port=ports[1]) assert list(nex.actors.keys()) == ["Acquirer", "Analysis"] assert list(nex.flags.keys()) == ["quit", "run", "load"] assert nex.processes == [] - nex.destroyNexus() + nex.destroy_nexus() assert True def test_config_logged(setdir, ports, caplog): nex = Nexus("test") - nex.createNexus( + nex.create_nexus( file="minimal_with_settings.yaml", control_port=ports[0], output_port=ports[1] ) - nex.destroyNexus() + nex.destroy_nexus() assert any( [ "not_relevant: for testing purposes" in record.msg @@ -119,54 +55,49 @@ def test_config_logged(setdir, ports, caplog): ) -def test_loadConfig(sample_nex): +def test_load_config(sample_nex): nex = sample_nex - nex.loadConfig("good_config.yaml") - assert set(nex.comm_queues.keys()) == set( - ["Acquirer_comm", "Analysis_comm", "GUI_comm"] + assert any( + [ + link_info.link_name == "q_out" + for link_info in nex.outgoing_topics["Acquirer"] + ] + ) + assert any( + [link_info.link_name == "q_in" for link_info in nex.incoming_topics["Analysis"]] ) def test_argument_config_precedence(setdir, ports): nex = Nexus("test") - nex.createNexus( + nex.create_nexus( file="minimal_with_settings.yaml", control_port=ports[0], output_port=ports[1], store_size=11_000_000, - use_watcher=True, ) cfg = nex.config.settings - nex.destroyNexus() + nex.destroy_nexus() assert cfg["control_port"] == ports[0] assert cfg["output_port"] == ports[1] - assert cfg["store_size"] == 20_000_000 - assert not cfg["use_watcher"] + assert cfg["store_size"] == 11_000_000 -def test_settings_override_random_ports(setdir, ports): - config_file = "minimal_with_settings.yaml" - nex = Nexus("test") - with open(config_file, "r") as ymlfile: - cfg = yaml.safe_load(ymlfile)["settings"] - control_port, output_port = nex.createNexus( - file=config_file, control_port=0, output_port=0 - ) - nex.destroyNexus() - assert control_port == cfg["control_port"] - assert output_port == cfg["output_port"] +# delete this comment later +def test_start_nexus(sample_nex): + with SignalManager(): + async def set_quit_flag(test_nex): + test_nex.flags["quit"] = True -# delete this comment later -@pytest.mark.skip(reason="unfinished") -def test_startNexus(sample_nex): - nex = sample_nex - nex.startNexus() - assert [p.name for p in nex.processes] == ["Acquirer", "Analysis"] - nex.destroyNexus() + nex = sample_nex + nex.start_nexus(nex.poll_queues, poll_function=set_quit_flag, test_nex=nex) + assert [p.name for p in nex.processes] == ["Acquirer", "Analysis"] -# @pytest.mark.skip(reason="This test is unfinished") +@pytest.mark.skip( + reason="This test is unfinished - it does not validate link structure" +) @pytest.mark.parametrize( ("cfg_name", "actor_list", "link_list"), [ @@ -197,20 +128,17 @@ def test_config_construction(cfg_name, actor_list, link_list, setdir, ports): """ nex = Nexus("test") - nex.createNexus(file=cfg_name, control_port=ports[0], output_port=ports[1]) + nex.create_nexus(file=cfg_name, control_port=ports[0], output_port=ports[1]) logging.info(cfg_name) # Check for actors act_lst = list(nex.actors) - lnk_lst = list(nex.sig_queues) - nex.destroyNexus() + nex.destroy_nexus() assert actor_list == act_lst - assert link_list == lnk_lst act_lst = [] - lnk_lst = [] assert True @@ -218,165 +146,247 @@ def test_config_construction(cfg_name, actor_list, link_list, setdir, ports): "cfg_name", [ "single_actor.yaml", - "single_actor_plasma.yaml", ], ) def test_single_actor(setdir, ports, cfg_name): nex = Nexus("test") with pytest.raises(AttributeError): - nex.createNexus( + nex.create_nexus( file="single_actor.yaml", control_port=ports[0], output_port=ports[1] ) - nex.destroyNexus() + nex.destroy_nexus() def test_cyclic_graph(setdir, ports): nex = Nexus("test") - nex.createNexus( + nex.create_nexus( file="cyclic_config.yaml", control_port=ports[0], output_port=ports[1] ) assert True - nex.destroyNexus() + nex.destroy_nexus() def test_blank_cfg(setdir, caplog, ports): nex = Nexus("test") - with pytest.raises(TypeError): - nex.createNexus( + with pytest.raises(CannotCreateConfigException): + nex.create_nexus( file="blank_file.yaml", control_port=ports[0], output_port=ports[1] ) assert any( ["The config file is empty" in record.msg for record in list(caplog.records)] ) - nex.destroyNexus() + nex.destroy_nexus() -# def test_hasGUI_True(setdir): -# setdir -# nex = Nexus("test") -# nex.createNexus(file="basic_demo_with_GUI.yaml") +def test_start_store(caplog): + nex = Nexus("test") + nex._start_store_interface(100_000_000) # 100 MB store -# assert True -# nex.destroyNexus() + assert any( + "StoreInterface start successful" in record.msg for record in caplog.records + ) -# @pytest.mark.skip(reason="This test is unfinished.") -# def test_hasGUI_False(): -# assert True + nex._close_store_interface() + nex.destroy_nexus() + assert True -@pytest.mark.skip(reason="unfinished") -def test_queue_message(setdir, sample_nex): - nex = sample_nex - nex.startNexus() - time.sleep(20) - nex.setup() - time.sleep(20) - nex.run() - time.sleep(10) - acq_comm = nex.comm_queues["Acquirer_comm"] - acq_comm.put("Test Message") - - assert nex.comm_queues is None - nex.destroyNexus() - assert True +def test_close_store(caplog): + nex = Nexus("test") + nex._start_store_interface(10000) + nex._close_store_interface() + assert any( + "StoreInterface close successful" in record.msg for record in caplog.records + ) -@pytest.mark.asyncio -@pytest.mark.skip(reason="This test is unfinished.") -async def test_queue_readin(sample_nex, caplog): - nex = sample_nex - nex.startNexus() - # cqs = nex.comm_queues - # assert cqs == None - assert [record.msg for record in caplog.records] is None - # cqs["Acquirer_comm"].put('quit') - # assert "quit" == cqs["Acquirer_comm"].get() - # await nex.pollQueues() - assert True + # write to store + with pytest.raises(AttributeError): + nex.p_StoreInterface.put("Message in", "Message in Label") -@pytest.mark.skip(reason="This test is unfinished.") -def test_queue_sendout(): + nex.destroy_nexus() assert True -@pytest.mark.skip(reason="This test is unfinished.") -def test_run_sig(): - assert True +def test_start_harvester(caplog, setdir, ports): + nex = Nexus("test") + try: + nex.create_nexus( + file="minimal_harvester.yaml", + store_size=100_000_000, + control_port=ports[0], + output_port=ports[1], + ) + time.sleep(3) -@pytest.mark.skip(reason="This test is unfinished.") -def test_setup_sig(): - assert True + nex.destroy_nexus() + except Exception as e: + print(f"error caught in test harness: {e}") + logging.error(f"error caught in test harness: {e}") + assert any("Harvester server started" in record.msg for record in caplog.records) -@pytest.mark.skip(reason="This test is unfinished.") -def test_quit_sig(): - assert True +def test_process_actor_state_update(caplog, setdir, ports): + nex = Nexus("test") + try: + nex.create_nexus( + file="minimal_harvester.yaml", + store_size=100_000_000, + control_port=ports[0], + output_port=ports[1], + ) -@pytest.mark.skip(reason="This test is unfinished.") -def test_usehdd_True(): - assert True + time.sleep(3) + new_actor_message = ActorStateMsg("test actor", "waiting", 1234, "test info") -@pytest.mark.skip(reason="This test is unfinished.") -def test_usehdd_False(): - assert True + nex.process_actor_state_update(new_actor_message) + assert "test actor" in nex.actor_states + assert nex.actor_states["test actor"].actor_name == new_actor_message.actor_name + assert ( + nex.actor_states["test actor"].nexus_in_port + == new_actor_message.nexus_in_port + ) + assert nex.actor_states["test actor"].status == new_actor_message.status + update_actor_message = ActorStateMsg("test actor", "waiting", 1234, "test info") -def test_startstore(caplog): - nex = Nexus("test") - nex._startStoreInterface(10000000) # 10 MB store + nex.process_actor_state_update(update_actor_message) + nex.destroy_nexus() + except Exception as e: + print(f"error caught in test harness: {e}") + logging.error(f"error caught in test harness: {e}") assert any( - "StoreInterface start successful" in record.msg for record in caplog.records + "Received state message from new actor test actor" in record.msg + for record in caplog.records + ) + assert any( + "Received state message from actor test actor" in record.msg + for record in caplog.records ) - nex._closeStoreInterface() - nex.destroyNexus() - assert True + +def test_process_actor_state_update_allows_run(caplog, setdir, ports): + nex = Nexus("test") + try: + nex.create_nexus( + file="minimal_harvester.yaml", + store_size=100_000_000, + control_port=ports[0], + output_port=ports[1], + ) + + time.sleep(3) + + nex.actor_states["test actor1"] = None + nex.actor_states["test actor2"] = None + + actor1_message = ActorStateMsg("test actor1", "ready", 1234, "test info") + + nex.process_actor_state_update(actor1_message) + assert "test actor1" in nex.actor_states + assert nex.actor_states["test actor1"].actor_name == actor1_message.actor_name + assert ( + nex.actor_states["test actor1"].nexus_in_port + == actor1_message.nexus_in_port + ) + assert nex.actor_states["test actor1"].status == actor1_message.status + + assert not nex.allowStart + + actor2_message = ActorStateMsg("test actor2", "ready", 5678, "test info2") + + nex.process_actor_state_update(actor2_message) + assert "test actor2" in nex.actor_states + assert nex.actor_states["test actor2"].actor_name == actor2_message.actor_name + assert ( + nex.actor_states["test actor2"].nexus_in_port + == actor2_message.nexus_in_port + ) + assert nex.actor_states["test actor2"].status == actor2_message.status + + nex.destroy_nexus() + except Exception as e: + print(f"error caught in test harness: {e}") + logging.error(f"error caught in test harness: {e}") + assert nex.allowStart -def test_closestore(caplog): +@pytest.mark.asyncio +async def test_process_actor_message(caplog, setdir, ports): nex = Nexus("test") - nex._startStoreInterface(10000) - nex._closeStoreInterface() + try: + nex.create_nexus( + file="minimal_harvester.yaml", + store_size=100_000_000, + control_port=ports[0], + output_port=ports[1], + ) - assert any( - "StoreInterface close successful" in record.msg for record in caplog.records - ) + time.sleep(3) - # write to store + nex.actor_states["test actor1"] = None + nex.actor_states["test actor2"] = None - with pytest.raises(AttributeError): - nex.p_StoreInterface.put("Message in", "Message in Label") + actor1_message = ActorStateMsg("test actor1", "ready", 1234, "test info") - nex.destroyNexus() - assert True + ctx = nex.zmq_context + s = ctx.socket(zmq.REQ) + s.connect(f"tcp://localhost:{nex.actor_in_socket_port}") + + s.send_pyobj(actor1_message) + + await nex.process_actor_message() + + nex.process_actor_state_update(actor1_message) + assert "test actor1" in nex.actor_states + assert nex.actor_states["test actor1"].actor_name == actor1_message.actor_name + assert ( + nex.actor_states["test actor1"].nexus_in_port + == actor1_message.nexus_in_port + ) + assert nex.actor_states["test actor1"].status == actor1_message.status + + s.close(linger=0) + nex.destroy_nexus() + except Exception as e: + print(f"error caught in test harness: {e}") + logging.error(f"error caught in test harness: {e}") + + assert not nex.allowStart def test_specified_free_port(caplog, setdir, ports): nex = Nexus("test") - nex.createNexus( - file="minimal_with_fixed_redis_port.yaml", - store_size=10000000, - control_port=ports[0], - output_port=ports[1], - ) + try: + nex.create_nexus( + file="minimal_with_fixed_redis_port.yaml", + store_size=100_000_000, + control_port=ports[0], + output_port=ports[1], + ) - store = StoreInterface(server_port_num=6378) - store.connect_to_server() - key = store.put("port 6378") - assert store.get(key) == "port 6378" + store = StoreInterface(server_port_num=6378) + store.connect_to_server() + key = store.put("port 6378") + assert store.get(key) == "port 6378" - assert any( - "Successfully connected to redis datastore on port 6378" in record.msg - for record in caplog.records - ) + assert any( + "Successfully connected to redis datastore on port 6378" in record.msg + for record in caplog.records + ) + + time.sleep(3) - nex.destroyNexus() + nex.destroy_nexus() + except Exception as e: + print(f"error caught in test harness: {e}") + logging.error(f"error caught in test harness: {e}") assert any( "StoreInterface start successful on port 6378" in record.msg @@ -386,32 +396,47 @@ def test_specified_free_port(caplog, setdir, ports): def test_specified_busy_port(caplog, setdir, ports, setup_store): nex = Nexus("test") - with pytest.raises(Exception, match="Could not start Redis on specified port."): - nex.createNexus( + try: + nex.create_nexus( file="minimal_with_fixed_default_redis_port.yaml", - store_size=10000000, + store_size=100_000_000, control_port=ports[0], output_port=ports[1], ) - nex.destroyNexus() + time.sleep(3) + + nex.destroy_nexus() + except Exception as e: + print(f"error caught in test harness: {e}") + logging.error(f"error caught in test harness: {e}") assert any( - "Could not start Redis on specified port number." in record.msg + "Could not connect to port 6379" in record.msg for record in caplog.records + ) + + assert any( + "StoreInterface start successful on port 6380" in record.msg for record in caplog.records ) def test_unspecified_port_default_free(caplog, setdir, ports): nex = Nexus("test") - nex.createNexus( - file="minimal.yaml", - store_size=10000000, - control_port=ports[0], - output_port=ports[1], - ) + try: + nex.create_nexus( + file="minimal.yaml", + store_size=100_000_000, + control_port=ports[0], + output_port=ports[1], + ) - nex.destroyNexus() + time.sleep(3) + + nex.destroy_nexus() + except Exception as e: + print(f"error caught in test harness: {e}") + logging.error(f"error caught in test harness: {e}") assert any( "StoreInterface start successful on port 6379" in record.msg @@ -421,14 +446,20 @@ def test_unspecified_port_default_free(caplog, setdir, ports): def test_unspecified_port_default_busy(caplog, setdir, ports, setup_store): nex = Nexus("test") - nex.createNexus( - file="minimal.yaml", - store_size=10000000, - control_port=ports[0], - output_port=ports[1], - ) + try: + nex.create_nexus( + file="minimal.yaml", + store_size=100_000_000, + control_port=ports[0], + output_port=ports[1], + ) + + time.sleep(3) - nex.destroyNexus() + nex.destroy_nexus() + except Exception as e: + print(f"error caught in test harness: {e}") + logging.error(f"error caught in test harness: {e}") assert any( "StoreInterface start successful on port 6380" in record.msg for record in caplog.records @@ -436,15 +467,27 @@ def test_unspecified_port_default_busy(caplog, setdir, ports, setup_store): def test_no_aof_dir_by_default(caplog, setdir, ports): - nex = Nexus("test") - nex.createNexus( - file="minimal.yaml", - store_size=10000000, - control_port=ports[0], - output_port=ports[1], - ) + try: + if "appendonlydir" in os.listdir("."): + shutil.rmtree("appendonlydir") + else: + logging.info("didn't find dbfilename") + + nex = Nexus("test") + + nex.create_nexus( + file="minimal.yaml", + store_size=100_000_000, + control_port=ports[0], + output_port=ports[1], + ) - nex.destroyNexus() + time.sleep(3) + + nex.destroy_nexus() + except Exception as e: + print(f"error caught in test harness: {e}") + logging.error(f"error caught in test harness: {e}") assert "appendonlydir" not in os.listdir(".") assert all(["improv_persistence_" not in name for name in os.listdir(".")]) @@ -452,19 +495,23 @@ def test_no_aof_dir_by_default(caplog, setdir, ports): def test_default_aof_dir_if_none_specified(caplog, setdir, ports, server_port_num): nex = Nexus("test") - nex.createNexus( - file="minimal_with_redis_saving.yaml", - store_size=10000000, - control_port=ports[0], - output_port=ports[1], - ) + try: + nex.create_nexus( + file="minimal_with_redis_saving.yaml", + store_size=100_000_000, + control_port=ports[0], + output_port=ports[1], + ) - store = StoreInterface(server_port_num=server_port_num) - store.put(1) + store = StoreInterface(server_port_num=server_port_num) + store.put(1) - time.sleep(3) + time.sleep(3) - nex.destroyNexus() + nex.destroy_nexus() + except Exception as e: + print(f"error caught in test harness: {e}") + logging.error(f"error caught in test harness: {e}") assert "appendonlydir" in os.listdir(".") @@ -478,19 +525,23 @@ def test_default_aof_dir_if_none_specified(caplog, setdir, ports, server_port_nu def test_specify_static_aof_dir(caplog, setdir, ports, server_port_num): nex = Nexus("test") - nex.createNexus( - file="minimal_with_custom_aof_dirname.yaml", - store_size=10000000, - control_port=ports[0], - output_port=ports[1], - ) + try: + nex.create_nexus( + file="minimal_with_custom_aof_dirname.yaml", + store_size=100_000_000, + control_port=ports[0], + output_port=ports[1], + ) - store = StoreInterface(server_port_num=server_port_num) - store.put(1) + store = StoreInterface(server_port_num=server_port_num) + store.put(1) - time.sleep(3) + time.sleep(3) - nex.destroyNexus() + nex.destroy_nexus() + except Exception as e: + print(f"error caught in test harness: {e}") + logging.error(f"error caught in test harness: {e}") assert "custom_aof_dirname" in os.listdir(".") @@ -504,19 +555,23 @@ def test_specify_static_aof_dir(caplog, setdir, ports, server_port_num): def test_use_ephemeral_aof_dir(caplog, setdir, ports, server_port_num): nex = Nexus("test") - nex.createNexus( - file="minimal_with_ephemeral_aof_dirname.yaml", - store_size=10000000, - control_port=ports[0], - output_port=ports[1], - ) + try: + nex.create_nexus( + file="minimal_with_ephemeral_aof_dirname.yaml", + store_size=100_000_000, + control_port=ports[0], + output_port=ports[1], + ) - store = StoreInterface(server_port_num=server_port_num) - store.put(1) + store = StoreInterface(server_port_num=server_port_num) + store.put(1) - time.sleep(3) + time.sleep(3) - nex.destroyNexus() + nex.destroy_nexus() + except Exception as e: + print(f"error caught in test harness: {e}") + logging.error(f"error caught in test harness: {e}") assert any(["improv_persistence_" in name for name in os.listdir(".")]) @@ -527,18 +582,24 @@ def test_use_ephemeral_aof_dir(caplog, setdir, ports, server_port_num): def test_save_no_schedule(caplog, setdir, ports, server_port_num): nex = Nexus("test") - nex.createNexus( - file="minimal_with_no_schedule_saving.yaml", - store_size=10000000, - control_port=ports[0], - output_port=ports[1], - ) + try: + nex.create_nexus( + file="minimal_with_no_schedule_saving.yaml", + store_size=100_000_000, + control_port=ports[0], + output_port=ports[1], + ) - store = StoreInterface(server_port_num=server_port_num) + store = StoreInterface(server_port_num=server_port_num) - fsync_schedule = store.client.config_get("appendfsync") + fsync_schedule = store.client.config_get("appendfsync") - nex.destroyNexus() + time.sleep(3) + + nex.destroy_nexus() + except Exception as e: + print(f"error caught in test harness: {e}") + logging.error(f"error caught in test harness: {e}") assert "appendonlydir" in os.listdir(".") shutil.rmtree("appendonlydir") @@ -548,18 +609,24 @@ def test_save_no_schedule(caplog, setdir, ports, server_port_num): def test_save_every_second(caplog, setdir, ports, server_port_num): nex = Nexus("test") - nex.createNexus( - file="minimal_with_every_second_saving.yaml", - store_size=10000000, - control_port=ports[0], - output_port=ports[1], - ) + try: + nex.create_nexus( + file="minimal_with_every_second_saving.yaml", + store_size=100_000_000, + control_port=ports[0], + output_port=ports[1], + ) + + store = StoreInterface(server_port_num=server_port_num) - store = StoreInterface(server_port_num=server_port_num) + fsync_schedule = store.client.config_get("appendfsync") - fsync_schedule = store.client.config_get("appendfsync") + time.sleep(3) - nex.destroyNexus() + nex.destroy_nexus() + except Exception as e: + print(f"error caught in test harness: {e}") + logging.error(f"error caught in test harness: {e}") assert "appendonlydir" in os.listdir(".") shutil.rmtree("appendonlydir") @@ -569,18 +636,24 @@ def test_save_every_second(caplog, setdir, ports, server_port_num): def test_save_every_write(caplog, setdir, ports, server_port_num): nex = Nexus("test") - nex.createNexus( - file="minimal_with_every_write_saving.yaml", - store_size=10000000, - control_port=ports[0], - output_port=ports[1], - ) + try: + nex.create_nexus( + file="minimal_with_every_write_saving.yaml", + store_size=100_000_000, + control_port=ports[0], + output_port=ports[1], + ) + + store = StoreInterface(server_port_num=server_port_num) - store = StoreInterface(server_port_num=server_port_num) + fsync_schedule = store.client.config_get("appendfsync") - fsync_schedule = store.client.config_get("appendfsync") + time.sleep(3) - nex.destroyNexus() + nex.destroy_nexus() + except Exception as e: + print(f"error caught in test harness: {e}") + logging.error(f"error caught in test harness: {e}") assert "appendonlydir" in os.listdir(".") shutil.rmtree("appendonlydir") @@ -588,65 +661,121 @@ def test_save_every_write(caplog, setdir, ports, server_port_num): assert fsync_schedule["appendfsync"] == "always" -@pytest.mark.skip(reason="Nexus no longer deletes files on shutdown. Nothing to test.") -def test_store_already_deleted_issues_warning(caplog): - nex = Nexus("test") - nex._startStoreInterface(10000) - store_location = nex.store_loc - StoreInterface(store_loc=nex.store_loc) - os.remove(nex.store_loc) - nex.destroyNexus() - assert any( - "StoreInterface file {} is already deleted".format(store_location) in record.msg - for record in caplog.records - ) +# def test_sigint_exits_cleanly(ports, set_dir_config_parent): +# server_opts = [ +# "improv", +# "server", +# "-c", +# str(ports[0]), +# "-o", +# str(ports[1]), +# "-f", +# "global.log", +# "configs/minimal.yaml", +# ] +# +# env = os.environ.copy() +# env["PYTHONPATH"] += ":" + os.getcwd() +# +# server = subprocess.Popen( +# server_opts, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=env +# ) +# +# time.sleep(5) +# +# server.send_signal(signal.SIGINT) +# +# server.wait(10) +# assert True -@pytest.mark.skip(reason="unfinished") -def test_actor_sub(setdir, capsys, monkeypatch, ports): - monkeypatch.setattr("improv.nexus.input", lambda: "setup\n") - cfg_file = "sample_config.yaml" +# def test_nexus_actor_in_port(ports, setdir, start_nexus_minimal_zmq): +# context = zmq.Context() +# nex_socket = context.socket(zmq.REQ) +# nex_socket.connect(f"tcp://localhost:{ports[3]}") # actor in port +# +# test_socket = context.socket(zmq.REP) +# test_socket.bind("tcp://*:0") +# in_port_string = test_socket.getsockopt_string(SocketOption.LAST_ENDPOINT) +# test_socket_port = int(in_port_string.split(":")[-1]) +# logging.info(f"Using port {test_socket_port}") +# +# logging.info("waiting to send") +# actor_state = ActorStateMsg( +# "test_actor", "test_status", test_socket_port, "test info string" +# ) +# nex_socket.send_pyobj(actor_state) +# logging.info("Sent") +# out = nex_socket.recv_pyobj() +# assert isinstance(out, ActorStateReplyMsg) +# assert out.actor_name == actor_state.actor_name +# assert out.status == "OK" + + +def test_nexus_create_nexus_no_cfg_file(ports): nex = Nexus("test") + with pytest.raises(ConfigFileNotProvidedException): + nex.create_nexus() - nex.createNexus( - file=cfg_file, store_size=4000, control_port=ports[0], output_port=ports[1] - ) - print("Nexus Created") - - nex.startNexus() - print("Nexus Started") - # time.sleep(5) - # print("Printing...") - # subprocess.Popen(["setup"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - # time.sleep(2) - # subprocess.Popen(["run"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - # time.sleep(5) - # subprocess.Popen(["quit"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - - nex.destroyNexus() - assert True +# +# @pytest.mark.skip(reason="Blocking comms so this won't work as-is") +# def test_nexus_actor_comm_setup(ports, setdir): +# filename = "minimal_zmq.yaml" +# nex = Nexus("test") +# nex.create_nexus( +# file=filename, +# store_size=10000000, +# control_port=ports[0], +# output_port=ports[1], +# actor_in_port=ports[2], +# ) +# +# actor = nex.actors["Generator"] +# actor.register_with_nexus() +# +# nex.process_actor_message() +# +# +# @pytest.mark.skip(reason="Test isn't meant to be used for coverage") +# def test_debug_nex(ports, setdir): +# filename = "minimal_zmq.yaml" +# conftest.nex_startup(ports, filename) +# +# +# @pytest.mark.skip(reason="Test isn't meant to be used for coverage") +# def test_nex_cfg(ports, setdir): +# filename = "minimal_zmq.yaml" +# nex = Nexus("test") +# nex.create_nexus( +# file=filename, +# store_size=100000000, +# control_port=ports[0], +# output_port=ports[1], +# actor_in_port=ports[2], +# ) +# nex.start_nexus() -@pytest.mark.skip( - reason="skipping to prevent issues with orphaned stores. TODO fix this" -) -def test_sigint_exits_cleanly(ports, tmp_path): - server_opts = [ - "improv", - "server", - "-c", - str(ports[0]), - "-o", - str(ports[1]), - "-f", - tmp_path / "global.log", - ] - - server = subprocess.Popen( - server_opts, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL - ) - server.send_signal(signal.SIGINT) +def test_nexus_bad_config_actor_args(setdir): + nex = Nexus("test") + # with pytest.raises(ConfigFileNotValidException): + try: + nex.create_nexus("bad_args.yaml") + except ConfigFileNotValidException: + assert True + except Exception as e: + print(f"error caught in test harness: {e}") + logging.error(f"error caught in test harness: {e}") + - server.wait(10) - assert True +def test_nexus_no_config_file(): + nex = Nexus("test") + # with pytest.raises(ConfigFileNotProvidedException): + try: + nex.create_nexus() + except ConfigFileNotProvidedException: + assert True + except Exception as e: + print(f"error caught in test harness: {e}") + logging.error(f"error caught in test harness: {e}") diff --git a/test/test_store_with_errors.py b/test/test_store_with_errors.py index b95781a0..36d1d0be 100644 --- a/test/test_store_with_errors.py +++ b/test/test_store_with_errors.py @@ -1,9 +1,7 @@ import pytest -from pyarrow import plasma -from improv.store import StoreInterface, RedisStoreInterface, PlasmaStoreInterface +from improv.store import StoreInterface, RedisStoreInterface, ObjectNotFoundError -from pyarrow._plasma import PlasmaObjectExists from scipy.sparse import csc_matrix import numpy as np import redis @@ -17,91 +15,30 @@ logger.setLevel(logging.DEBUG) -# TODO: add docstrings!!! -# TODO: clean up syntax - consistent capitalization, function names, etc. -# TODO: decide to keep classes -# TODO: increase coverage!!! SEE store.py - -# Separate each class as individual file - individual tests??? - - def test_connect(setup_store, server_port_num): store = StoreInterface(server_port_num=server_port_num) assert isinstance(store.client, redis.Redis) -def test_plasma_connect(setup_plasma_store, set_store_loc): - store = PlasmaStoreInterface(store_loc=set_store_loc) - assert isinstance(store.client, plasma.PlasmaClient) - - def test_redis_connect(setup_store, server_port_num): store = RedisStoreInterface(server_port_num=server_port_num) - assert isinstance(store.client, redis.Redis) assert store.client.ping() -def test_connect_incorrect_path(setup_plasma_store, set_store_loc): - # TODO: shorter name??? - # TODO: passes, but refactor --- see comments - store_loc = "asdf" - # Handle exception thrown - assert name == 'CannotConnectToStoreInterfaceError' - # and message == 'Cannot connect to store at {}'.format(str(store_loc)) - # with pytest.raises(Exception, match='CannotConnectToStoreInterfaceError') as cm: - # store.connect_store(store_loc) - # # Check that the exception thrown is a CannotConnectToStoreInterfaceError - # raise Exception('Cannot connect to store: {0}'.format(e)) - with pytest.raises(CannotConnectToStoreInterfaceError) as e: - store = PlasmaStoreInterface(store_loc=store_loc) - store.connect_store(store_loc) - # Check that the exception thrown is a CannotConnectToStoreInterfaceError - assert e.value.message == "Cannot connect to store at {}".format(str(store_loc)) - - -def test_redis_connect_wrong_port(setup_store, server_port_num): +def test_redis_connect_wrong_port(server_port_num): bad_port_num = 1234 with pytest.raises(CannotConnectToStoreInterfaceError) as e: RedisStoreInterface(server_port_num=bad_port_num) assert e.value.message == "Cannot connect to store at {}".format(str(bad_port_num)) -def test_connect_none_path(setup_plasma_store): - # BUT default should be store_loc = '/tmp/store' if not entered? - store_loc = None - # Handle exception thrown - assert name == 'CannotConnectToStoreInterfaceError' - # and message == 'Cannot connect to store at {}'.format(str(store_loc)) - # with pytest.raises(Exception) as cm: - # store.connnect_store(store_loc) - # Check that the exception thrown is a CannotConnectToStoreInterfaceError - # assert cm.exception.name == 'CannotConnectToStoreInterfaceError' - # with pytest.raises(Exception, match='CannotConnectToStoreInterfaceError') as cm: - # store.connect_store(store_loc) - # Check that the exception thrown is a CannotConnectToStoreInterfaceError - # raise Exception('Cannot connect to store: {0}'.format(e)) - with pytest.raises(CannotConnectToStoreInterfaceError) as e: - store = PlasmaStoreInterface(store_loc=store_loc) - store.connect_store(store_loc) - # Check that the exception thrown is a CannotConnectToStoreInterfaceError - assert e.value.message == "Cannot connect to store at {}".format(str(store_loc)) - - -# class StoreInterfaceGet(self): - - # TODO: @pytest.parameterize...store.get and store.getID for diff datatypes, -# pickleable and not, etc. -# Check raises...CannotGetObjectError (object never stored) def test_init_empty(setup_store, server_port_num): store = StoreInterface(server_port_num=server_port_num) # logger.info(store.client.config_get()) assert store.get_all() == [] -def test_plasma_init_empty(setup_plasma_store, set_store_loc): - store = PlasmaStoreInterface(store_loc=set_store_loc) - assert store.get_all() == {} - - def test_is_csc_matrix_and_put(setup_store, server_port_num): mat = csc_matrix((3, 4), dtype=np.int8) store = StoreInterface(server_port_num=server_port_num) @@ -109,21 +46,13 @@ def test_is_csc_matrix_and_put(setup_store, server_port_num): assert isinstance(store.get(x), csc_matrix) -def test_plasma_is_csc_matrix_and_put(setup_plasma_store, set_store_loc): - mat = csc_matrix((3, 4), dtype=np.int8) - store = PlasmaStoreInterface(store_loc=set_store_loc) - x = store.put(mat, "matrix") - assert isinstance(store.getID(x), csc_matrix) - - -@pytest.mark.skip def test_get_list_and_all(setup_store, server_port_num): store = StoreInterface(server_port_num=server_port_num) - # id = store.put(1, "one") - # id2 = store.put(2, "two") - # id3 = store.put(3, "three") - assert [1, 2] == store.getList(["one", "two"]) - assert [1, 2, 3] == store.get_all() + id = store.put(1) + id2 = store.put(2) + store.put(3) + assert [1, 2] == store.get_list([id, id2]) + assert [1, 2, 3] == sorted(store.get_all()) def test_reset(setup_store, server_port_num): @@ -133,13 +62,6 @@ def test_reset(setup_store, server_port_num): assert store.get(id) == 1 -def test_plasma_reset(setup_plasma_store, set_store_loc): - store = PlasmaStoreInterface(store_loc=set_store_loc) - store.reset() - id = store.put(1, "one") - assert store.get(id) == 1 - - def test_put_one(setup_store, server_port_num): store = StoreInterface(server_port_num=server_port_num) id = store.put(1) @@ -152,27 +74,10 @@ def test_redis_put_one(setup_store, server_port_num): assert 1 == store.get(key) -def test_plasma_put_one(setup_plasma_store, set_store_loc): - store = PlasmaStoreInterface(store_loc=set_store_loc) - id = store.put(1, "one") - assert 1 == store.get(id) - - -@pytest.mark.skip(reason="Error not being raised") -def test_put_twice(setup_store): - # store = StoreInterface() - with pytest.raises(PlasmaObjectExists) as e: - # id = store.put(2, "two") - # id2 = store.put(2, "two") - pass - # Check that the exception thrown is an PlasmaObjectExists - assert e.value.message == "Object already exists. Meant to call replace?" - - -def test_getOne(setup_store, server_port_num): - store = StoreInterface(server_port_num=server_port_num) - id = store.put(1) - assert 1 == store.get(id) +def test_redis_get_unknown_object(setup_store, server_port_num): + store = RedisStoreInterface(server_port_num=server_port_num) + with pytest.raises(ObjectNotFoundError): + store.get("unknown") def test_redis_get_one(setup_store, server_port_num): diff --git a/test/test_tui.py b/test/test_tui.py index d45bccfe..64b44ee5 100644 --- a/test/test_tui.py +++ b/test/test_tui.py @@ -6,8 +6,6 @@ from zmq import PUB, REP from zmq.log.handlers import PUBHandler -from test_nexus import ports - @pytest.fixture def logger(ports): @@ -31,7 +29,7 @@ async def sockets(ports): @pytest.fixture async def app(ports): - mock = tui.TUI(*ports) + mock = tui.TUI(*ports[:-1]) yield mock time.sleep(0.5) @@ -60,8 +58,13 @@ async def test_log_panel_receives_logging(app, logger): async def test_input_box_echoed_to_console(app): + ctx = zmq.Context() + mock_server_socket = ctx.socket(REP) + mock_server_socket.bind(f"tcp://*:{app.control_port.split(':')[1]}") async with app.run_test() as pilot: await pilot.press(*"foo", "enter") + await mock_server_socket.recv_string() + await mock_server_socket.send_string("test reply") console = pilot.app.get_widget_by_id("console") assert console.history[0] == "foo"