From 6458871af74853bfd9184057055d970b28125f63 Mon Sep 17 00:00:00 2001 From: yechank <161688079+yechank-nvidia@users.noreply.github.com> Date: Thu, 30 Oct 2025 14:20:21 +0900 Subject: [PATCH 1/8] refactor MultimodalInputProcessor Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com> --- .../_torch/models/modeling_gemma3vl.py | 50 ++- .../_torch/models/modeling_hyperclovax.py | 108 +++--- tensorrt_llm/_torch/models/modeling_llama.py | 59 +++- .../_torch/models/modeling_llava_next.py | 53 ++- .../_torch/models/modeling_mistral.py | 49 ++- .../_torch/models/modeling_nanov2vlm.py | 78 +++-- tensorrt_llm/_torch/models/modeling_phi4mm.py | 52 ++- .../_torch/models/modeling_qwen2vl.py | 74 ++-- tensorrt_llm/_torch/models/modeling_vila.py | 55 ++- tensorrt_llm/inputs/__init__.py | 8 +- tensorrt_llm/inputs/registry.py | 318 +++++++++++------- 11 files changed, 581 insertions(+), 323 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_gemma3vl.py b/tensorrt_llm/_torch/models/modeling_gemma3vl.py index 15e93ad0977..2fb32303e03 100644 --- a/tensorrt_llm/_torch/models/modeling_gemma3vl.py +++ b/tensorrt_llm/_torch/models/modeling_gemma3vl.py @@ -10,7 +10,8 @@ BaseWeightMapper from ..._utils import nvtx_range -from ...inputs import (ExtraProcessedInputs, InputProcessor, +from ...inputs import (BaseMultimodalDummyInputsBuilder, + BaseMultimodalInputProcessor, ExtraProcessedInputs, MultimodalPlaceholderMetadata, MultimodalPlaceholderPlacement, TextPrompt, register_input_processor) @@ -33,15 +34,46 @@ def _is_disagg() -> bool: return os.getenv(_MULTIMODAL_ENV_NAME, "0") == "1" -class Gemma3InputProcessor(InputProcessor): +from transformers import AutoTokenizer, PretrainedConfig - def __init__(self, model_path, model_config, tokenizer, trust_remote_code): - self.tokenizer = tokenizer - self.processor = AutoProcessor.from_pretrained( - model_path, trust_remote_code=trust_remote_code, use_fast=True) - self.model_config = model_config - self.device = 'cuda' +class Gemma3InputProcessor(BaseMultimodalInputProcessor, + BaseMultimodalDummyInputsBuilder): + + def __init__(self, + model_path: str, + config: PretrainedConfig, + tokenizer: AutoTokenizer, + trust_remote_code: bool = True): + super().__init__() + self._config = config + self._tokenizer = tokenizer + self._model_path = model_path + self._processor = AutoProcessor.from_pretrained( + model_path, + trust_remote_code=trust_remote_code, + use_fast=self.use_fast) + self._dtype = self.config.torch_dtype + + @property + def config(self) -> PretrainedConfig: + return self._config + + @property + def tokenizer(self) -> AutoTokenizer: + return self._tokenizer + + @property + def model_path(self) -> str: + return self._model_path + + @property + def processor(self) -> AutoProcessor: + return self._processor + + @property + def dtype(self) -> torch.dtype: + return self._dtype @nvtx_range("[Vision] preprocess") def _preprocess(self, inputs): @@ -59,7 +91,7 @@ def _preprocess(self, inputs): images=images, do_rescale=do_rescale, return_tensors="pt", - device=self.device).to(dtype=torch.bfloat16) + ).to(dtype=self.dtype) input_ids = processor_output["input_ids"] pixel_values = processor_output.get("pixel_values") diff --git a/tensorrt_llm/_torch/models/modeling_hyperclovax.py b/tensorrt_llm/_torch/models/modeling_hyperclovax.py index 692d24750dc..9dcf039d175 100644 --- a/tensorrt_llm/_torch/models/modeling_hyperclovax.py +++ b/tensorrt_llm/_torch/models/modeling_hyperclovax.py @@ -15,8 +15,9 @@ from tensorrt_llm.inputs.multimodal import MultimodalParams -from ...inputs import (BaseMultimodalInputProcessor, ExtraProcessedInputs, - InputProcessor, MultimodalPlaceholderMetadata, +from ...inputs import (BaseMultimodalDummyInputsBuilder, + BaseMultimodalInputProcessor, ExtraProcessedInputs, + MultimodalPlaceholderMetadata, MultimodalPlaceholderPlacement, TextPrompt, register_input_processor) from ...logger import logger @@ -564,33 +565,54 @@ def build_mlp( return nn.Sequential(*layers) -class HCXVisionInputProcessor(BaseMultimodalInputProcessor, InputProcessor): +class HCXVisionInputProcessor(BaseMultimodalDummyInputsBuilder, + BaseMultimodalInputProcessor): def __init__(self, model_path: str, - model_config: PretrainedConfig, + config: PretrainedConfig, tokenizer: AutoTokenizer, trust_remote_code: bool = True): - - self.pretrained_config = model_config - self.tokenizer = tokenizer - self.use_fast = True - if self.tokenizer is None: - self.tokenizer = AutoTokenizer.from_pretrained( - model_path, - trust_remote_code=trust_remote_code, - use_fast=self.use_fast) - self.processor = AutoProcessor.from_pretrained( + super().__init__() + self._config = config + self._tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained( + model_path, + trust_remote_code=trust_remote_code, + use_fast=self.use_fast) + self._processor = AutoProcessor.from_pretrained( model_path, trust_remote_code=trust_remote_code, use_fast=self.use_fast) - self.tllm_multimodal_token_id = self.pretrained_config.language_config[ + self._model_path = model_path + self._dtype = self.config.torch_dtype + + self.tllm_multimodal_token_id = self.config.language_config[ "vocab_size"] + 1 self.vision_query_lengths = None self._vision_query_generator = None + @property + def config(self) -> PretrainedConfig: + return self._config + + @property + def tokenizer(self) -> AutoTokenizer: + return self._tokenizer + + @property + def model_path(self) -> str: + return self._model_path + + @property + def processor(self) -> AutoProcessor: + return self._processor + + @property + def dtype(self) -> torch.dtype: + return self._dtype + def get_vocab_size(self): - return self.pretrained_config.language_config["vocab_size"] + return self.config.language_config["vocab_size"] def get_num_tokens_per_image( self, @@ -656,8 +678,7 @@ def _post_process(self, vision_query_lengths = preprocessed_image.get("vision_query_lengths", None) non_vision_query_lengths = determine_non_vision_query_lengths( - input_ids, self.tokenizer.pad_token_id, - self.pretrained_config.img_start_id) + input_ids, self.tokenizer.pad_token_id, self.config.img_start_id) batch_size = input_ids.size(0) len_inputs_embeds = max([ @@ -666,11 +687,10 @@ def _post_process(self, non_vision_query_lengths, vision_query_lengths) ]) - len_inputs_embeds = min(self.pretrained_config.decoder_max_length, + len_inputs_embeds = min(self.config.decoder_max_length, len_inputs_embeds) - image_cnts = (input_ids == self.pretrained_config.img_start_id).sum( - dim=1).tolist() + image_cnts = (input_ids == self.config.img_start_id).sum(dim=1).tolist() fused_input_ids = torch.zeros([batch_size, len_inputs_embeds], dtype=input_ids.dtype) @@ -678,7 +698,7 @@ def _post_process(self, non_vision_query_length = non_vision_query_lengths[batch_idx] sample = sample[:non_vision_query_length + image_cnts[batch_idx]] - mask = (sample == self.pretrained_config.img_start_id) + mask = (sample == self.config.img_start_id) img_start_ids = mask.nonzero() input_start, temp_start = 0, 0 @@ -779,7 +799,7 @@ class HCXVisionModel(nn.Module): def __init__(self, model_config: ModelConfig[PretrainedConfig]): super().__init__() self.model_config = model_config - self.pretrained_config = model_config.pretrained_config + self.config = model_config.pretrained_config siglip_model_config = copy.deepcopy(self.model_config) siglip_model_config.pretrained_config = self.model_config.pretrained_config.vision_config self.visual_token_idx = 0 if "siglip" in self.model_config.pretrained_config.vision_config.model_type else 1 @@ -787,24 +807,22 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]): self.vision_model = SiglipVisionModel(siglip_model_config).to( self.dtype) self.mm_projector = HCXVisionCAbstractor( - num_queries=self.pretrained_config.num_queries_vis_abstractor, - num_input_tokens=( - self.pretrained_config.vision_config.image_size // - self.pretrained_config.vision_config.patch_size)**2, - encoder_hidden_size=self.pretrained_config.vision_config. - hidden_size, - hidden_size=self.pretrained_config.vision_config.hidden_size, - output_hidden_size=self.pretrained_config.hidden_size, - pos_emb=self.pretrained_config.proj_pos_emb, - prenorm=self.pretrained_config.proj_prenorm, + num_queries=self.config.num_queries_vis_abstractor, + num_input_tokens=(self.config.vision_config.image_size // + self.config.vision_config.patch_size)**2, + encoder_hidden_size=self.config.vision_config.hidden_size, + hidden_size=self.config.vision_config.hidden_size, + output_hidden_size=self.config.hidden_size, + pos_emb=self.config.proj_pos_emb, + prenorm=self.config.proj_prenorm, ).to(self.dtype) self.image_newline = nn.Parameter(torch.empty( - self.pretrained_config.hidden_size, ), + self.config.hidden_size, ), requires_grad=False).to(self.dtype) - self.unpad = self.pretrained_config.unpad - self.use_nth_layer = self.pretrained_config.use_nth_layer - self.anyres = self.pretrained_config.anyres + self.unpad = self.config.unpad + self.use_nth_layer = self.config.use_nth_layer + self.anyres = self.config.anyres self.possible_resolutions = self._init_possible_resolutions() self.post_config() @@ -814,18 +832,18 @@ def post_config(self): def _init_possible_resolutions(self): possible_resolutions = [] - if self.pretrained_config.anyres: - assert self.pretrained_config.max_num_grids > 0 - for i in range(1, self.pretrained_config.max_num_grids + 1): - for j in range(1, self.pretrained_config.max_num_grids + 1): - if i == 1 and j == 1 and not self.pretrained_config.use_1x1_grid: + if self.config.anyres: + assert self.config.max_num_grids > 0 + for i in range(1, self.config.max_num_grids + 1): + for j in range(1, self.config.max_num_grids + 1): + if i == 1 and j == 1 and not self.config.use_1x1_grid: continue - if i * j <= self.pretrained_config.max_num_grids: + if i * j <= self.config.max_num_grids: possible_resolutions.append([i, j]) possible_resolutions = [[ - ys * self.pretrained_config.vision_config.image_size, - xs * self.pretrained_config.vision_config.image_size + ys * self.config.vision_config.image_size, + xs * self.config.vision_config.image_size ] for ys, xs in possible_resolutions] return possible_resolutions diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index 19f70543de5..42eea7f9f6f 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -21,7 +21,8 @@ from tensorrt_llm.lora_manager import HfLoraLoader from tensorrt_llm.models.convert_utils import split_matrix_tp -from ...inputs import (ExtraProcessedInputs, InputProcessor, +from ...inputs import (BaseMultimodalDummyInputsBuilder, + BaseMultimodalInputProcessor, ExtraProcessedInputs, MultimodalPlaceholderMetadata, MultimodalPlaceholderPlacement, TextPrompt, register_input_processor) @@ -1042,26 +1043,54 @@ def forward(self, multimodal_params: List[MultimodalParams]): return [image_features] -class Llama4InputProcessor(InputProcessor): +from transformers import AutoTokenizer, PretrainedConfig + + +class Llama4InputProcessor(BaseMultimodalInputProcessor, + BaseMultimodalDummyInputsBuilder): def __init__(self, - model_path, - model_config, - tokenizer, + model_path: str, + config: PretrainedConfig, + tokenizer: AutoTokenizer, trust_remote_code: bool = True): - self.use_fast = True - self.processor = AutoProcessor.from_pretrained( + super().__init__() + self._config = config + self._dtype = self._config.torch_dtype + self._tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained( + model_path) + self._model_path = model_path + self._processor = AutoProcessor.from_pretrained( model_path, - trust_remote_code=trust_remote_code, - use_fast=self.use_fast) - self.model_config = model_config - self.tokenizer = tokenizer - self.vocab_size = model_config.text_config.vocab_size - self.image_token_index = model_config.image_token_index + use_fast=self.use_fast, + trust_remote_code=trust_remote_code) + + self.vocab_size = self.config.text_config.vocab_size + self.image_token_index = self.config.image_token_index self.fake_image_token = self.processor.fake_image_token self.image_token = self.processor.img_patch_token - self.image_token_start_index = self.model_config.boi_token_index - self.image_token_end_index = self.model_config.eoi_token_index + self.image_token_start_index = self.config.boi_token_index + self.image_token_end_index = self.config.eoi_token_index + + @property + def config(self) -> PretrainedConfig: + return self._config + + @property + def tokenizer(self) -> AutoTokenizer: + return self._tokenizer + + @property + def model_path(self) -> str: + return self._model_path + + @property + def processor(self) -> AutoProcessor: + return self._processor + + @property + def dtype(self) -> torch.dtype: + return self._dtype def attach_multimodal_embeddings( self, inputs: TextPrompt, multimodal_embedding: Dict[str, diff --git a/tensorrt_llm/_torch/models/modeling_llava_next.py b/tensorrt_llm/_torch/models/modeling_llava_next.py index d0e7d2adfc5..aa9369ed046 100644 --- a/tensorrt_llm/_torch/models/modeling_llava_next.py +++ b/tensorrt_llm/_torch/models/modeling_llava_next.py @@ -14,8 +14,9 @@ from tensorrt_llm.inputs.multimodal import MultimodalParams -from ...inputs import (BaseMultimodalInputProcessor, ExtraProcessedInputs, - InputProcessor, MultimodalPlaceholderMetadata, +from ...inputs import (BaseMultimodalDummyInputsBuilder, + BaseMultimodalInputProcessor, ExtraProcessedInputs, + MultimodalPlaceholderMetadata, MultimodalPlaceholderPlacement, TextPrompt, register_input_processor, support_multimodal_disaggregated) @@ -34,29 +35,49 @@ DISAGG = os.getenv('TLLM_MULTIMODAL_DISAGGREGATED', '0') == '1' -class LlavaNextInputProcessor(BaseMultimodalInputProcessor, InputProcessor): +class LlavaNextInputProcessor(BaseMultimodalInputProcessor, + BaseMultimodalDummyInputsBuilder): def __init__(self, model_path: str, - model_config: PretrainedConfig, + config: PretrainedConfig, tokenizer: AutoTokenizer, trust_remote_code: bool = True): - self.tokenizer = tokenizer - self.use_fast = True - if self.tokenizer is None: - self.tokenizer = AutoTokenizer.from_pretrained( - model_path, - trust_remote_code=trust_remote_code, - use_fast=self.use_fast) - self.processor = AutoProcessor.from_pretrained( + super().__init__() + self._config = config + self._tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained( model_path, trust_remote_code=trust_remote_code, use_fast=self.use_fast) - self.model_config = model_config + self._processor = AutoProcessor.from_pretrained( + model_path, + trust_remote_code=trust_remote_code, + use_fast=self.use_fast) + self._model_path = model_path + self._dtype = self.config.text_config.torch_dtype + + self.image_token_index = config.image_token_index + self.vocab_size = config.vocab_size + + @property + def config(self) -> PretrainedConfig: + return self._config + + @property + def tokenizer(self) -> AutoTokenizer: + return self._tokenizer + + @property + def model_path(self) -> str: + return self._model_path + + @property + def processor(self) -> AutoProcessor: + return self._processor - self.image_token_index = model_config.image_token_index - self.vocab_size = model_config.vocab_size - self.config = model_config.vision_config + @property + def dtype(self) -> torch.dtype: + return self._dtype def _postprocess( self, input_ids: torch.Tensor, mm_features: Union[torch.Tensor, diff --git a/tensorrt_llm/_torch/models/modeling_mistral.py b/tensorrt_llm/_torch/models/modeling_mistral.py index 86abe856c41..9c4cf6b54c4 100644 --- a/tensorrt_llm/_torch/models/modeling_mistral.py +++ b/tensorrt_llm/_torch/models/modeling_mistral.py @@ -29,8 +29,9 @@ from tensorrt_llm._torch.speculative import SpecMetadata from tensorrt_llm._utils import nvtx_range from tensorrt_llm.functional import PositionEmbeddingType -from tensorrt_llm.inputs import (BaseMultimodalInputProcessor, - ExtraProcessedInputs, InputProcessor, +from tensorrt_llm.inputs import (BaseMultimodalDummyInputsBuilder, + BaseMultimodalInputProcessor, + ExtraProcessedInputs, MultimodalPlaceholderMetadata, MultimodalPlaceholderPlacement, TextPrompt, register_input_processor) @@ -214,26 +215,46 @@ def __init__( ) -class Mistral3InputProcessor(BaseMultimodalInputProcessor, InputProcessor): +class Mistral3InputProcessor(BaseMultimodalInputProcessor, + BaseMultimodalDummyInputsBuilder): def __init__( self, model_path: str, - model_config: PretrainedConfig, + config: PretrainedConfig, tokenizer: Optional[AutoTokenizer], trust_remote_code: bool = False, ): - if tokenizer is None: - tokenizer = AutoTokenizer.from_pretrained(model_path, - use_fast=False) + super().__init__() + self._config = config + self._dtype = self._config.torch_dtype + self._tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained( + model_path) + self._model_path = model_path + self._processor = AutoProcessor.from_pretrained( + model_path, + use_fast=self.use_fast, + trust_remote_code=trust_remote_code) - # To abide by the `InputProcessor` interface. - self.model_path = model_path - self.model_config = model_config - self.tokenizer = tokenizer + @property + def config(self) -> PretrainedConfig: + return self._config + + @property + def tokenizer(self) -> AutoTokenizer: + return self._tokenizer - self._processor = AutoProcessor.from_pretrained(model_path, - use_fast=False) + @property + def model_path(self) -> str: + return self._model_path + + @property + def processor(self) -> AutoProcessor: + return self._processor + + @property + def dtype(self) -> torch.dtype: + return self._dtype @torch.inference_mode() def __call__( @@ -282,7 +303,7 @@ def get_vocab_size(self) -> int: """Return the vocab size of the model.""" # Unlike some other VLMs, mistral3's vocab size is stored in its `text_config`, not the top-level # config. - return self.model_config.text_config.vocab_size + return self.config.text_config.vocab_size def get_mm_token_ids(self) -> torch.Tensor: """Get the IDs of all multimodal tokens (placeholders and special tokens alike).""" diff --git a/tensorrt_llm/_torch/models/modeling_nanov2vlm.py b/tensorrt_llm/_torch/models/modeling_nanov2vlm.py index e84660ccf81..5e8f2ecd078 100644 --- a/tensorrt_llm/_torch/models/modeling_nanov2vlm.py +++ b/tensorrt_llm/_torch/models/modeling_nanov2vlm.py @@ -11,8 +11,9 @@ from tensorrt_llm._torch.models.checkpoints import NemotronHHfWeightMapper from tensorrt_llm.inputs.multimodal import MultimodalParams -from ...inputs import (BaseMultimodalInputProcessor, ExtraProcessedInputs, - InputProcessor, MultimodalPlaceholderMetadata, +from ...inputs import (BaseMultimodalDummyInputsBuilder, + BaseMultimodalInputProcessor, ExtraProcessedInputs, + MultimodalPlaceholderMetadata, MultimodalPlaceholderPlacement, TextPrompt, compute_retained_tokens_count, compute_retention_mask, register_input_processor) @@ -254,49 +255,70 @@ def forward( return mm_embedding, num_tokens_in_videos -class NanoV2VLInputProcessor(BaseMultimodalInputProcessor, InputProcessor): +class NanoV2VLInputProcessor(BaseMultimodalInputProcessor, + BaseMultimodalDummyInputsBuilder): def __init__(self, model_path: str, - model_config: transformers.PretrainedConfig, + config: transformers.PretrainedConfig, tokenizer: transformers.AutoTokenizer, trust_remote_code: bool = True): + super().__init__() if not trust_remote_code: - raise ValueError("trust_remote_code must be True for NanoV2VL") + raise ValueError("trust_remote_code must be True for Phi4MM") + + self._config = config + self._tokenizer = tokenizer if tokenizer is not None else transformers.AutoTokenizer.from_pretrained( + model_path, + trust_remote_code=trust_remote_code, + use_fast=self.use_fast) + self._processor = transformers.AutoProcessor.from_pretrained( + model_path, + trust_remote_code=trust_remote_code, + use_fast=self.use_fast) + self._model_path = model_path + self._dtype = self.config.torch_dtype + self.device = 'cpu' - self.model_config = model_config - self.image_size = model_config.force_image_size - self.patch_size = model_config.patch_size - self.downsample_ratio = model_config.downsample_ratio + self.image_size = self.config.force_image_size + self.patch_size = self.config.patch_size + self.downsample_ratio = self.config.downsample_ratio self.spatial_merge_size = int(self.patch_size / self.downsample_ratio) - self.img_context_token_id = model_config.img_context_token_id + self.img_context_token_id = self.config.img_context_token_id self.num_image_token = int((self.image_size // self.patch_size)**2 * (self.downsample_ratio**2)) self.video_pruning_ratio = VIDEO_PRUNING_RATIO - - self.device = 'cpu' - - self.tokenizer = tokenizer - self.use_fast = True - if self.tokenizer is None: - self.tokenizer = transformers.AutoTokenizer.from_pretrained( - model_path, trust_remote_code=True, use_fast=self.use_fast) - - self.processor = transformers.AutoImageProcessor.from_pretrained( - model_path, trust_remote_code=True, use_fast=self.use_fast) - - self.img_context_token = model_config.img_context_token - self.video_context_token = model_config.video_context_token - self.img_start_token = model_config.img_start_token - self.img_end_token = model_config.img_end_token - self.dtype = model_config.torch_dtype + self.img_context_token = self.config.img_context_token + self.video_context_token = self.config.video_context_token + self.img_start_token = self.config.img_start_token + self.img_end_token = self.config.img_end_token self.image_start_token_id = self.tokenizer.encode( self.img_start_token, add_special_tokens=False)[0] self.image_end_token_id = self.tokenizer.encode( self.img_end_token, add_special_tokens=False)[0] + @property + def config(self) -> transformers.PretrainedConfig: + return self._config + + @property + def tokenizer(self) -> transformers.AutoTokenizer: + return self._tokenizer + + @property + def model_path(self) -> str: + return self._model_path + + @property + def processor(self) -> transformers.AutoProcessor: + return self._processor + + @property + def dtype(self) -> torch.dtype: + return self._dtype + def get_vocab_size(self): - return self.model_config.llm_config.vocab_size + return self.config.llm_config.vocab_size def get_mm_special_token_ids(self) -> torch.Tensor: " Return multimodal special token ids for NanoV2VL. " diff --git a/tensorrt_llm/_torch/models/modeling_phi4mm.py b/tensorrt_llm/_torch/models/modeling_phi4mm.py index f80d09da078..f04e6280b64 100644 --- a/tensorrt_llm/_torch/models/modeling_phi4mm.py +++ b/tensorrt_llm/_torch/models/modeling_phi4mm.py @@ -32,8 +32,9 @@ from tensorrt_llm.inputs.multimodal import MultimodalParams from ...executor.request import LoRARequest -from ...inputs import (BaseMultimodalInputProcessor, ExtraProcessedInputs, - InputProcessor, MultimodalPlaceholderMetadata, +from ...inputs import (BaseMultimodalDummyInputsBuilder, + BaseMultimodalInputProcessor, ExtraProcessedInputs, + MultimodalPlaceholderMetadata, MultimodalPlaceholderPlacement, TextPrompt, register_input_processor) from ...logger import logger @@ -755,31 +756,29 @@ def forward(self, multimodal_params: List[MultimodalParams], return self._encoding_batch_request(multimodal_params, mm_token_ids) -class Phi4MMInputProcessor(BaseMultimodalInputProcessor, InputProcessor): +class Phi4MMInputProcessor(BaseMultimodalInputProcessor, + BaseMultimodalDummyInputsBuilder): def __init__(self, model_path: str, - model_config: transformers.PretrainedConfig, + config: transformers.PretrainedConfig, tokenizer: transformers.AutoTokenizer, trust_remote_code: bool = True): + super().__init__() if not trust_remote_code: raise ValueError("trust_remote_code must be True for Phi4MM") - self.model_config = model_config - self.device = 'cpu' - - self.tokenizer = tokenizer - self.use_fast = True - if self.tokenizer is None: - self.tokenizer = transformers.AutoTokenizer.from_pretrained( - model_path, - trust_remote_code=trust_remote_code, - use_fast=self.use_fast) - - self.processor = transformers.AutoProcessor.from_pretrained( + self._config = config + self._tokenizer = tokenizer if tokenizer is not None else transformers.AutoTokenizer.from_pretrained( model_path, trust_remote_code=trust_remote_code, use_fast=self.use_fast) + self._processor = transformers.AutoProcessor.from_pretrained( + model_path, + trust_remote_code=trust_remote_code, + use_fast=self.use_fast) + self._model_path = model_path + self._dtype = self.config.torch_dtype # Bind the optimized methods to the image processor instance self.processor.image_processor.dynamic_preprocess = MethodType( dynamic_preprocess, @@ -789,8 +788,27 @@ def __init__(self, image_preprocess, self.processor.image_processor, ) + self.device = 'cpu' - self.dtype = model_config.torch_dtype + @property + def config(self) -> transformers.PretrainedConfig: + return self._config + + @property + def tokenizer(self) -> transformers.AutoTokenizer: + return self._tokenizer + + @property + def model_path(self) -> str: + return self._model_path + + @property + def processor(self) -> transformers.AutoProcessor: + return self._processor + + @property + def dtype(self) -> torch.dtype: + return self._dtype def get_mm_token_ids(self) -> Optional[torch.Tensor]: return torch.tensor([_IMAGE_SPECIAL_TOKEN_ID, _AUDIO_SPECIAL_TOKEN_ID], diff --git a/tensorrt_llm/_torch/models/modeling_qwen2vl.py b/tensorrt_llm/_torch/models/modeling_qwen2vl.py index b1ff8ed6238..f3805bfa84b 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen2vl.py +++ b/tensorrt_llm/_torch/models/modeling_qwen2vl.py @@ -28,8 +28,8 @@ from tensorrt_llm.inputs.multimodal import MultimodalParams from ..._utils import nvtx_range -from ...inputs import (BaseDummyInputsBuilder, BaseMultimodalInputProcessor, - ExtraProcessedInputs, InputProcessor, +from ...inputs import (BaseMultimodalDummyInputsBuilder, + BaseMultimodalInputProcessor, ExtraProcessedInputs, MultimodalPlaceholderMetadata, MultimodalPlaceholderPlacement, TextPrompt, register_input_processor) @@ -88,36 +88,55 @@ def process_weights(weights: Dict, return filtered_weights -class Qwen2VLInputProcessorBase(BaseDummyInputsBuilder, - BaseMultimodalInputProcessor, InputProcessor): +class Qwen2VLInputProcessorBase(BaseMultimodalInputProcessor, + BaseMultimodalDummyInputsBuilder): def __init__(self, model_path: str, - model_config: PretrainedConfig, + config: PretrainedConfig, tokenizer: AutoTokenizer, trust_remote_code: bool = True): super().__init__() - self.model_config = model_config - self.vision_dtype = self.model_config.torch_dtype - self.tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained( + self._config = config + self._dtype = self._config.torch_dtype + self._tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained( model_path) - self.use_fast = True - self.model_path = model_path - self.processor = AutoProcessor.from_pretrained( + self._model_path = model_path + self._processor = AutoProcessor.from_pretrained( model_path, use_fast=self.use_fast, trust_remote_code=trust_remote_code) - self.tllm_multimodal_token_id = self.model_config.vocab_size + 1 + self.tllm_multimodal_token_id = self.get_vocab_size() + 1 # temporal patch size for video frames - self.temporal_patch_size = getattr(model_config.vision_config, + self.temporal_patch_size = getattr(self._config.vision_config, 'temporal_patch_size', 1) + @property + def config(self) -> PretrainedConfig: + return self._config + + @property + def tokenizer(self) -> AutoTokenizer: + return self._tokenizer + + @property + def model_path(self) -> str: + return self._model_path + + @property + def processor(self) -> AutoProcessor: + return self._processor + + @property + def dtype(self) -> torch.dtype: + return self._dtype + @classmethod def get_rope_index( cls, - model_config: PretrainedConfig, + config: PretrainedConfig, input_ids: Optional[torch.IntTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, @@ -131,7 +150,7 @@ def get_rope_index( The main difference between the two implementations is how temporal position IDs are calculated. Args: - model_config: The model configuration + config: The HF model configuration input_ids: Indices of input sequence tokens in the vocabulary image_grid_thw: The temporal, height and width of feature shape of each image in LLM video_grid_thw: The temporal, height and width of feature shape of each video in LLM @@ -142,10 +161,10 @@ def get_rope_index( position_ids: A tensor of shape (3, batch_size, sequence_length) mrope_position_deltas: A tensor of shape (batch_size) """ - spatial_merge_size = model_config.vision_config.spatial_merge_size - image_token_id = model_config.image_token_id - video_token_id = model_config.video_token_id - vision_start_token_id = model_config.vision_start_token_id + spatial_merge_size = config.vision_config.spatial_merge_size + image_token_id = config.image_token_id + video_token_id = config.video_token_id + vision_start_token_id = config.vision_start_token_id mrope_position_deltas = [] # Handle case with no vision inputs @@ -247,14 +266,14 @@ def get_rope_index( torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) # Calculate temporal position IDs based on model type - if hasattr(model_config.vision_config, 'tokens_per_second'): + if hasattr(config.vision_config, 'tokens_per_second'): # Qwen2_5_VL style temporal position calculation if isinstance(second_per_grid_t, torch.Tensor): second_per_grid_t = second_per_grid_t.item() range_tensor = torch.arange(llm_grid_t).view(-1, 1) expanded_range = range_tensor.expand( -1, llm_grid_h * llm_grid_w) - time_tensor = expanded_range * second_per_grid_t * model_config.vision_config.tokens_per_second + time_tensor = expanded_range * second_per_grid_t * config.vision_config.tokens_per_second t_index = time_tensor.long().flatten() else: # Qwen2VL style temporal position calculation @@ -312,9 +331,9 @@ def _preprocess(self, text: dict[str, any], mm_data: dict[str, any], **mm_processor_kwargs) def _postprocess(self, input_ids: torch.IntTensor) -> torch.IntTensor: - masks = (input_ids == self.model_config.image_token_id) | ( - input_ids == self.model_config.vision_token_id) | ( - input_ids == self.model_config.video_token_id) + masks = (input_ids == self.config.image_token_id) | ( + input_ids == self.config.vision_token_id) | ( + input_ids == self.config.video_token_id) input_ids[masks] = self.tllm_multimodal_token_id return input_ids @@ -326,7 +345,7 @@ def get_mrope_config( attention_mask: torch.Tensor, second_per_grid_ts: torch.Tensor = None) -> dict[str, torch.Tensor]: mrope_position_ids, mrope_position_deltas = Qwen2VLInputProcessorBase.get_rope_index( - self.model_config, input_ids, image_grid_thw, video_grid_thw, + self.config, input_ids, image_grid_thw, video_grid_thw, attention_mask, second_per_grid_ts) mrope_config = {} @@ -352,15 +371,14 @@ def __call__( pixel_values = processed_inputs.get('pixel_values', None) if pixel_values is not None: multimodal_data["image"] = { - "pixel_values": pixel_values.to(self.vision_dtype), + "pixel_values": pixel_values.to(self.dtype), "image_grid_thw": processed_inputs.get('image_grid_thw') } pixel_values_videos = processed_inputs.get('pixel_values_videos', None) if pixel_values_videos is not None: multimodal_data["video"] = { - "pixel_values_videos": - pixel_values_videos.to(self.vision_dtype), + "pixel_values_videos": pixel_values_videos.to(self.dtype), "video_grid_thw": processed_inputs.get('video_grid_thw') } diff --git a/tensorrt_llm/_torch/models/modeling_vila.py b/tensorrt_llm/_torch/models/modeling_vila.py index e9e420dd64f..4b76ba122fc 100644 --- a/tensorrt_llm/_torch/models/modeling_vila.py +++ b/tensorrt_llm/_torch/models/modeling_vila.py @@ -35,7 +35,7 @@ PreTrainedModel) from ..._utils import nvtx_range -from ...inputs import (ExtraProcessedInputs, InputProcessor, +from ...inputs import (BaseMultimodalInputProcessor, ExtraProcessedInputs, MultimodalPlaceholderMetadata, MultimodalPlaceholderPlacement, TextPrompt, register_input_processor) @@ -864,31 +864,50 @@ def _apply_chat_template(text, conv, tokenizer): return text -class VilaInputProcessor(InputProcessor): +from transformers import AutoProcessor + + +class VilaInputProcessor(BaseMultimodalInputProcessor): def __init__(self, - model_path, - model_config, - tokenizer, + model_path: str, + config: PretrainedConfig, + tokenizer: AutoTokenizer, trust_remote_code: bool = True): - self.model_config = model_config + super().__init__() + self._config = config llm_path, vision_tower_path, mm_projector_path = _get_model_paths( - self.model_config) + self.config) + self._dtype = self.config.model_dtype + self._tokenizer = init_tokenizer( + llm_path) if tokenizer is None else tokenizer + self.device = 'cuda' - self.model_dtype = _convert_dtype(self.model_config.model_dtype) self.conv_mode = _get_conversation_mode(llm_path) - - self.tokenizer = init_tokenizer( - llm_path) if tokenizer is None else tokenizer self.vision_tower, self.image_processor = init_vision_tower( - vision_tower_path, self.model_config) - self.mm_projector = init_mm_projector(mm_projector_path, - self.model_config) + vision_tower_path, self.config) + self.mm_projector = init_mm_projector(mm_projector_path, self.config) # must be fp16 self.vision_tower.to(device=self.device, dtype=torch.float16) self.mm_projector.to(device=self.device, dtype=torch.float16) + @property + def config(self) -> PretrainedConfig: + return self._config + + @property + def tokenizer(self) -> AutoTokenizer: + return self._tokenizer + + @property + def processor(self) -> AutoProcessor: + return None + + @property + def dtype(self) -> torch.dtype: + return self._dtype + @nvtx_range("[Vision] preprocess") def _preprocess(self, mm_data: dict[str, any], @@ -904,7 +923,7 @@ def _preprocess(self, images = mm_data["image"] return process_images(images, self.image_processor, - self.model_config, + self.config, enable_dynamic_res=True, enable_dynamic_s2=True, use_fast=use_fast, @@ -919,7 +938,7 @@ def _preprocess(self, mm_tensor, block_sizes = process_images( video, self.image_processor, - self.model_config, + self.config, enable_dynamic_res=False, enable_dynamic_s2=False, use_fast=use_fast, @@ -935,7 +954,7 @@ def _process(self, mm_tensor, block_sizes): """Extract multimodal features from multimodal input""" mm_tensor = mm_tensor.to(self.vision_tower.dtype) # must be fp16 - if getattr(self.model_config, "dynamic_s2", False): + if getattr(self.config, "dynamic_s2", False): # dynamic S2 logic in https://github.com/NVlabs/VILA/blob/main/llava/model/llava_arch.py::encoder_images() if block_sizes is None: block_sizes = [None] * len(mm_tensor) @@ -1012,7 +1031,7 @@ def _postprocess(self, input_ids, mm_features): raise ValueError( f"Invalid multimodal features type: {type(mm_features)}") mm_total_length = sum(mm_lengths_per_split) - assert mm_hidden_dim == self.model_config.hidden_size, "Multimodal embedding_dim must match model hidden_size" + assert mm_hidden_dim == self.config.hidden_size, "Multimodal embedding_dim must match model hidden_size" ## split input_ids into segments by isolating mm tokens vocab_size = len(self.tokenizer) # vocab including special tokens diff --git a/tensorrt_llm/inputs/__init__.py b/tensorrt_llm/inputs/__init__.py index 9f4251b6504..406a71d4f5c 100644 --- a/tensorrt_llm/inputs/__init__.py +++ b/tensorrt_llm/inputs/__init__.py @@ -1,9 +1,9 @@ from .data import PromptInputs, TextPrompt, TokensPrompt, prompt_inputs from .evs import compute_retained_tokens_count, compute_retention_mask from .multimodal import MultimodalInput -from .registry import (BaseDummyInputsBuilder, BaseMultimodalInputProcessor, - ExtraProcessedInputs, InputProcessor, - MultimodalPlaceholderMetadata, +from .registry import (BaseMultimodalDummyInputsBuilder, + BaseMultimodalInputProcessor, ExtraProcessedInputs, + InputProcessor, MultimodalPlaceholderMetadata, MultimodalPlaceholderPlacement, create_input_processor, create_input_processor_with_hash, register_input_processor, @@ -35,7 +35,7 @@ "register_input_processor", "support_multimodal_disaggregated", "ExtraProcessedInputs", - "BaseDummyInputsBuilder", + "BaseMultimodalDummyInputsBuilder", "BaseMultimodalInputProcessor", "MultimodalPlaceholderMetadata", "MultimodalPlaceholderPlacement", diff --git a/tensorrt_llm/inputs/registry.py b/tensorrt_llm/inputs/registry.py index 399da7b9ef8..4841369a1fb 100644 --- a/tensorrt_llm/inputs/registry.py +++ b/tensorrt_llm/inputs/registry.py @@ -4,6 +4,7 @@ from typing import (Any, Callable, Dict, List, Optional, Protocol, Tuple, Type, TypeVar) +import torch from PIL import Image from torch import Tensor, nn @@ -21,6 +22,11 @@ ExtraProcessedInputs = Dict[str, Any] +from abc import ABC, abstractmethod + +from transformers import (AutoProcessor, PretrainedConfig, + PreTrainedTokenizerBase) + class InputProcessor(Protocol): """ @@ -35,9 +41,8 @@ class InputProcessor(Protocol): """ model_path: any - model_config: any + config: any tokenizer: any - multimodal_hashing_supported: Optional[bool] = None def __call__( self, inputs: TextPrompt, sampling_params: SamplingParams @@ -45,49 +50,77 @@ def __call__( ... -class BaseDummyInputsBuilder: - """ - Base class for generating dummy inputs. Specially for profiling - """ - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.image_max_dim = 16384 - self.img_min_dim = 128 - - def get_dummy_image(self, max_width: int, max_height: int): - image = Image.new("RGB", (max_width, max_height), - color=random.randint(0, 256)) - return image - - def get_dummy_prompt(self, input_seq_len: int): - # TODO(yechank): We use the max resolution as starting point and keep reducing the resolution until the prompt length is less than the input sequence length. - # Need to find better way to calculate the dummy prompt length as this iteration may not be efficient. - while self.image_max_dim >= self.img_min_dim: - image = self.get_dummy_image(max_width=self.image_max_dim, - max_height=self.image_max_dim) +class DefaultInputProcessor(InputProcessor): + """Preprocess the inputs to the model.""" - test_mm_prompt = tensorrt_llm.inputs.utils.default_multimodal_input_loader( - tokenizer=self.tokenizer, - model_dir=self.model_path, - model_type=self.model_config.model_type, - modality="image", - prompts=[""], - media=[[image]], - image_data_format="pt")[0] + def __init__(self, + model_path, + config, + tokenizer, + trust_remote_code: bool = True) -> None: + self.tokenizer = tokenizer + self.config = config + self.model_path = model_path + self.multimodal_hashing_supported = None - prompt_token_ids_single_img, _ = self(test_mm_prompt, None) + def __call__( + self, inputs: TextPrompt, sampling_params: SamplingParams + ) -> Tuple[List[int], Optional[ExtraProcessedInputs]]: + """The default input processor handles only tokenization.""" + if self.tokenizer is None: + raise ValueError("tokenizer is required to tokenize string prompt") + kwargs = {} + if sampling_params.truncate_prompt_tokens is not None: + kwargs = dict(truncation=True, + max_length=sampling_params.truncate_prompt_tokens) + toktoken_special_tokens = { + "<|startoftext|>", + "<|endoftext|>", + "<|reserved_200000|>", + "<|reserved_200001|>", + "<|return|>", + "<|constrain|>", + "<|reserved_200004|>", + "<|channel|>", + "<|start|>", + "<|end|>", + "<|message|>", + "<|reserved_200009|>", + "<|reserved_200010|>", + "<|reserved_200011|>", + "<|call|>", + "<|reserved_200013|>", + } + with nvtx_range_debug("tokenize prompt"): + try: + token_ids = self.tokenizer.encode( + inputs["prompt"], + add_special_tokens=sampling_params.add_special_tokens, + **kwargs) + except: + # Tiktoken path + token_ids = self.tokenizer.encode( + inputs["prompt"], allowed_special=toktoken_special_tokens) - if len(prompt_token_ids_single_img) <= input_seq_len: - return test_mm_prompt + if "query" in inputs: + with nvtx_range_debug("tokenize query"): + try: + query_token_ids = self.tokenizer.encode( + inputs["query"], + add_special_tokens=sampling_params.add_special_tokens, + **kwargs) + except: + # Tiktoken path + query_token_ids = self.tokenizer.encode( + inputs["query"], + allowed_special=toktoken_special_tokens) - # reduce img resolution - self.image_max_dim = self.image_max_dim >> 1 + return token_ids, {"query_token_ids": query_token_ids} - return None + return token_ids, None -class BaseMultimodalInputProcessor: +class BaseMultimodalInputProcessor(InputProcessor, ABC): """ Base class for multimodal input processors with default implementations of get_num_tokens_per_image and get_num_tokens_per_video methods. @@ -98,32 +131,75 @@ class BaseMultimodalInputProcessor: def __init__(self, **kwargs): super().__init__(**kwargs) + self._use_fast: bool = kwargs.get('use_fast', True) + self._multimodal_hashing_supported: Optional[bool] = None + + @property + @abstractmethod + def processor(self) -> AutoProcessor: + """The HF AutoProcessor for this model.""" + ... + + @property + @abstractmethod + def tokenizer(self) -> PreTrainedTokenizerBase: + """The HF tokenizer for this model.""" + ... + + @property + @abstractmethod + def config(self) -> PretrainedConfig: + """The HF pretrained config for this model.""" + ... + + @property + @abstractmethod + def dtype(self) -> torch.dtype: + """The dtype for this model.""" + ... + + @abstractmethod + def __call__( + self, inputs: TextPrompt, sampling_params: SamplingParams + ) -> Tuple[List[int], Optional[ExtraProcessedInputs]]: + ... - def get_processor(self) -> Optional[Any]: - """Return the processor object if available; otherwise raise NotImplementedError. + @property + def use_fast(self) -> bool: """ - if not hasattr(self, 'processor') and not hasattr(self, '_processor'): - raise NotImplementedError( - f"cannot find processor in {self.__class__.__name__}. " - "Please ensure the processor is stored under self.processor or self._processor." - ) - return getattr(self, 'processor', getattr(self, '_processor', None)) + Whether to use fast tokenizer for AutoProcessor. + Default is True for most multimodal models. + """ + return self._use_fast + + @property + def multimodal_hashing_supported(self) -> Optional[bool]: + """ + Whether multimodal hashing is supported for this processor. + + Returns None if unknown (will be detected at runtime), + True if supported, False if not supported. + """ + return self._multimodal_hashing_supported + + @multimodal_hashing_supported.setter + def multimodal_hashing_supported(self, value: Optional[bool]) -> None: + """Set the multimodal hashing support status (used for runtime detection).""" + self._multimodal_hashing_supported = value def get_vocab_size(self) -> Optional[int]: """Return the tokenizer/model vocabulary size if available; otherwise None. Resolution order: - 1) self.model_config.vocab_size + 1) self.config.vocab_size 2) self.tokenizer.vocab_size """ # 1) Model config - if hasattr(self, 'model_config') and getattr( - self.model_config, 'vocab_size', None) is not None: - return int(self.model_config.vocab_size) + if hasattr(self.config, 'vocab_size'): + return int(self.config.vocab_size) # 2) Direct tokenizer on self - if hasattr(self, 'tokenizer') and getattr(self.tokenizer, 'vocab_size', - None) is not None: + if hasattr(self.tokenizer, 'vocab_size'): return int(self.tokenizer.vocab_size) logger.debug( @@ -136,13 +212,11 @@ def get_mm_token_ids(self) -> Optional[Tensor]: The token IDs filtered by this method should be contiguous for each multimodal item, i.e. special tokens if any should be included. """ - processor = self.get_processor() - if processor is not None and getattr(processor, 'mm_token_ids', - None) is not None: - return processor.mm_token_ids + if hasattr(self.processor, 'mm_token_ids'): + return self.processor.mm_token_ids logger.debug( - f"Cannot determine mm_token_ids from {self.__class__.__name__}. " + f"Cannot find mm_token_ids in {self.__class__.__name__}.processor. " "If needed, please override this method to return multimodal token ids. " ) return None @@ -156,19 +230,15 @@ def get_mm_special_token_ids(self) -> Optional[Tensor]: (e.g., Mistral3, LLaMA4) mix special tokens with multimodal tokens, so they need to be returned separately. """ - processor = self.get_processor() - return getattr(processor, "mm_special_token_ids", - None) if processor else None + return getattr(self.processor, "mm_special_token_ids", None) @property def get_num_multimodal_tokens(self): """ Get the Hugging Face processor's '_get_num_multimodal_tokens' method. """ - processor = self.get_processor() - if processor is not None and hasattr(processor, - '_get_num_multimodal_tokens'): - return processor._get_num_multimodal_tokens + if hasattr(self.processor, '_get_num_multimodal_tokens'): + return self.processor._get_num_multimodal_tokens else: raise NotImplementedError( f"get_num_multimodal_tokens not implemented for {self.__class__.__name__}. " @@ -226,74 +296,65 @@ def get_num_tokens_per_video( return num_tokens_per_frame * num_frames // temporal_patch_size -class DefaultInputProcessor(InputProcessor): - """Preprocess the inputs to the model.""" +class BaseMultimodalDummyInputsBuilder(ABC): + """ + Base class for generating dummy inputs. Specially for profiling + """ - def __init__(self, - model_path, - model_config, - tokenizer, - trust_remote_code: bool = True) -> None: - self.tokenizer = tokenizer - self.model_config = model_config - self.model_path = model_path - self.multimodal_hashing_supported = None + DEFAULT_IMAGE_MAX_DIM = 16384 + DEFAULT_IMAGE_MIN_DIM = 128 - def __call__( - self, inputs: TextPrompt, sampling_params: SamplingParams - ) -> Tuple[List[int], Optional[ExtraProcessedInputs]]: - """The default input processor handles only tokenization.""" - if self.tokenizer is None: - raise ValueError("tokenizer is required to tokenize string prompt") - kwargs = {} - if sampling_params.truncate_prompt_tokens is not None: - kwargs = dict(truncation=True, - max_length=sampling_params.truncate_prompt_tokens) - toktoken_special_tokens = { - "<|startoftext|>", - "<|endoftext|>", - "<|reserved_200000|>", - "<|reserved_200001|>", - "<|return|>", - "<|constrain|>", - "<|reserved_200004|>", - "<|channel|>", - "<|start|>", - "<|end|>", - "<|message|>", - "<|reserved_200009|>", - "<|reserved_200010|>", - "<|reserved_200011|>", - "<|call|>", - "<|reserved_200013|>", - } - with nvtx_range_debug("tokenize prompt"): - try: - token_ids = self.tokenizer.encode( - inputs["prompt"], - add_special_tokens=sampling_params.add_special_tokens, - **kwargs) - except: - # Tiktoken path - token_ids = self.tokenizer.encode( - inputs["prompt"], allowed_special=toktoken_special_tokens) + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.image_max_dim = kwargs.get('image_max_dim', + self.DEFAULT_IMAGE_MAX_DIM) + self.img_min_dim = kwargs.get('img_min_dim', self.DEFAULT_IMAGE_MIN_DIM) - if "query" in inputs: - with nvtx_range_debug("tokenize query"): - try: - query_token_ids = self.tokenizer.encode( - inputs["query"], - add_special_tokens=sampling_params.add_special_tokens, - **kwargs) - except: - # Tiktoken path - query_token_ids = self.tokenizer.encode( - inputs["query"], - allowed_special=toktoken_special_tokens) + @property + @abstractmethod + def tokenizer(self) -> PreTrainedTokenizerBase: + pass - return token_ids, {"query_token_ids": query_token_ids} + @property + @abstractmethod + def config(self) -> PretrainedConfig: + pass - return token_ids, None + @property + @abstractmethod + def model_path(self) -> str: + pass + + def get_dummy_image(self, max_width: int, max_height: int) -> Image.Image: + image = Image.new("RGB", (max_width, max_height), + color=random.randint(0, 256)) + return image + + def get_dummy_prompt(self, input_seq_len: int): + # TODO(yechank): We use the max resolution as starting point and keep reducing the resolution until the prompt length is less than the input sequence length. + # Need to find better way to calculate the dummy prompt length as this iteration may not be efficient. + while self.image_max_dim >= self.img_min_dim: + image = self.get_dummy_image(max_width=self.image_max_dim, + max_height=self.image_max_dim) + + test_mm_prompt = tensorrt_llm.inputs.utils.default_multimodal_input_loader( + tokenizer=self.tokenizer, + model_dir=self.model_path, + model_type=self.config.model_type, + modality="image", + prompts=[""], + media=[[image]], + image_data_format="pt")[0] + + prompt_token_ids_single_img, _ = self(test_mm_prompt, None) + + if len(prompt_token_ids_single_img) <= input_seq_len: + return test_mm_prompt + + # reduce img resolution + self.image_max_dim = self.image_max_dim >> 1 + + return None class MultimodalPlaceholderPlacement(enum.Enum): @@ -513,7 +574,6 @@ def create_input_processor( trust_remote_code=True) model_config = config.pretrained_config except (ValueError, EnvironmentError) as e: - config = None logger.debug( f"Unable to load HF config from {model_path_or_dir}: {e}. Falling back." ) @@ -539,7 +599,7 @@ def create_input_processor( def create_input_processor_with_hash( - input_processor: InputProcessor, + input_processor: BaseMultimodalInputProcessor, hash_lib=default_hasher, ) -> Callable[[TextPrompt, SamplingParams], Tuple[ List[int], Optional[ExtraProcessedInputs]]]: From 334c49a2b24b80f5dba15c7447bd5dd2d81d03c6 Mon Sep 17 00:00:00 2001 From: yechank <161688079+yechank-nvidia@users.noreply.github.com> Date: Thu, 30 Oct 2025 14:22:44 +0900 Subject: [PATCH 2/8] import ordering Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com> --- tensorrt_llm/_torch/models/modeling_gemma3vl.py | 6 ++---- tensorrt_llm/_torch/models/modeling_vila.py | 7 ++----- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_gemma3vl.py b/tensorrt_llm/_torch/models/modeling_gemma3vl.py index 2fb32303e03..a5308a3b524 100644 --- a/tensorrt_llm/_torch/models/modeling_gemma3vl.py +++ b/tensorrt_llm/_torch/models/modeling_gemma3vl.py @@ -4,7 +4,8 @@ from typing import List, Optional, Tuple import torch -from transformers import AutoProcessor, Gemma3Config, PreTrainedModel +from transformers import (AutoProcessor, AutoTokenizer, Gemma3Config, + PretrainedConfig, PreTrainedModel) from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \ BaseWeightMapper @@ -34,9 +35,6 @@ def _is_disagg() -> bool: return os.getenv(_MULTIMODAL_ENV_NAME, "0") == "1" -from transformers import AutoTokenizer, PretrainedConfig - - class Gemma3InputProcessor(BaseMultimodalInputProcessor, BaseMultimodalDummyInputsBuilder): diff --git a/tensorrt_llm/_torch/models/modeling_vila.py b/tensorrt_llm/_torch/models/modeling_vila.py index 4b76ba122fc..d8497e23408 100644 --- a/tensorrt_llm/_torch/models/modeling_vila.py +++ b/tensorrt_llm/_torch/models/modeling_vila.py @@ -31,8 +31,8 @@ from huggingface_hub.utils import HFValidationError from PIL import Image from transformers import (AutoConfig, AutoImageProcessor, AutoModel, - AutoTokenizer, LlavaConfig, PretrainedConfig, - PreTrainedModel) + AutoProcessor, AutoTokenizer, LlavaConfig, + PretrainedConfig, PreTrainedModel) from ..._utils import nvtx_range from ...inputs import (BaseMultimodalInputProcessor, ExtraProcessedInputs, @@ -864,9 +864,6 @@ def _apply_chat_template(text, conv, tokenizer): return text -from transformers import AutoProcessor - - class VilaInputProcessor(BaseMultimodalInputProcessor): def __init__(self, From c70ff309b67ecf8bffa214ec7cdc16b4f0b5160f Mon Sep 17 00:00:00 2001 From: yechank <161688079+yechank-nvidia@users.noreply.github.com> Date: Tue, 28 Oct 2025 10:21:42 +0900 Subject: [PATCH 3/8] config naming convention fix Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com> --- tensorrt_llm/_torch/models/modeling_llama.py | 6 +++--- tensorrt_llm/_torch/models/modeling_llava_next.py | 10 +++++----- tensorrt_llm/_torch/models/modeling_mistral.py | 14 +++++++------- tensorrt_llm/_torch/models/modeling_qwen2vl.py | 4 ++-- tensorrt_llm/inputs/registry.py | 14 +++++++------- 5 files changed, 24 insertions(+), 24 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index 42eea7f9f6f..38d487a7eae 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -5,8 +5,8 @@ import torch from PIL.Image import Image from torch import nn -from transformers import (AutoProcessor, Llama4Config, Llama4VisionModel, - LlamaConfig) +from transformers import (AutoProcessor, AutoTokenizer, Llama4Config, + Llama4VisionModel, LlamaConfig, PretrainedConfig) from transformers.modeling_utils import load_sharded_checkpoint from transformers.models.llama4.modeling_llama4 import Llama4MultiModalProjector @@ -1150,7 +1150,7 @@ def attach_multimodal_embeddings( f"Missing required key in multimodal embedding: {e}") # Validate embedding dimensions - model_hidden_size = self.model_config.text_config.hidden_size + model_hidden_size = self.config.text_config.hidden_size for i, embedding in enumerate(mm_embeddings): if embedding.shape[-1] != model_hidden_size: raise ValueError( diff --git a/tensorrt_llm/_torch/models/modeling_llava_next.py b/tensorrt_llm/_torch/models/modeling_llava_next.py index aa9369ed046..aeb974a98ac 100644 --- a/tensorrt_llm/_torch/models/modeling_llava_next.py +++ b/tensorrt_llm/_torch/models/modeling_llava_next.py @@ -84,9 +84,9 @@ def _postprocess( List[torch.Tensor]] ) -> Tuple[torch.Tensor, torch.Tensor]: # Define model specific variables here before shared logic - mm_tokens = torch.tensor([self.model_config.image_token_index + mm_tokens = torch.tensor([self.config.image_token_index ]).to(input_ids.device) - model_hidden_size = self.model_config.text_config.hidden_size + model_hidden_size = self.config.text_config.hidden_size start_len = end_len = 0 # for llava, need not append start/end token around each image token # End model specific variables @@ -191,12 +191,12 @@ def get_prompt_token_ids( raise NotImplementedError( "Only one mm_handle is supported for LlavaNext for now") hidden_size = mm_handles[0]['tensor_size'][1] - assert hidden_size == self.model_config.text_config.hidden_size, "Multimodal embedding hidden size must match model hidden size" + assert hidden_size == self.config.text_config.hidden_size, "Multimodal embedding hidden size must match model hidden size" input_ids = self.tokenizer(text_prompt, return_tensors="pt").input_ids[0] - vocab_size = self.model_config.text_config.vocab_size - image_token_index = self.model_config.image_token_index + vocab_size = self.config.text_config.vocab_size + image_token_index = self.config.image_token_index image_mask = input_ids == image_token_index image_positions = torch.where(image_mask)[0] diff --git a/tensorrt_llm/_torch/models/modeling_mistral.py b/tensorrt_llm/_torch/models/modeling_mistral.py index 9c4cf6b54c4..ec650c804b2 100644 --- a/tensorrt_llm/_torch/models/modeling_mistral.py +++ b/tensorrt_llm/_torch/models/modeling_mistral.py @@ -261,13 +261,13 @@ def __call__( self, inputs: TextPrompt, sampling_params: SamplingParams ) -> Tuple[List[int], Optional[ExtraProcessedInputs]]: images = inputs.get("multi_modal_data", {}).get("image") - do_rescale = self._processor.image_processor.do_rescale + do_rescale = self.processor.image_processor.do_rescale if images is not None and isinstance(images[0], torch.Tensor): # The default multimodal input loader will normalize images to [0, 1] when the requested # format is "pt" (pytorch tensors), but not for "pil" (PIL images). do_rescale = False - processed = self._processor( + processed = self.processor( text=inputs["prompt"], images=images, do_rescale=do_rescale, @@ -310,18 +310,18 @@ def get_mm_token_ids(self) -> torch.Tensor: return torch.tensor([ # This is the `[IMG]` token id inserted into the prompt that should be replaced with image # embeddings. - self._processor.image_token_id, + self.processor.image_token_id, # This is the `[IMG_BREAK]` token id at the end of every "row". - self._processor.image_break_token_id, + self.processor.image_break_token_id, # This is the `[IMG_END]` token id to signify the end of an image. - self._processor.image_end_token_id, + self.processor.image_end_token_id, ]) def get_mm_special_token_ids(self) -> torch.Tensor: """Get the IDs of special multimodal tokens (placeholders not included).""" return torch.tensor([ - self._processor.image_break_token_id, - self._processor.image_end_token_id, + self.processor.image_break_token_id, + self.processor.image_end_token_id, ]) diff --git a/tensorrt_llm/_torch/models/modeling_qwen2vl.py b/tensorrt_llm/_torch/models/modeling_qwen2vl.py index f3805bfa84b..23a7c3e14d8 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen2vl.py +++ b/tensorrt_llm/_torch/models/modeling_qwen2vl.py @@ -110,7 +110,7 @@ def __init__(self, self.tllm_multimodal_token_id = self.get_vocab_size() + 1 # temporal patch size for video frames - self.temporal_patch_size = getattr(self._config.vision_config, + self.temporal_patch_size = getattr(self.config.vision_config, 'temporal_patch_size', 1) @property @@ -150,7 +150,7 @@ def get_rope_index( The main difference between the two implementations is how temporal position IDs are calculated. Args: - config: The HF model configuration + config: The HF's PretrainedConfig model configuration input_ids: Indices of input sequence tokens in the vocabulary image_grid_thw: The temporal, height and width of feature shape of each image in LLM video_grid_thw: The temporal, height and width of feature shape of each video in LLM diff --git a/tensorrt_llm/inputs/registry.py b/tensorrt_llm/inputs/registry.py index 4841369a1fb..912c2ecfb28 100644 --- a/tensorrt_llm/inputs/registry.py +++ b/tensorrt_llm/inputs/registry.py @@ -566,13 +566,13 @@ def create_input_processor( from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.models import get_model_architecture - model_config = None + config = None if checkpoint_format == "HF": try: - config = ModelConfig.from_pretrained(model_path_or_dir, - trust_remote_code=True) - model_config = config.pretrained_config + model_config = ModelConfig.from_pretrained(model_path_or_dir, + trust_remote_code=True) + config = model_config.pretrained_config except (ValueError, EnvironmentError) as e: logger.debug( f"Unable to load HF config from {model_path_or_dir}: {e}. Falling back." @@ -581,9 +581,9 @@ def create_input_processor( logger.debug( f"checkpoint_format={checkpoint_format}; skipping HF config load.") - if model_config is not None: + if config is not None: try: - model_cls, _ = get_model_architecture(model_config) + model_cls, _ = get_model_architecture(config) input_processor_cls = INPUT_PROCESSOR_REGISTRY._input_processors_cls_by_model_type \ .get(model_cls) except RuntimeError: # unregistered model @@ -591,7 +591,7 @@ def create_input_processor( input_processor_cls = None if input_processor_cls is not None: return input_processor_cls(model_path_or_dir, - model_config, + config, tokenizer, trust_remote_code=True) From ae3e7a87e7d3ee6f36b3027e5cc734f18328865c Mon Sep 17 00:00:00 2001 From: yechank <161688079+yechank-nvidia@users.noreply.github.com> Date: Thu, 30 Oct 2025 14:55:54 +0900 Subject: [PATCH 4/8] registry file import ordering Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com> --- tensorrt_llm/inputs/registry.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tensorrt_llm/inputs/registry.py b/tensorrt_llm/inputs/registry.py index 912c2ecfb28..394ebd04532 100644 --- a/tensorrt_llm/inputs/registry.py +++ b/tensorrt_llm/inputs/registry.py @@ -1,5 +1,6 @@ import enum import random +from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import (Any, Callable, Dict, List, Optional, Protocol, Tuple, Type, TypeVar) @@ -7,6 +8,8 @@ import torch from PIL import Image from torch import Tensor, nn +from transformers import (AutoProcessor, PretrainedConfig, + PreTrainedTokenizerBase) import tensorrt_llm @@ -22,11 +25,6 @@ ExtraProcessedInputs = Dict[str, Any] -from abc import ABC, abstractmethod - -from transformers import (AutoProcessor, PretrainedConfig, - PreTrainedTokenizerBase) - class InputProcessor(Protocol): """ From af363d6b4960f067633ffda45884c44d7b8323d0 Mon Sep 17 00:00:00 2001 From: yechank <161688079+yechank-nvidia@users.noreply.github.com> Date: Mon, 3 Nov 2025 10:55:15 +0900 Subject: [PATCH 5/8] AD fix Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com> --- tensorrt_llm/_torch/auto_deploy/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/auto_deploy/llm.py b/tensorrt_llm/_torch/auto_deploy/llm.py index 5062ee04054..205c1eb80dc 100644 --- a/tensorrt_llm/_torch/auto_deploy/llm.py +++ b/tensorrt_llm/_torch/auto_deploy/llm.py @@ -22,7 +22,7 @@ class ADInputProcessor(DefaultInputProcessor): """ def __init__(self, tokenizer: Optional[TokenizerBase], processor: Optional[Any] = None): - super().__init__(model_path=None, model_config=None, tokenizer=tokenizer) + super().__init__(model_path=None, config=None, tokenizer=tokenizer) # NOTE: HF's tokenizer/processor that has the apply_chat_template method self.processor = processor or getattr(tokenizer, "tokenizer", None) From 83dbda8c3103e836aa178a3188b7301d4b8eea0b Mon Sep 17 00:00:00 2001 From: yechank <161688079+yechank-nvidia@users.noreply.github.com> Date: Mon, 3 Nov 2025 13:45:21 +0900 Subject: [PATCH 6/8] mistral test fix Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com> --- tests/unittest/_torch/modeling/test_modeling_mistral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unittest/_torch/modeling/test_modeling_mistral.py b/tests/unittest/_torch/modeling/test_modeling_mistral.py index d3ea00c32ad..a79e9415bdb 100644 --- a/tests/unittest/_torch/modeling/test_modeling_mistral.py +++ b/tests/unittest/_torch/modeling/test_modeling_mistral.py @@ -529,7 +529,7 @@ def test_processor_get_num_tokens_per_image( ) as mocked_auto_processor: input_processor = modeling_mistral.Mistral3InputProcessor( model_path=str(tmp_path), - model_config=mistral_3_config, + config=mistral_3_config, tokenizer=mock.MagicMock(), ) From 06d63ca0b5cec9a8d191e6f033ca69dbcac189c4 Mon Sep 17 00:00:00 2001 From: yechank <161688079+yechank-nvidia@users.noreply.github.com> Date: Mon, 3 Nov 2025 15:15:11 +0900 Subject: [PATCH 7/8] multimodal test fix Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com> --- tests/unittest/_torch/multimodal/test_external_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unittest/_torch/multimodal/test_external_embedding.py b/tests/unittest/_torch/multimodal/test_external_embedding.py index c2734b67765..dac9035ee5c 100644 --- a/tests/unittest/_torch/multimodal/test_external_embedding.py +++ b/tests/unittest/_torch/multimodal/test_external_embedding.py @@ -76,7 +76,7 @@ def processor_setup(request): mock_auto_processor.from_pretrained.return_value = mock_processor processor = config["processor_class"](model_path="dummy_path", - model_config=mock_config, + config=mock_config, tokenizer=mock_tokenizer, trust_remote_code=True) From ed6f6b3e608cd55615ab652faa86e91c3f69ed9f Mon Sep 17 00:00:00 2001 From: yechank <161688079+yechank-nvidia@users.noreply.github.com> Date: Mon, 3 Nov 2025 17:06:54 +0900 Subject: [PATCH 8/8] multimodal test fix2 Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com> --- .../_torch/multimodal/test_find_num_image_tokens.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unittest/_torch/multimodal/test_find_num_image_tokens.py b/tests/unittest/_torch/multimodal/test_find_num_image_tokens.py index 6894237d9b9..688e3d0ef4d 100644 --- a/tests/unittest/_torch/multimodal/test_find_num_image_tokens.py +++ b/tests/unittest/_torch/multimodal/test_find_num_image_tokens.py @@ -84,13 +84,13 @@ def test_get_num_tokens_per_image(model_key, multimodal_model_configs): if model_type == 'llava_next': input_processor = LlavaNextInputProcessor( model_path=encoder_model_dir, - model_config=model_config_dict, + config=model_config_dict, tokenizer=tokenizer, trust_remote_code=True) elif model_type == 'qwen2_5_vl': input_processor = Qwen2VLInputProcessorBase( model_path=encoder_model_dir, - model_config=model_config_dict, + config=model_config_dict, tokenizer=tokenizer, trust_remote_code=True) else: @@ -191,13 +191,13 @@ def test_get_num_tokens_per_video(model_key, multimodal_model_configs): if model_type == 'llava_next': input_processor = LlavaNextInputProcessor( model_path=encoder_model_dir, - model_config=model_config_dict, + config=model_config_dict, tokenizer=tokenizer, trust_remote_code=True) elif model_type == 'qwen2_5_vl': input_processor = Qwen2VLInputProcessorBase( model_path=encoder_model_dir, - model_config=model_config_dict, + config=model_config_dict, tokenizer=tokenizer, trust_remote_code=True) else: