Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Add RPS based load option #65

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@ storage: # TODO
type: local
dataset:
file: "datasets/openorca_large_subset_011.jsonl"
max_queries: 1000
max_queries: 4000
min_input_tokens: 0
max_input_tokens: 1024
min_output_tokens: 0
max_output_tokens: 1024
max_output_tokens: 512
max_sequence_tokens: 2048
load_options:
type: constant #Future options: loadgen, stair-step
concurrency: 1
type: rps #Options: concurrency, rps, loadgen, stair-step
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason you replaced the word constant load type with concurrency? imo, constant sounds more closer to Constant Load which is a Continuous stream of requests.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking that constant is ambiguous, as RPS can also be constant. My other thinking is that we might later add dynamically changing RPS or dynamically changing concurrency so either RPS or concurrency could be constant or dynamic.

concurrency: 4
rps: 16
duration: 20 # In seconds. Maybe in future support "100s" "10m", etc...
plugin: "openai_plugin"
plugin_options:
Expand Down
11 changes: 9 additions & 2 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class Dataset:
def __init__(self,
file,
model_name="",
max_queries=3000,
max_queries=8000,
min_input_tokens=0,
max_input_tokens=16000,
min_output_tokens=0,
Expand All @@ -36,6 +36,13 @@ def __init__(self,
logging.warning("Total dataset is %s elements, check filters!", len(self.dataset_list))
self.index = 0

def user_subset(self, user_id, num_users):
if user_id >= num_users:
logging.error("Unexpected inputs, user_id must be < num_users")

return self.dataset_list[user_id::num_users]


def get_next_n_queries(self, n):
"""Get the N next queries."""
max_index = len(self.dataset_list)
Expand All @@ -47,7 +54,7 @@ def get_next_n_queries(self, n):
def initialize_dataset(
filename,
model_name="",
max_queries=3000,
max_queries=8000,
min_input_tokens=0,
max_input_tokens=16000,
min_output_tokens=0,
Expand Down
99 changes: 65 additions & 34 deletions load_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,32 @@
import utils


def run_main_process(concurrency, duration, dataset, dataset_q, stop_q):
def run_main_process(rps, duration, dataset, schedule_q, stop_q):
"""Run the main process."""
logging.info("Test from main process")

# Initialize the dataset_queue with 4*concurrency requests
for query in dataset.get_next_n_queries(2 * concurrency):
dataset_q.put(query)

start_time = time.time()
end_time = start_time + duration
if rps is not None:
main_loop_rps_mode(schedule_q, rps, start_time, end_time)
else:
main_loop_concurrency_mode(schedule_q, start_time, end_time)

logging.info("Timer ended, stopping processes")

# Signal users to stop sending requests
stop_q.put(None)

return
Comment on lines +37 to +38
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop this return?

Suggested change
return

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't but I wonder if adding a dedicated try-catch exception block in this function worth it. We currently catch all the cascade exceptions with the generic Exception class in the main function but it's probably not the cleanest way to handle the exception IMO.

Not suggesting this should be addressed in this PR but a follow-up PR to cleanup our exception handling might be good.


def main_loop_concurrency_mode(schedule_q, start_time, end_time):
"""Let all users send requests repeatedly until end_time"""
logging.info("Test from main process")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need this logging statement here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I'll remove this thanks!


schedule_q.put(start_time)

current_time = start_time
while (current_time - start_time) < duration:
# Keep the dataset queue full for duration
if dataset_q.qsize() < int(0.5*concurrency + 1):
logging.info("Adding %d entries to dataset queue", concurrency)
for query in dataset.get_next_n_queries(concurrency):
dataset_q.put(query)
while current_time < end_time:
time.sleep(0.1)
current_time = time.time()

Expand All @@ -38,12 +48,37 @@ def run_main_process(concurrency, duration, dataset, dataset_q, stop_q):
# Signal users to stop sending requests
stop_q.put(None)

# Empty the dataset queue
while not dataset_q.empty():
logging.debug("Removing element from dataset_q")
dataset_q.get()

return
def request_schedule_constant_rps(rps, start_time, end_time):
"""Returns a list of timestamps for request schedule with constant RPS"""
interval = 1 / rps
next_req_time = start_time
while next_req_time < end_time:
yield(next_req_time)
next_req_time = next_req_time + interval


# This function should support non-constant RPS in the future
def main_loop_rps_mode(schedule_q, rps, start_time, end_time):
"""Dispatch requests with constant RPS, via schedule_q"""
req_times = request_schedule_constant_rps(rps, start_time, end_time)

current_time = time.time()
for next_req_time in req_times:
while next_req_time > current_time:
# Wait or spin until next req needs to be dispatched
sleep_time = (next_req_time - current_time) - 0.01 # Sleep until 10ms before next_req_time
if sleep_time > 0:
time.sleep(sleep_time)
# else spin
current_time = time.time()

logging.info(f"Scheduling request time {next_req_time}")
schedule_q.put(next_req_time)

if current_time >= end_time:
return



def gather_results(results_pipes):
Expand All @@ -57,17 +92,12 @@ def gather_results(results_pipes):
return results_list


def exit_gracefully(procs, dataset_q, stop_q, logger_q, log_reader_thread, code):
def exit_gracefully(procs, stop_q, logger_q, log_reader_thread, code):
"""Exit gracefully."""
# Signal users to stop sending requests
if stop_q.empty():
stop_q.put(None)

if dataset_q is not None and not dataset_q.empty():
logging.warning("Removing more elements from dataset_q after gathering results!")
while not dataset_q.empty():
dataset_q.get()

logging.debug("Calling join() on all user processes")
for proc in procs:
proc.join()
Expand All @@ -89,48 +119,50 @@ def main(args):
log_reader_thread = logging_utils.init_logging(args.log_level, logger_q)

# Create processes and their Users
schedule_q = mp_ctx.Queue(1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
schedule_q = mp_ctx.Queue(1)
schedule_q = mp_ctx.Queue(1)
schedule_q.cancel_join_thread()

Toggle cancel_join_thread() here to avoid the queue blocking on exit.

stop_q = mp_ctx.Queue(1)
dataset_q = mp_ctx.Queue()
procs = []
results_pipes = []

# Parse config
logging.debug("Parsing YAML config file %s", args.config)
concurrency, duration, plugin = 0, 0, None
rps, concurrency, duration, plugin = None, 0, 0, None
try:
config = utils.yaml_load(args.config)
concurrency, duration, plugin = utils.parse_config(config)
rps, concurrency, duration, plugin = utils.parse_config(config)
except Exception as e:
logging.error("Exiting due to invalid input: %s", repr(e))
exit_gracefully(procs, dataset_q, stop_q, logger_q, log_reader_thread, 1)
exit_gracefully(procs, stop_q, logger_q, log_reader_thread, 1)

try:
logging.debug("Creating dataset with configuration %s", config["dataset"])
# Get model_name if set for prompt formatting
model_name = config.get("plugin_options", {}).get("model_name", "")
dataset = Dataset(model_name=model_name, **config["dataset"])

logging.debug("Creating %s Users and corresponding processes", concurrency)
logging.info("Creating %s Users and corresponding processes", concurrency)
for idx in range(concurrency):
send_results, recv_results = mp_ctx.Pipe()
results_pipes.append(recv_results)
user = User(
idx,
dataset_q=dataset_q,
dataset=dataset.user_subset(idx, concurrency),
schedule_q=schedule_q,
stop_q=stop_q,
results_pipe=send_results,
plugin=plugin,
logger_q=logger_q,
log_level=args.log_level,
run_duration=duration,
rate_limited=(rps is not None)
)
proc = mp_ctx.Process(target=user.run_user_process)
procs.append(proc)
logging.info("Starting %s", proc)
proc.start()
results_pipes.append(recv_results)

logging.debug("Running main process")
run_main_process(concurrency, duration, dataset, dataset_q, stop_q)
run_main_process(rps, duration, dataset, schedule_q, stop_q)

results_list = gather_results(results_pipes)

Expand All @@ -139,13 +171,12 @@ def main(args):
# Terminate queues immediately on ^C
except KeyboardInterrupt:
stop_q.cancel_join_thread()
dataset_q.cancel_join_thread()
exit_gracefully(procs, dataset_q, stop_q, logger_q, log_reader_thread, 130)
exit_gracefully(procs, stop_q, logger_q, log_reader_thread, 130)
except Exception:
logging.exception("Unexpected exception in main process")
exit_gracefully(procs, dataset_q, stop_q, logger_q, log_reader_thread, 1)
exit_gracefully(procs, stop_q, logger_q, log_reader_thread, 1)

exit_gracefully(procs, dataset_q, stop_q, logger_q, log_reader_thread, 0)
exit_gracefully(procs, stop_q, logger_q, log_reader_thread, 0)


if __name__ == "__main__":
Expand Down
5 changes: 0 additions & 5 deletions plugins/caikit_client_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def request_grpc(self, query, user_id, test_end_time: float=0):
result.output_tokens_before_timeout = result.output_tokens
result.output_text = response

result.calculate_results()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if its time to depreciate the caikit_client_plugin?

return result

def streaming_request_grpc(self, query, user_id, test_end_time: float=0):
Expand Down Expand Up @@ -113,8 +112,6 @@ def streaming_request_grpc(self, query, user_id, test_end_time: float=0):
# TODO: Calculate correct output tokens before test timeout duration for streaming requests
result.output_tokens_before_timeout = result.output_tokens

result.calculate_results()

return result

def request_http(self, query, user_id):
Expand All @@ -138,7 +135,6 @@ def request_http(self, query, user_id):
result.output_tokens_before_timeout = result.output_tokens
result.output_text = response

result.calculate_results()
return result

def streaming_request_http(self, query, user_id):
Expand Down Expand Up @@ -171,5 +167,4 @@ def streaming_request_http(self, query, user_id):
# TODO: Calculate correct output tokens before test timeout duration for streaming requests
result.output_tokens_before_timeout = result.output_tokens

result.calculate_results()
return result
2 changes: 0 additions & 2 deletions plugins/dummy_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def request_http(self, query, user_id, test_end_time: float=0):

result.end_time = time.time()

result.calculate_results()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we do a cleanup we probably should remove this file.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this was originally added with the thought that it could be used in some test cases but we may want to remove it depending on how we decide to handle testing (unit tests, e2e tests, etc...)


return result

Expand Down Expand Up @@ -63,5 +62,4 @@ def streaming_request_http(self, query, user_id, test_end_time: float=0):
# TODO: Calculate correct output tokens before test timeout duration for streaming requests
result.output_tokens_before_timeout = result.output_tokens

result.calculate_results()
return result
1 change: 0 additions & 1 deletion plugins/hf_tgi_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,5 +114,4 @@ def streaming_request_http(self, query, user_id, test_end_time: float=0):
# TODO: Calculate correct output tokens before test timeout duration for streaming requests
result.output_tokens_before_timeout = result.output_tokens

result.calculate_results()
return result
2 changes: 0 additions & 2 deletions plugins/openai_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ def request_http(self, query: dict, user_id: int, test_end_time: float = 0):

# For non-streaming requests we are keeping output_tokens_before_timeout and output_tokens same.
result.output_tokens_before_timeout = result.output_tokens
result.calculate_results()

return result

Expand Down Expand Up @@ -356,5 +355,4 @@ def streaming_request_http(self, query: dict, user_id: int, test_end_time: float
if expected_output_tokens and result.output_tokens != expected_output_tokens:
logger.warning(f"Received {result.output_tokens} tokens but expected {expected_output_tokens} tokens")

result.calculate_results()
return result
2 changes: 0 additions & 2 deletions plugins/tgis_grpc_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ def make_request(self, query: dict, user_id: int, test_end_time: float = 0):
else:
result.output_tokens = query["output_tokens"]

result.calculate_results()
return result

def make_request_stream(self, query: dict, user_id: int, test_end_time: float):
Expand Down Expand Up @@ -199,5 +198,4 @@ def make_request_stream(self, query: dict, user_id: int, test_end_time: float):
logger.warning("Output token count not found in response, using dataset expected output tokens")
result.output_tokens = len(tokens)

result.calculate_results()
return result
7 changes: 7 additions & 0 deletions result.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ def __init__(self, user_id, input_id, input_tokens=None):
self.output_text: Optional[str] = None
self.output_tokens: Optional[int] = None
self.output_tokens_before_timeout: Optional[int] = None
self.scheduled_start_time: Optional[float] = None
self.start_time: Optional[float] = None
self.ack_time: Optional[float] = None
self.first_token_time: Optional[float] = None
self.end_time: Optional[float] = None
self.client_wait_time: Optional[float] = None
self.response_time: Optional[float] = None
self.tt_ack: Optional[float] = None
self.ttft: Optional[float] = None
Expand Down Expand Up @@ -59,3 +61,8 @@ def calculate_results(self):
self.tpot = (
self.response_time / self.output_tokens
) # Time per output token in ms

if self.scheduled_start_time is not None and self.start_time is not None:
self.client_wait_time = (
self.start_time - self.scheduled_start_time
)
Loading
Loading