Skip to content

Commit

Permalink
Add RPS based load option
Browse files Browse the repository at this point in the history
  • Loading branch information
dagrayvid committed Oct 21, 2024
1 parent 94dd326 commit bf19e71
Show file tree
Hide file tree
Showing 10 changed files with 94 additions and 31 deletions.
9 changes: 5 additions & 4 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@ storage: # TODO
type: local
dataset:
file: "datasets/openorca_large_subset_011.jsonl"
max_queries: 1000
max_queries: 4000
min_input_tokens: 0
max_input_tokens: 1024
min_output_tokens: 0
max_output_tokens: 1024
max_output_tokens: 512
max_sequence_tokens: 2048
load_options:
type: constant #Future options: loadgen, stair-step
concurrency: 1
type: rps #Options: concurrency, rps, loadgen, stair-step
concurrency: 4
rps: 16
duration: 20 # In seconds. Maybe in future support "100s" "10m", etc...
plugin: "openai_plugin"
plugin_options:
Expand Down
51 changes: 42 additions & 9 deletions load_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,31 @@
import utils


def run_main_process(concurrency, duration, dataset, schedule_q, stop_q):
def run_main_process(rps, duration, dataset, schedule_q, stop_q):
"""Run the main process."""
logging.info("Test from main process")

start_time = time.time()
current_time = start_time
end_time = start_time + duration
if rps is not None:
main_loop_rps_mode(schedule_q, rps, start_time, end_time)
else:
main_loop_concurrency_mode(schedule_q, start_time, end_time)

logging.info("Timer ended, stopping processes")

# Signal users to stop sending requests
stop_q.put(None)

return

def main_loop_concurrency_mode(schedule_q, start_time, end_time):
logging.info("Test from main process")

schedule_q.put(start_time)
while (current_time - start_time) < duration:
# Keep the dataset queue full for duration

current_time = start_time
while current_time < end_time:
time.sleep(0.1)
current_time = time.time()

Expand All @@ -31,7 +47,23 @@ def run_main_process(concurrency, duration, dataset, schedule_q, stop_q):
# Signal users to stop sending requests
stop_q.put(None)

return

def main_loop_rps_mode(schedule_q, rps, start_time, end_time):
interval = 1 / rps

next_req_time = start_time
current_time = start_time
while current_time < end_time:
if next_req_time <= current_time:
logging.info("Scheduling request")
schedule_q.put(next_req_time)
next_req_time = next_req_time + interval
sleep_time = (next_req_time - current_time) - 0.01 # Sleep until 10ms before next_req_time
if sleep_time > 0:
time.sleep(sleep_time)
# else spin until next_req_time <= current_time

current_time = time.time()


def gather_results(results_pipes):
Expand Down Expand Up @@ -79,10 +111,10 @@ def main(args):

# Parse config
logging.debug("Parsing YAML config file %s", args.config)
concurrency, duration, plugin = 0, 0, None
rps, concurrency, duration, plugin = None, 0, 0, None
try:
config = utils.yaml_load(args.config)
concurrency, duration, plugin = utils.parse_config(config)
rps, concurrency, duration, plugin = utils.parse_config(config)
except Exception as e:
logging.error("Exiting due to invalid input: %s", repr(e))
exit_gracefully(procs, stop_q, logger_q, log_reader_thread, 1)
Expand All @@ -93,7 +125,7 @@ def main(args):
model_name = config.get("plugin_options", {}).get("model_name", "")
dataset = Dataset(model_name=model_name, **config["dataset"])

logging.debug("Creating %s Users and corresponding processes", concurrency)
logging.info("Creating %s Users and corresponding processes", concurrency)
for idx in range(concurrency):
send_results, recv_results = mp_ctx.Pipe()
results_pipes.append(recv_results)
Expand All @@ -107,14 +139,15 @@ def main(args):
logger_q=logger_q,
log_level=args.log_level,
run_duration=duration,
rate_limited=(rps is not None)
)
proc = mp_ctx.Process(target=user.run_user_process)
procs.append(proc)
logging.info("Starting %s", proc)
proc.start()

logging.debug("Running main process")
run_main_process(concurrency, duration, dataset, schedule_q, stop_q)
run_main_process(rps, duration, dataset, schedule_q, stop_q)

results_list = gather_results(results_pipes)

Expand Down
5 changes: 0 additions & 5 deletions plugins/caikit_client_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def request_grpc(self, query, user_id, test_end_time: float=0):
result.output_tokens_before_timeout = result.output_tokens
result.output_text = response

result.calculate_results()
return result

def streaming_request_grpc(self, query, user_id, test_end_time: float=0):
Expand Down Expand Up @@ -113,8 +112,6 @@ def streaming_request_grpc(self, query, user_id, test_end_time: float=0):
# TODO: Calculate correct output tokens before test timeout duration for streaming requests
result.output_tokens_before_timeout = result.output_tokens

result.calculate_results()

return result

def request_http(self, query, user_id):
Expand All @@ -138,7 +135,6 @@ def request_http(self, query, user_id):
result.output_tokens_before_timeout = result.output_tokens
result.output_text = response

result.calculate_results()
return result

def streaming_request_http(self, query, user_id):
Expand Down Expand Up @@ -171,5 +167,4 @@ def streaming_request_http(self, query, user_id):
# TODO: Calculate correct output tokens before test timeout duration for streaming requests
result.output_tokens_before_timeout = result.output_tokens

result.calculate_results()
return result
2 changes: 0 additions & 2 deletions plugins/dummy_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def request_http(self, query, user_id, test_end_time: float=0):

result.end_time = time.time()

result.calculate_results()

return result

Expand Down Expand Up @@ -63,5 +62,4 @@ def streaming_request_http(self, query, user_id, test_end_time: float=0):
# TODO: Calculate correct output tokens before test timeout duration for streaming requests
result.output_tokens_before_timeout = result.output_tokens

result.calculate_results()
return result
1 change: 0 additions & 1 deletion plugins/hf_tgi_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,5 +114,4 @@ def streaming_request_http(self, query, user_id, test_end_time: float=0):
# TODO: Calculate correct output tokens before test timeout duration for streaming requests
result.output_tokens_before_timeout = result.output_tokens

result.calculate_results()
return result
2 changes: 0 additions & 2 deletions plugins/openai_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ def request_http(self, query: dict, user_id: int, test_end_time: float = 0):

# For non-streaming requests we are keeping output_tokens_before_timeout and output_tokens same.
result.output_tokens_before_timeout = result.output_tokens
result.calculate_results()

return result

Expand Down Expand Up @@ -356,5 +355,4 @@ def streaming_request_http(self, query: dict, user_id: int, test_end_time: float
if expected_output_tokens and result.output_tokens != expected_output_tokens:
logger.warning(f"Received {result.output_tokens} tokens but expected {expected_output_tokens} tokens")

result.calculate_results()
return result
2 changes: 0 additions & 2 deletions plugins/tgis_grpc_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ def make_request(self, query: dict, user_id: int, test_end_time: float = 0):
else:
result.output_tokens = query["output_tokens"]

result.calculate_results()
return result

def make_request_stream(self, query: dict, user_id: int, test_end_time: float):
Expand Down Expand Up @@ -199,5 +198,4 @@ def make_request_stream(self, query: dict, user_id: int, test_end_time: float):
logger.warning("Output token count not found in response, using dataset expected output tokens")
result.output_tokens = len(tokens)

result.calculate_results()
return result
7 changes: 7 additions & 0 deletions result.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ def __init__(self, user_id, input_id, input_tokens=None):
self.output_text: Optional[str] = None
self.output_tokens: Optional[int] = None
self.output_tokens_before_timeout: Optional[int] = None
self.scheduled_start_time: Optional[float] = None
self.start_time: Optional[float] = None
self.ack_time: Optional[float] = None
self.first_token_time: Optional[float] = None
self.end_time: Optional[float] = None
self.client_wait_time: Optional[float] = None
self.response_time: Optional[float] = None
self.tt_ack: Optional[float] = None
self.ttft: Optional[float] = None
Expand Down Expand Up @@ -59,3 +61,8 @@ def calculate_results(self):
self.tpot = (
self.response_time / self.output_tokens
) # Time per output token in ms

if self.scheduled_start_time is not None and self.start_time is not None:
self.client_wait_time = (
self.start_time - self.scheduled_start_time
)
25 changes: 22 additions & 3 deletions user.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(
logger_q,
log_level,
run_duration,
rate_limited,
):
"""Initialize object."""
self.user_id = user_id
Expand All @@ -34,14 +35,21 @@ def __init__(
# Must get reset in user process to use the logger created in _init_user_process_logging
self.logger = logging.getLogger("user")
self.run_duration = run_duration
self.rate_limited = rate_limited

def make_request(self, test_end_time=0):
def make_request(self, test_end_time=0, req_schedule_time=None):
"""Make a request."""
query = self.dataset[self.dataset_idx]
self.dataset_idx = (self.dataset_idx + 1) % len(self.dataset)

self.logger.info("User %s making request", self.user_id)
result = self.plugin.request_func(query, self.user_id, test_end_time)

if req_schedule_time:
result.scheduled_start_time = req_schedule_time

result.calculate_results()

return result

def _init_user_process_logging(self):
Expand All @@ -60,14 +68,25 @@ def run_user_process(self):
self._init_user_process_logging()

# Waits for all processes to actually be started
while self.schedule_q.empty():
while not self.rate_limited and self.schedule_q.empty():
time.sleep(0.1)

test_end_time = time.time() + self.run_duration
self.logger.info("User %s starting request loop", self.user_id)

while self.stop_q.empty():
result = self.make_request(test_end_time)
try:
req_schedule_time = self.schedule_q.get(timeout=2)
if not self.stop_q.empty():
break
except queue.Empty:
# if timeout passes, queue.Empty will be thrown
# User should check if stop_q has been set, else poll again
# self.debug.info("User waiting for a request to be scheduled")
continue

result = self.make_request(test_end_time, req_schedule_time=req_schedule_time)


if result is not None:
self.results_list.append(result)
Expand Down
21 changes: 18 additions & 3 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,15 @@ def parse_config(config):
logging.info("load_options config: %s", config["load_options"])

load_options = config.get("load_options")
concurrency = load_options.get("concurrency")
duration = load_options.get("duration")
if load_options.get("type") == "concurrency":
concurrency = load_options.get("concurrency")
rps = None
elif load_options.get("type") == "rps":
concurrency = load_options.get("concurrency")
rps = load_options.get("rps")
else:
logging.error("Unknown load_options type %s", load_options.get("type"))

plugin_type = config.get("plugin")
if plugin_type == "openai_plugin":
Expand All @@ -93,7 +100,7 @@ def parse_config(config):
logging.error("Unknown plugin type %s", plugin_type)
raise ValueError(f"Unknown plugin type {plugin_type}")

return concurrency, duration, plugin
return rps, concurrency, duration, plugin


def yaml_load(file):
Expand All @@ -118,7 +125,7 @@ def write_output(config, results_list):
logging.warning("Output path %s does not exist, creating it!", path)
path.mkdir(parents=True, exist_ok=True)

concurrency, duration, _ = parse_config(config)
rps, concurrency, duration, _ = parse_config(config)
outfile_name = output_options.get("file").format(
concurrency=concurrency, duration=duration
)
Expand Down Expand Up @@ -179,6 +186,14 @@ def write_output(config, results_list):
df_test_duration = df[df["output_tokens"] == df["output_tokens_before_timeout"]]
req_completed_within_test_duration = len(df_test_duration)

if rps is not None:
rps_scheduled = req_count / duration
rps_completed = req_completed_within_test_duration / duration
print(f"Actual requests per second scheduled: {rps_scheduled}")
print(f"Actual requests per second completed during run: {rps_completed}")
average_client_wait_time = df["client_wait_time"].mean()
print(f"Avg. client wait time per request: {average_client_wait_time}")

# Time per output token summary
output_obj = get_summary(df_test_duration, output_obj, "tpot")

Expand Down

0 comments on commit bf19e71

Please sign in to comment.