Skip to content

VLM e2e test pipeline #337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 30, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions QEfficient/transformers/models/mllama/modeling_mllama.py
Original file line number Diff line number Diff line change
@@ -161,8 +161,8 @@ def forward(
value_states_new = torch.index_put(value_states_old, indices, value_states)

# Select old or new image KV states based on q_len
key_states = torch.where(q_len == 1, key_states_old, key_states_new)
value_states = torch.where(q_len == 1, value_states_old, value_states_new)
key_states = torch.where(torch.tensor(q_len == 1), key_states_old, key_states_new)
value_states = torch.where(torch.tensor(q_len == 1), value_states_old, value_states_new)

# Update the image cache
past_key_value.key_cache[self.layer_idx] = key_states
@@ -924,7 +924,7 @@ def forward(
return_dict=return_dict,
cache_position=cache_position,
)

outputs["pixel_values"] = pixel_values
return outputs

def get_dummy_inputs(self, kv_offload: bool = False):
@@ -1092,6 +1092,8 @@ def get_output_names(self, kv_offload: bool = False):
"logits",
*[f"past_{kv}.{i}_RetainedState" for i in range(num_hidden_layers) for kv in ["key", "value"]],
]
if not kv_offload:
lang_output_names.append("pixel_values_RetainedState")

output_names = {}
if kv_offload:
2 changes: 2 additions & 0 deletions QEfficient/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -13,8 +13,10 @@
check_and_assign_cache_dir,
dump_qconfig,
get_num_layers_from_config,
get_num_layers_vlm,
get_onnx_dir_name,
get_padding_shape_from_config,
get_padding_shape_vlm,
get_qpc_dir_path,
hf_download,
load_hf_processor,
46 changes: 46 additions & 0 deletions QEfficient/utils/_utils.py
Original file line number Diff line number Diff line change
@@ -357,6 +357,52 @@ def get_num_layers_from_config(config):
return n_layer


def get_num_layers_vlm(config):
"""
Gets number of layers from model config of VLM
--------

:config: AutoConfig from pretrained model.

Return:
number of layers of text and vision part
"""

if hasattr(config, "llm_config") and hasattr(config, "vision_config"): # Intern
n_layers_text = config.llm_config.num_hidden_layers
n_layers_vision = config.vision_config.num_hidden_layers
elif hasattr(config, "text_config") and hasattr(config, "vision_config"): # Llava, Mllama
n_layers_text = config.text_config.num_hidden_layers
n_layers_vision = config.vision_config.num_hidden_layers

return (n_layers_text, n_layers_vision)


def get_padding_shape_vlm(config, ctx_len, batch_size=1):
"""
Gets padding dims for VLM models- number of kv heads and d_head
and returns padding shape - (batch_size, number of kv heads, seq_len, hidden size)
required for initialization of past_key_values
--------

:config: AutoConfig from pretrained model.
:batch_size: int. number of input prompts used to create inputs
:seq_len: int. sequence length to run the model for.

Return:
List[int, int, int, int]
"""
if hasattr(config, "text_config"):
n_heads = config.text_config.num_key_value_heads
d_head = config.text_config.hidden_size // config.text_config.num_attention_heads
padding_shape = [batch_size, n_heads, ctx_len, d_head]
elif hasattr(config, "llm_config"):
n_heads = config.llm_config.num_key_value_heads
d_head = config.llm_config.hidden_size // config.llm_config.num_attention_heads
padding_shape = [batch_size, n_heads, ctx_len, d_head]
return padding_shape


def execute_command(process: str, command: str, output_file_path: Optional[str] = None):
"""
Executes the give command using subprocess.
253 changes: 252 additions & 1 deletion QEfficient/utils/generate_inputs.py
Original file line number Diff line number Diff line change
@@ -4,11 +4,16 @@
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
from typing import List

import numpy as np
import torch

from QEfficient.utils import get_num_layers_from_config, get_padding_shape_from_config, padding_check_and_fix
from QEfficient.utils import (
get_num_layers_from_config,
get_padding_shape_from_config,
padding_check_and_fix,
)


class InputHandler:
@@ -198,3 +203,249 @@ def update_ort_outputs(self, ort_outputs):
outputs["logits"] = ort_outputs["logits"]

return outputs


class InputHandlerVLM:
def __init__(
self, batch_size, config, image, conversation, processor, prompt, prompt_len, ctx_len, max_gen_len, n_layer
):
self.ctx_len = ctx_len
self.prompt_len = prompt_len
self.max_gen_len = max_gen_len
self.config = config
self.image = image
self.prompt = prompt
self.batch_size = batch_size
self.n_layer = n_layer
self.processor = processor
self.conversation = conversation

def prepare_pytorch_inputs(self):
"""
Function responsible for creating Prefill stage tensor inputs for PyTorch model.

Return:
:Dict: input_ids, position_ids, past_key_values
"""
inputs = self.processor(images=self.image, text=self.prompt, return_tensors="pt")
if hasattr(self.config, "text_config"):
txt_cfg = self.config.text_config
else:
txt_cfg = self.config.llm_config

num_hidden_layers = txt_cfg.num_hidden_layers
num_key_value_heads = txt_cfg.num_key_value_heads
head_dim = txt_cfg.hidden_size // txt_cfg.num_attention_heads
if hasattr(txt_cfg, "cross_attention_layers"):
cross_attention_layers = txt_cfg.cross_attention_layers

vis_cfg = self.config.vision_config
num_patches = (vis_cfg.image_size // vis_cfg.patch_size) ** 2 + 1
image_tokens_len = vis_cfg.max_num_tiles * num_patches

inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1) - 1
inputs["past_key_values"] = []
for i in range(num_hidden_layers):
# Specific to mllama as of now
if hasattr(txt_cfg, "cross_attention_layers") and i in cross_attention_layers:
idx = cross_attention_layers.index(i)
assert idx == ((i - 3) // 5), f"{i}, {(i - 3) // 5}"
inputs["past_key_values"].append(
(
torch.zeros(1, num_key_value_heads, image_tokens_len, head_dim),
torch.zeros(1, num_key_value_heads, image_tokens_len, head_dim),
)
)
else:
inputs["past_key_values"].append(
(
torch.zeros(1, num_key_value_heads, self.ctx_len, head_dim),
torch.zeros(1, num_key_value_heads, self.ctx_len, head_dim),
)
)

return inputs

def prepare_vlm_ort_inputs(self):
if hasattr(self.config, "text_config"):
txt_cfg = self.config.text_config
else:
txt_cfg = self.config.llm_config
num_hidden_layers = txt_cfg.num_hidden_layers
num_key_value_heads = txt_cfg.num_key_value_heads
head_dim = txt_cfg.hidden_size // txt_cfg.num_attention_heads
if hasattr(txt_cfg, "cross_attention_layers"):
cross_attention_layers = txt_cfg.cross_attention_layers
vis_cfg = self.config.vision_config
num_patches = (vis_cfg.image_size // vis_cfg.patch_size) ** 2 + 1
image_tokens_len = vis_cfg.max_num_tiles * num_patches

inputs = self.processor(images=self.image, text=self.prompt, return_tensors="np")
if "attention_mask" in inputs.keys():
inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1) - 1
inputs["past_key_values"] = []

vision_inputs = {
k: v for k, v in inputs.items() if k in {"pixel_values", "aspect_ratio_ids", "aspect_ratio_mask"}
}

for i in range(num_hidden_layers):
if hasattr(txt_cfg, "cross_attention_layers") and i in cross_attention_layers:
idx = cross_attention_layers.index(i)
assert idx == ((i - 3) // 5), f"{i}, {(i - 3) // 5}"
inputs["past_key." + str(i)] = np.zeros(
(self.batch_size, num_key_value_heads, image_tokens_len, head_dim), dtype=np.float32
)
inputs["past_value." + str(i)] = np.zeros(
(self.batch_size, num_key_value_heads, image_tokens_len, head_dim), dtype=np.float32
)
else:
inputs["past_key." + str(i)] = np.zeros(
(self.batch_size, num_key_value_heads, self.ctx_len, head_dim), dtype=np.float32
)
inputs["past_value." + str(i)] = np.zeros(
(self.batch_size, num_key_value_heads, self.ctx_len, head_dim), dtype=np.float32
)
lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs}
return vision_inputs, lang_inputs

def update_vlm_ort_outputs(self, ort_outputs):
"""
Function responsible for updating ONNXRT session outputs.

``Mandatory`` Args:
:ort_outputs (Dict): Numpy outputs of Onnx model from current iteration

Return:
updated_outputs (Dict): Updated past_key_values, logits, pixel_values
"""
present_key_values = []
for i in range(self.n_layer[0]):
if "past_key." + str(i) + "_RetainedState" in ort_outputs:
present_key_values.append(ort_outputs["past_key." + str(i) + "_RetainedState"])
if "past_value." + str(i) + "_RetainedState" in ort_outputs:
present_key_values.append(ort_outputs["past_value." + str(i) + "_RetainedState"])

outputs = {}
outputs["past_key_values"] = present_key_values
outputs["logits"] = ort_outputs["logits"]
outputs["pixel_values_RetainedState"] = (
ort_outputs["pixel_values_RetainedState"] if "pixel_values_RetainedState" in ort_outputs else None
)
outputs["image_features_RetainedState"] = (
ort_outputs["image_features_RetainedState"] if "image_features_RetainedState" in ort_outputs else None
)
return outputs

def update_vlm_ort_inputs(self, inputs, ort_outputs):
"""
Function responsible for updating Prefill stage inputs to create inputs for decode stage inputs for ONNX model to be run on ONNXRT.

``Mandatory`` Args:
:inputs (Dict): NumPy inputs of Onnx model from previous iteration
:ort_outputs (Dict): Numpy outputs of Onnx model from previous iteration

Return:
:Dict: Updated input_ids, position_ids, pixel_values and past_key_values
"""
updated_inputs = {}
updated_inputs["input_ids"] = ort_outputs["logits"].argmax(-1)
updated_inputs["position_ids"] = np.max(inputs["position_ids"], axis=1, keepdims=True) + 1
for i in range(self.n_layer[0]):
updated_inputs["past_key." + str(i)] = ort_outputs["past_key_values"][i * 2]
updated_inputs["past_value." + str(i)] = ort_outputs["past_key_values"][i * 2 + 1]
if "pixel_values_RetainedState" in ort_outputs.keys():
updated_inputs["pixel_values"] = ort_outputs["pixel_values_RetainedState"]
if "image_features_RetainedState" in ort_outputs.keys():
updated_inputs["image_features"] = ort_outputs["image_features_RetainedState"]

if "cross_attention_mask" in inputs.keys():
bs, _, num_images, img_tiles = inputs["cross_attention_mask"].shape
updated_inputs["cross_attention_mask"] = torch.ones(
(bs, 1, num_images, img_tiles), dtype=torch.int64
).numpy()

for k, v in inputs.items():
if k not in updated_inputs.keys():
updated_inputs[k] = v
return updated_inputs


class InputHandlerInternVL(InputHandlerVLM):
def __init__(self, batch_size, config, image, processor, prompt, prompt_len, ctx_len, max_gen_len, n_layer):
self.ctx_len = ctx_len
self.prompt_len = prompt_len
self.max_gen_len = max_gen_len
self.config = config
self.image = image
self.prompt = prompt
self.batch_size = batch_size
self.n_layer = n_layer
self.processor = processor

def prepare_pytorch_inputs(self):
question = "<image>\n" + self.prompt
pixel_values = self.processor.load_image(self.image, max_num=12)
# Chat Template information for prompt preprocessing
messages: List[List[str]] = []
roles = ("<|im_start|>user\n", "<|im_start|>assistant\n")
prompt = self.processor(pixel_values, question, messages, roles)
inputs = self.processor.tokenizer(prompt, return_tensors="pt")
inputs["pixel_values"] = pixel_values.clone()

if hasattr(self.config, "text_config"):
txt_cfg = self.config.text_config
else:
txt_cfg = self.config.llm_config

num_hidden_layers = txt_cfg.num_hidden_layers
num_key_value_heads = txt_cfg.num_key_value_heads
head_dim = txt_cfg.hidden_size // txt_cfg.num_attention_heads

inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1) - 1
inputs["past_key_values"] = []
for i in range(num_hidden_layers):
inputs["past_key_values"].append(
(
torch.zeros(1, num_key_value_heads, self.ctx_len, head_dim),
torch.zeros(1, num_key_value_heads, self.ctx_len, head_dim),
)
)

return inputs

def prepare_vlm_ort_inputs(self):
if hasattr(self.config, "text_config"):
txt_cfg = self.config.text_config
else:
txt_cfg = self.config.llm_config
num_hidden_layers = txt_cfg.num_hidden_layers
num_key_value_heads = txt_cfg.num_key_value_heads
head_dim = txt_cfg.hidden_size // txt_cfg.num_attention_heads

question = "<image>\n" + self.prompt
pixel_values = self.processor.load_image(self.image, max_num=12)
# Chat Template information for prompt preprocessing
messages: List[List[str]] = []
roles = ("<|im_start|>user\n", "<|im_start|>assistant\n")
prompt = self.processor(pixel_values, question, messages, roles)
inputs = self.processor.tokenizer(prompt, return_tensors="np")
inputs["pixel_values"] = pixel_values.numpy()

if "attention_mask" in inputs.keys():
inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1) - 1
inputs["past_key_values"] = []

vision_inputs = {
k: v for k, v in inputs.items() if k in {"pixel_values", "aspect_ratio_ids", "aspect_ratio_mask"}
}

for i in range(num_hidden_layers):
inputs["past_key." + str(i)] = np.zeros(
(self.batch_size, num_key_value_heads, self.ctx_len, head_dim), dtype=np.float32
)
inputs["past_value." + str(i)] = np.zeros(
(self.batch_size, num_key_value_heads, self.ctx_len, head_dim), dtype=np.float32
)
lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs}
return vision_inputs, lang_inputs
Loading