Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 20 additions & 15 deletions trinity/common/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,35 +129,40 @@ def apply_chat_template(
def _handle_prompt_truncation(self, prompt: str, **kwargs) -> Tuple[Sequence, bool]:
"""Handle prompt truncation if needed."""
# Tokenize once without truncation to check if truncation is needed
token_ids = self.tokenizer( # type: ignore
prompt,
truncation=False,
return_tensors="pt",
)[
"input_ids"
][0].tolist()
prompt_token_ids = self.tokenizer( # type: ignore
prompt, truncation=False, return_tensors="pt"
)["input_ids"][0].tolist()

# Check if truncation is needed and apply it
if (
self.config.enable_prompt_truncation
and self.config.max_prompt_tokens is not None
and len(token_ids) > self.config.max_prompt_tokens
and len(prompt_token_ids) > self.config.max_prompt_tokens
):
self.logger.warning(f"Prompt was truncated to {self.config.max_prompt_tokens} tokens")
token_ids = token_ids[: self.config.max_prompt_tokens + 1] # leave one for response

dummy_response = "[This experience is masked out due to overlong prompt]"
dummy_response_tokens = self.tokenizer( # type: ignore
dummy_response, truncation=False, return_tensors="pt"
)["input_ids"][0].tolist()
dummy_response_tokens = dummy_response_tokens[
: min(len(dummy_response_tokens), self.config.max_response_tokens) # type: ignore
]

token_ids = prompt_token_ids[: self.config.max_prompt_tokens] + dummy_response_tokens
return [
Experience(
tokens=token_ids,
logprobs=torch.zeros(1, dtype=torch.float32),
prompt_length=len(token_ids) - 1,
prompt_text=self.tokenizer.decode(token_ids[:-1]),
response_text=self.tokenizer.decode(token_ids[-1]),
logprobs=torch.zeros(len(dummy_response_tokens), dtype=torch.float32),
prompt_length=len(prompt_token_ids),
prompt_text=self.tokenizer.decode(prompt_token_ids),
response_text=dummy_response,
truncate_status="prompt_truncated",
reward=0.0,
)
for _ in range(kwargs.get("n", 1))
], False
return token_ids, True
], False # If prompt truncation is activated, return a list of dummy experiences & False
return prompt_token_ids, True # Otherwise, return prompt_token_ids & True

async def convert_messages_to_experience(
self,
Expand Down
5 changes: 3 additions & 2 deletions trinity/common/models/tinker_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@ async def generate(self, prompt: str, **kwargs) -> Sequence[Experience]:
if self.tokenizer is None:
await self._initialize_tokenizer()

token_ids, is_valid = self._handle_prompt_truncation(prompt, **kwargs)
returned_seq, is_valid = self._handle_prompt_truncation(prompt, **kwargs)
if not is_valid:
return token_ids
return returned_seq # is_valid is False: returned_seq is a list of dummy experiences
token_ids = returned_seq # is_valid is True: returned_seq is prompt's token_ids

with_chat_completion = kwargs.get("with_chat_completion", False)
if with_chat_completion:
Expand Down
10 changes: 7 additions & 3 deletions trinity/common/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,14 @@ async def generate(
if self.tokenizer is None:
await self._initialize_tokenizer()

token_ids, is_valid = self._handle_prompt_truncation(prompt, **kwargs)
returned_seq, is_valid = self._handle_prompt_truncation(prompt, **kwargs)
if not is_valid:
return token_ids
prompt = {"prompt_token_ids": token_ids}
return (
returned_seq # is_valid is False: returned_seq is a list of dummy experiences
)
prompt = {
"prompt_token_ids": returned_seq
} # is_valid is True: returned_seq is token_ids
multi_modal_inputs = None
else: # multi modal
multi_modal_inputs = build_mm_input_for_training(self.processor, **prompt)
Expand Down
Loading