Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 37 additions & 14 deletions inference_perf/datagen/shared_prefix_datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,44 +86,67 @@ def get_data(self) -> Generator[InferenceAPIData, None, None]:
yield LazyLoadInferenceAPIData(data_index=i, prefered_worker_id=prefered_worker_id)
i += 1

def _generate_random_token_ids(self, length: int) -> List[int]:
def _generate_random_token_ids(self, length: int, prefix_token_ids: List[int]) -> List[int]:
"""Generates a list of random token IDs of a specified length."""
if length == 0:
return []
hf_tokenizer = self.tokenizer.get_tokenizer()
prefix_prompt_len = self.tokenizer.count_tokens(hf_tokenizer.decode(prefix_token_ids, skip_special_tokens=True))

if prefix_prompt_len > length:
raise ValueError(f"Prefix length ({prefix_prompt_len}) exceeds requested length ({length}).")

random_part_size = length - prefix_prompt_len + 5

# np.random.randint's high parameter is exclusive
return np.random.randint(0, self.vocab_size, size=length, dtype=np.int64).tolist() # type: ignore[no-any-return]
token_ids = prefix_token_ids + np.random.randint(0, self.vocab_size, size=random_part_size, dtype=np.int64).tolist()
prompt_text = hf_tokenizer.decode(token_ids, skip_special_tokens=True)

while length < self.tokenizer.count_tokens(prompt_text):
token_ids.pop()
prompt_text = hf_tokenizer.decode(token_ids, skip_special_tokens=True)

# if trimmed too many tokens, retry
if length > self.tokenizer.count_tokens(prompt_text):
token_ids = prefix_token_ids + np.random.randint(0, self.vocab_size, size=random_part_size, dtype=np.int64).tolist()
prompt_text = hf_tokenizer.decode(token_ids, skip_special_tokens=True)

return token_ids

def _generate_prompts(self) -> None:
"""Pre-generates all prompts based on the configuration."""
if self.tokenizer is None:
# This check is defensive; __init__ should have already validated this.
raise ValueError("Tokenizer is not available for generating prompts.")

hf_tokenizer = self.tokenizer.get_tokenizer()

for group_id in range(self.num_groups):
# Generate a shared prefix (system prompt)
shared_prefix_token_ids = self._generate_random_token_ids(self.system_prompt_len)
shared_prefix_token_ids = self._generate_random_token_ids(self.system_prompt_len, prefix_token_ids=[])

shared_prefix_text = hf_tokenizer.decode(shared_prefix_token_ids, skip_special_tokens=True)

for prompt_id in range(self.num_prompts_per_group):
# Generate a unique question
question_token_ids = self._generate_random_token_ids(self.question_len)
question_text = hf_tokenizer.decode(question_token_ids, skip_special_tokens=True)
total_target_length = self.system_prompt_len + self.question_len

full_token_ids = self._generate_random_token_ids(
length=total_target_length,
prefix_token_ids=shared_prefix_token_ids
)

if self.enable_multi_turn_chat:
# multi turn chat, create user to keep conversation

question_token_ids = full_token_ids[len(shared_prefix_token_ids):]
question_text = hf_tokenizer.decode(question_token_ids, skip_special_tokens=True)

self.user_sessions.append(
LocalUserSession(
user_session_id=f"user_session_{self.num_prompts_per_group * group_id + prompt_id}",
context=shared_prefix_text,
)
)
self.prompts.append(question_text)
else:
# Single turn chat, Combine shared prefix and question
question_text = shared_prefix_text + " " + question_text

self.prompts.append(question_text)
full_prompt_text = hf_tokenizer.decode(full_token_ids, skip_special_tokens=True)
self.prompts.append(full_prompt_text)

# Shuffle the generated prompts to ensure randomness if served sequentially by different workers
random.shuffle(self.prompts)
Loading