Skip to content

Commit

Permalink
Split dataset before process creation
Browse files Browse the repository at this point in the history
  • Loading branch information
dagrayvid committed Oct 19, 2024
1 parent 955c120 commit 94dd326
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 45 deletions.
11 changes: 9 additions & 2 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class Dataset:
def __init__(self,
file,
model_name="",
max_queries=3000,
max_queries=8000,
min_input_tokens=0,
max_input_tokens=16000,
min_output_tokens=0,
Expand All @@ -36,6 +36,13 @@ def __init__(self,
logging.warning("Total dataset is %s elements, check filters!", len(self.dataset_list))
self.index = 0

def user_subset(self, user_id, num_users):
if user_id >= num_users:
logging.error("Unexpected inputs, user_id must be < num_users")

return self.dataset_list[user_id::num_users]


def get_next_n_queries(self, n):
"""Get the N next queries."""
max_index = len(self.dataset_list)
Expand All @@ -47,7 +54,7 @@ def get_next_n_queries(self, n):
def initialize_dataset(
filename,
model_name="",
max_queries=3000,
max_queries=8000,
min_input_tokens=0,
max_input_tokens=16000,
min_output_tokens=0,
Expand Down
43 changes: 13 additions & 30 deletions load_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,15 @@
import utils


def run_main_process(concurrency, duration, dataset, dataset_q, stop_q):
def run_main_process(concurrency, duration, dataset, schedule_q, stop_q):
"""Run the main process."""
logging.info("Test from main process")

# Initialize the dataset_queue with 4*concurrency requests
for query in dataset.get_next_n_queries(2 * concurrency):
dataset_q.put(query)

start_time = time.time()
current_time = start_time
schedule_q.put(start_time)
while (current_time - start_time) < duration:
# Keep the dataset queue full for duration
if dataset_q.qsize() < int(0.5*concurrency + 1):
logging.info("Adding %d entries to dataset queue", concurrency)
for query in dataset.get_next_n_queries(concurrency):
dataset_q.put(query)
time.sleep(0.1)
current_time = time.time()

Expand All @@ -38,11 +31,6 @@ def run_main_process(concurrency, duration, dataset, dataset_q, stop_q):
# Signal users to stop sending requests
stop_q.put(None)

# Empty the dataset queue
while not dataset_q.empty():
logging.debug("Removing element from dataset_q")
dataset_q.get()

return


Expand All @@ -57,17 +45,12 @@ def gather_results(results_pipes):
return results_list


def exit_gracefully(procs, dataset_q, stop_q, logger_q, log_reader_thread, code):
def exit_gracefully(procs, stop_q, logger_q, log_reader_thread, code):
"""Exit gracefully."""
# Signal users to stop sending requests
if stop_q.empty():
stop_q.put(None)

if dataset_q is not None and not dataset_q.empty():
logging.warning("Removing more elements from dataset_q after gathering results!")
while not dataset_q.empty():
dataset_q.get()


logging.debug("Calling join() on all user processes")
for proc in procs:
proc.join()
Expand All @@ -89,8 +72,8 @@ def main(args):
log_reader_thread = logging_utils.init_logging(args.log_level, logger_q)

# Create processes and their Users
schedule_q = mp_ctx.Queue(1)
stop_q = mp_ctx.Queue(1)
dataset_q = mp_ctx.Queue()
procs = []
results_pipes = []

Expand All @@ -102,7 +85,7 @@ def main(args):
concurrency, duration, plugin = utils.parse_config(config)
except Exception as e:
logging.error("Exiting due to invalid input: %s", repr(e))
exit_gracefully(procs, dataset_q, stop_q, logger_q, log_reader_thread, 1)
exit_gracefully(procs, stop_q, logger_q, log_reader_thread, 1)

try:
logging.debug("Creating dataset with configuration %s", config["dataset"])
Expand All @@ -113,9 +96,11 @@ def main(args):
logging.debug("Creating %s Users and corresponding processes", concurrency)
for idx in range(concurrency):
send_results, recv_results = mp_ctx.Pipe()
results_pipes.append(recv_results)
user = User(
idx,
dataset_q=dataset_q,
dataset=dataset.user_subset(idx, concurrency),
schedule_q=schedule_q,
stop_q=stop_q,
results_pipe=send_results,
plugin=plugin,
Expand All @@ -127,10 +112,9 @@ def main(args):
procs.append(proc)
logging.info("Starting %s", proc)
proc.start()
results_pipes.append(recv_results)

logging.debug("Running main process")
run_main_process(concurrency, duration, dataset, dataset_q, stop_q)
run_main_process(concurrency, duration, dataset, schedule_q, stop_q)

results_list = gather_results(results_pipes)

Expand All @@ -139,13 +123,12 @@ def main(args):
# Terminate queues immediately on ^C
except KeyboardInterrupt:
stop_q.cancel_join_thread()
dataset_q.cancel_join_thread()
exit_gracefully(procs, dataset_q, stop_q, logger_q, log_reader_thread, 130)
exit_gracefully(procs, stop_q, logger_q, log_reader_thread, 130)
except Exception:
logging.exception("Unexpected exception in main process")
exit_gracefully(procs, dataset_q, stop_q, logger_q, log_reader_thread, 1)
exit_gracefully(procs, stop_q, logger_q, log_reader_thread, 1)

exit_gracefully(procs, dataset_q, stop_q, logger_q, log_reader_thread, 0)
exit_gracefully(procs, stop_q, logger_q, log_reader_thread, 0)


if __name__ == "__main__":
Expand Down
29 changes: 16 additions & 13 deletions user.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ class User:
def __init__(
self,
user_id,
dataset_q,
dataset,
schedule_q,
stop_q,
results_pipe,
plugin,
Expand All @@ -22,7 +23,9 @@ def __init__(
"""Initialize object."""
self.user_id = user_id
self.plugin = plugin
self.dataset_q = dataset_q
self.dataset = dataset
self.dataset_idx = 0
self.schedule_q = schedule_q
self.stop_q = stop_q
self.results_list = []
self.results_pipe = results_pipe
Expand All @@ -34,15 +37,8 @@ def __init__(

def make_request(self, test_end_time=0):
"""Make a request."""
try:
query = self.dataset_q.get(timeout=2)
except queue.Empty:
# if timeout passes, queue.Empty will be thrown
# User should continue to poll for inputs
return None
except ValueError:
self.logger.warn("dataset q does not exist!")
return None
query = self.dataset[self.dataset_idx]
self.dataset_idx = (self.dataset_idx + 1) % len(self.dataset)

self.logger.info("User %s making request", self.user_id)
result = self.plugin.request_func(query, self.user_id, test_end_time)
Expand All @@ -63,13 +59,20 @@ def run_user_process(self):
"""Run a process."""
self._init_user_process_logging()

# Waits for all processes to actually be started
while self.schedule_q.empty():
time.sleep(0.1)

test_end_time = time.time() + self.run_duration
self.logger.info("User %s starting request loop", self.user_id)

while self.stop_q.empty():
result = self.make_request(test_end_time)
# make_request will return None after 2 seconds if dataset_q is empty
# to ensure that users don't get stuck waiting for requests indefinitely

if result is not None:
self.results_list.append(result)
else:
self.logger.info("Unexpected None result from User.make_request()")

self.results_pipe.send(self.results_list)

Expand Down

0 comments on commit 94dd326

Please sign in to comment.