Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/auto_deploy/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class ADInputProcessor(DefaultInputProcessor):
"""

def __init__(self, tokenizer: Optional[TokenizerBase], processor: Optional[Any] = None):
super().__init__(model_path=None, model_config=None, tokenizer=tokenizer)
super().__init__(model_path=None, config=None, tokenizer=tokenizer)
# NOTE: HF's tokenizer/processor that has the apply_chat_template method
self.processor = processor or getattr(tokenizer, "tokenizer", None)

Expand Down
50 changes: 40 additions & 10 deletions tensorrt_llm/_torch/models/modeling_gemma3vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
from typing import List, Optional, Tuple

import torch
from transformers import AutoProcessor, Gemma3Config, PreTrainedModel
from transformers import (AutoProcessor, AutoTokenizer, Gemma3Config,
PretrainedConfig, PreTrainedModel)

from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
BaseWeightMapper

from ..._utils import nvtx_range
from ...inputs import (ExtraProcessedInputs, InputProcessor,
from ...inputs import (BaseMultimodalDummyInputsBuilder,
BaseMultimodalInputProcessor, ExtraProcessedInputs,
MultimodalPlaceholderMetadata,
MultimodalPlaceholderPlacement, TextPrompt,
register_input_processor)
Expand All @@ -33,15 +35,43 @@ def _is_disagg() -> bool:
return os.getenv(_MULTIMODAL_ENV_NAME, "0") == "1"


class Gemma3InputProcessor(InputProcessor):
class Gemma3InputProcessor(BaseMultimodalInputProcessor,
BaseMultimodalDummyInputsBuilder):

def __init__(self, model_path, model_config, tokenizer, trust_remote_code):
def __init__(self,
model_path: str,
config: PretrainedConfig,
tokenizer: AutoTokenizer,
trust_remote_code: bool = True):
super().__init__()
self._config = config
self._tokenizer = tokenizer
self._model_path = model_path
self._processor = AutoProcessor.from_pretrained(
model_path,
trust_remote_code=trust_remote_code,
use_fast=self.use_fast)
self._dtype = self.config.torch_dtype

@property
def config(self) -> PretrainedConfig:
return self._config

@property
def tokenizer(self) -> AutoTokenizer:
return self._tokenizer

self.tokenizer = tokenizer
self.processor = AutoProcessor.from_pretrained(
model_path, trust_remote_code=trust_remote_code, use_fast=True)
self.model_config = model_config
self.device = 'cuda'
@property
def model_path(self) -> str:
return self._model_path

@property
def processor(self) -> AutoProcessor:
return self._processor

@property
def dtype(self) -> torch.dtype:
return self._dtype

@nvtx_range("[Vision] preprocess")
def _preprocess(self, inputs):
Expand All @@ -59,7 +89,7 @@ def _preprocess(self, inputs):
images=images,
do_rescale=do_rescale,
return_tensors="pt",
device=self.device).to(dtype=torch.bfloat16)
).to(dtype=self.dtype)

input_ids = processor_output["input_ids"]
pixel_values = processor_output.get("pixel_values")
Expand Down
108 changes: 63 additions & 45 deletions tensorrt_llm/_torch/models/modeling_hyperclovax.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@

from tensorrt_llm.inputs.multimodal import MultimodalParams

from ...inputs import (BaseMultimodalInputProcessor, ExtraProcessedInputs,
InputProcessor, MultimodalPlaceholderMetadata,
from ...inputs import (BaseMultimodalDummyInputsBuilder,
BaseMultimodalInputProcessor, ExtraProcessedInputs,
MultimodalPlaceholderMetadata,
MultimodalPlaceholderPlacement, TextPrompt,
register_input_processor)
from ...logger import logger
Expand Down Expand Up @@ -564,33 +565,54 @@ def build_mlp(
return nn.Sequential(*layers)


class HCXVisionInputProcessor(BaseMultimodalInputProcessor, InputProcessor):
class HCXVisionInputProcessor(BaseMultimodalDummyInputsBuilder,
BaseMultimodalInputProcessor):

def __init__(self,
model_path: str,
model_config: PretrainedConfig,
config: PretrainedConfig,
tokenizer: AutoTokenizer,
trust_remote_code: bool = True):

self.pretrained_config = model_config
self.tokenizer = tokenizer
self.use_fast = True
if self.tokenizer is None:
self.tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=trust_remote_code,
use_fast=self.use_fast)
self.processor = AutoProcessor.from_pretrained(
super().__init__()
self._config = config
self._tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=trust_remote_code,
use_fast=self.use_fast)
self._processor = AutoProcessor.from_pretrained(
model_path,
trust_remote_code=trust_remote_code,
use_fast=self.use_fast)
self.tllm_multimodal_token_id = self.pretrained_config.language_config[
self._model_path = model_path
self._dtype = self.config.torch_dtype

self.tllm_multimodal_token_id = self.config.language_config[
"vocab_size"] + 1
self.vision_query_lengths = None
self._vision_query_generator = None

@property
def config(self) -> PretrainedConfig:
return self._config

@property
def tokenizer(self) -> AutoTokenizer:
return self._tokenizer

@property
def model_path(self) -> str:
return self._model_path

@property
def processor(self) -> AutoProcessor:
return self._processor

@property
def dtype(self) -> torch.dtype:
return self._dtype

def get_vocab_size(self):
return self.pretrained_config.language_config["vocab_size"]
return self.config.language_config["vocab_size"]

def get_num_tokens_per_image(
self,
Expand Down Expand Up @@ -656,8 +678,7 @@ def _post_process(self,
vision_query_lengths = preprocessed_image.get("vision_query_lengths",
None)
non_vision_query_lengths = determine_non_vision_query_lengths(
input_ids, self.tokenizer.pad_token_id,
self.pretrained_config.img_start_id)
input_ids, self.tokenizer.pad_token_id, self.config.img_start_id)
batch_size = input_ids.size(0)

len_inputs_embeds = max([
Expand All @@ -666,19 +687,18 @@ def _post_process(self,
non_vision_query_lengths, vision_query_lengths)
])

len_inputs_embeds = min(self.pretrained_config.decoder_max_length,
len_inputs_embeds = min(self.config.decoder_max_length,
len_inputs_embeds)

image_cnts = (input_ids == self.pretrained_config.img_start_id).sum(
dim=1).tolist()
image_cnts = (input_ids == self.config.img_start_id).sum(dim=1).tolist()

fused_input_ids = torch.zeros([batch_size, len_inputs_embeds],
dtype=input_ids.dtype)
for batch_idx, sample in enumerate(input_ids):
non_vision_query_length = non_vision_query_lengths[batch_idx]
sample = sample[:non_vision_query_length + image_cnts[batch_idx]]

mask = (sample == self.pretrained_config.img_start_id)
mask = (sample == self.config.img_start_id)
img_start_ids = mask.nonzero()
input_start, temp_start = 0, 0

Expand Down Expand Up @@ -779,32 +799,30 @@ class HCXVisionModel(nn.Module):
def __init__(self, model_config: ModelConfig[PretrainedConfig]):
super().__init__()
self.model_config = model_config
self.pretrained_config = model_config.pretrained_config
self.config = model_config.pretrained_config
siglip_model_config = copy.deepcopy(self.model_config)
siglip_model_config.pretrained_config = self.model_config.pretrained_config.vision_config
self.visual_token_idx = 0 if "siglip" in self.model_config.pretrained_config.vision_config.model_type else 1
self.dtype = self.model_config.pretrained_config.vision_config.torch_dtype
self.vision_model = SiglipVisionModel(siglip_model_config).to(
self.dtype)
self.mm_projector = HCXVisionCAbstractor(
num_queries=self.pretrained_config.num_queries_vis_abstractor,
num_input_tokens=(
self.pretrained_config.vision_config.image_size //
self.pretrained_config.vision_config.patch_size)**2,
encoder_hidden_size=self.pretrained_config.vision_config.
hidden_size,
hidden_size=self.pretrained_config.vision_config.hidden_size,
output_hidden_size=self.pretrained_config.hidden_size,
pos_emb=self.pretrained_config.proj_pos_emb,
prenorm=self.pretrained_config.proj_prenorm,
num_queries=self.config.num_queries_vis_abstractor,
num_input_tokens=(self.config.vision_config.image_size //
self.config.vision_config.patch_size)**2,
encoder_hidden_size=self.config.vision_config.hidden_size,
hidden_size=self.config.vision_config.hidden_size,
output_hidden_size=self.config.hidden_size,
pos_emb=self.config.proj_pos_emb,
prenorm=self.config.proj_prenorm,
).to(self.dtype)
self.image_newline = nn.Parameter(torch.empty(
self.pretrained_config.hidden_size, ),
self.config.hidden_size, ),
requires_grad=False).to(self.dtype)

self.unpad = self.pretrained_config.unpad
self.use_nth_layer = self.pretrained_config.use_nth_layer
self.anyres = self.pretrained_config.anyres
self.unpad = self.config.unpad
self.use_nth_layer = self.config.use_nth_layer
self.anyres = self.config.anyres
self.possible_resolutions = self._init_possible_resolutions()
self.post_config()

Expand All @@ -814,18 +832,18 @@ def post_config(self):

def _init_possible_resolutions(self):
possible_resolutions = []
if self.pretrained_config.anyres:
assert self.pretrained_config.max_num_grids > 0
for i in range(1, self.pretrained_config.max_num_grids + 1):
for j in range(1, self.pretrained_config.max_num_grids + 1):
if i == 1 and j == 1 and not self.pretrained_config.use_1x1_grid:
if self.config.anyres:
assert self.config.max_num_grids > 0
for i in range(1, self.config.max_num_grids + 1):
for j in range(1, self.config.max_num_grids + 1):
if i == 1 and j == 1 and not self.config.use_1x1_grid:
continue
if i * j <= self.pretrained_config.max_num_grids:
if i * j <= self.config.max_num_grids:
possible_resolutions.append([i, j])

possible_resolutions = [[
ys * self.pretrained_config.vision_config.image_size,
xs * self.pretrained_config.vision_config.image_size
ys * self.config.vision_config.image_size,
xs * self.config.vision_config.image_size
] for ys, xs in possible_resolutions]
return possible_resolutions

Expand Down
65 changes: 47 additions & 18 deletions tensorrt_llm/_torch/models/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import torch
from PIL.Image import Image
from torch import nn
from transformers import (AutoProcessor, Llama4Config, Llama4VisionModel,
LlamaConfig)
from transformers import (AutoProcessor, AutoTokenizer, Llama4Config,
Llama4VisionModel, LlamaConfig, PretrainedConfig)
from transformers.modeling_utils import load_sharded_checkpoint
from transformers.models.llama4.modeling_llama4 import Llama4MultiModalProjector

Expand All @@ -21,7 +21,8 @@
from tensorrt_llm.lora_manager import HfLoraLoader
from tensorrt_llm.models.convert_utils import split_matrix_tp

from ...inputs import (ExtraProcessedInputs, InputProcessor,
from ...inputs import (BaseMultimodalDummyInputsBuilder,
BaseMultimodalInputProcessor, ExtraProcessedInputs,
MultimodalPlaceholderMetadata,
MultimodalPlaceholderPlacement, TextPrompt,
register_input_processor)
Expand Down Expand Up @@ -1042,26 +1043,54 @@ def forward(self, multimodal_params: List[MultimodalParams]):
return [image_features]


class Llama4InputProcessor(InputProcessor):
from transformers import AutoTokenizer, PretrainedConfig


class Llama4InputProcessor(BaseMultimodalInputProcessor,
BaseMultimodalDummyInputsBuilder):

def __init__(self,
model_path,
model_config,
tokenizer,
model_path: str,
config: PretrainedConfig,
tokenizer: AutoTokenizer,
trust_remote_code: bool = True):
self.use_fast = True
self.processor = AutoProcessor.from_pretrained(
super().__init__()
self._config = config
self._dtype = self._config.torch_dtype
self._tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(
model_path)
self._model_path = model_path
self._processor = AutoProcessor.from_pretrained(
model_path,
trust_remote_code=trust_remote_code,
use_fast=self.use_fast)
self.model_config = model_config
self.tokenizer = tokenizer
self.vocab_size = model_config.text_config.vocab_size
self.image_token_index = model_config.image_token_index
use_fast=self.use_fast,
trust_remote_code=trust_remote_code)

self.vocab_size = self.config.text_config.vocab_size
self.image_token_index = self.config.image_token_index
self.fake_image_token = self.processor.fake_image_token
self.image_token = self.processor.img_patch_token
self.image_token_start_index = self.model_config.boi_token_index
self.image_token_end_index = self.model_config.eoi_token_index
self.image_token_start_index = self.config.boi_token_index
self.image_token_end_index = self.config.eoi_token_index

@property
def config(self) -> PretrainedConfig:
return self._config

@property
def tokenizer(self) -> AutoTokenizer:
return self._tokenizer

@property
def model_path(self) -> str:
return self._model_path

@property
def processor(self) -> AutoProcessor:
return self._processor

@property
def dtype(self) -> torch.dtype:
return self._dtype

def attach_multimodal_embeddings(
self, inputs: TextPrompt, multimodal_embedding: Dict[str,
Expand Down Expand Up @@ -1121,7 +1150,7 @@ def attach_multimodal_embeddings(
f"Missing required key in multimodal embedding: {e}")

# Validate embedding dimensions
model_hidden_size = self.model_config.text_config.hidden_size
model_hidden_size = self.config.text_config.hidden_size
for i, embedding in enumerate(mm_embeddings):
if embedding.shape[-1] != model_hidden_size:
raise ValueError(
Expand Down
Loading