diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 90f2e57571b..cf85f7c4687 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -129,35 +129,40 @@ def apply_chat_template( def _handle_prompt_truncation(self, prompt: str, **kwargs) -> Tuple[Sequence, bool]: """Handle prompt truncation if needed.""" # Tokenize once without truncation to check if truncation is needed - token_ids = self.tokenizer( # type: ignore - prompt, - truncation=False, - return_tensors="pt", - )[ - "input_ids" - ][0].tolist() + prompt_token_ids = self.tokenizer( # type: ignore + prompt, truncation=False, return_tensors="pt" + )["input_ids"][0].tolist() # Check if truncation is needed and apply it if ( self.config.enable_prompt_truncation and self.config.max_prompt_tokens is not None - and len(token_ids) > self.config.max_prompt_tokens + and len(prompt_token_ids) > self.config.max_prompt_tokens ): self.logger.warning(f"Prompt was truncated to {self.config.max_prompt_tokens} tokens") - token_ids = token_ids[: self.config.max_prompt_tokens + 1] # leave one for response + + dummy_response = "[This experience is masked out due to overlong prompt]" + dummy_response_tokens = self.tokenizer( # type: ignore + dummy_response, truncation=False, return_tensors="pt" + )["input_ids"][0].tolist() + dummy_response_tokens = dummy_response_tokens[ + : min(len(dummy_response_tokens), self.config.max_response_tokens) # type: ignore + ] + + token_ids = prompt_token_ids[: self.config.max_prompt_tokens] + dummy_response_tokens return [ Experience( tokens=token_ids, - logprobs=torch.zeros(1, dtype=torch.float32), - prompt_length=len(token_ids) - 1, - prompt_text=self.tokenizer.decode(token_ids[:-1]), - response_text=self.tokenizer.decode(token_ids[-1]), + logprobs=torch.zeros(len(dummy_response_tokens), dtype=torch.float32), + prompt_length=len(prompt_token_ids), + prompt_text=self.tokenizer.decode(prompt_token_ids), + response_text=dummy_response, truncate_status="prompt_truncated", reward=0.0, ) for _ in range(kwargs.get("n", 1)) - ], False - return token_ids, True + ], False # If prompt truncation is activated, return a list of dummy experiences & False + return prompt_token_ids, True # Otherwise, return prompt_token_ids & True async def convert_messages_to_experience( self, diff --git a/trinity/common/models/tinker_model.py b/trinity/common/models/tinker_model.py index de381d2200a..1629e1e9f95 100644 --- a/trinity/common/models/tinker_model.py +++ b/trinity/common/models/tinker_model.py @@ -52,9 +52,10 @@ async def generate(self, prompt: str, **kwargs) -> Sequence[Experience]: if self.tokenizer is None: await self._initialize_tokenizer() - token_ids, is_valid = self._handle_prompt_truncation(prompt, **kwargs) + returned_seq, is_valid = self._handle_prompt_truncation(prompt, **kwargs) if not is_valid: - return token_ids + return returned_seq # is_valid is False: returned_seq is a list of dummy experiences + token_ids = returned_seq # is_valid is True: returned_seq is prompt's token_ids with_chat_completion = kwargs.get("with_chat_completion", False) if with_chat_completion: diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 09968b7971e..f3537638d63 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -219,10 +219,14 @@ async def generate( if self.tokenizer is None: await self._initialize_tokenizer() - token_ids, is_valid = self._handle_prompt_truncation(prompt, **kwargs) + returned_seq, is_valid = self._handle_prompt_truncation(prompt, **kwargs) if not is_valid: - return token_ids - prompt = {"prompt_token_ids": token_ids} + return ( + returned_seq # is_valid is False: returned_seq is a list of dummy experiences + ) + prompt = { + "prompt_token_ids": returned_seq + } # is_valid is True: returned_seq is token_ids multi_modal_inputs = None else: # multi modal multi_modal_inputs = build_mm_input_for_training(self.processor, **prompt)