Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 145 additions & 3 deletions QEfficient/generation/embedding_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
operations, separating them from the main text generation logic.
"""

from typing import Any, Dict, Optional, Tuple
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import requests
import torch
from PIL import Image
from transformers import AutoImageProcessor
from transformers import AutoImageProcessor, AutoTokenizer

from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.utils.logging_utils import logger
Expand All @@ -37,6 +38,9 @@ def __init__(
qeff_model: Optional[QAICInferenceSession],
vision_session: Optional[QAICInferenceSession],
processor: Optional[AutoImageProcessor],
tokenizer: Optional[AutoTokenizer],
image_height: Optional[int] = None,
image_width: Optional[int] = None,
config: Optional[Dict[str, Any]] = None,
lang_session: Optional[QAICInferenceSession] = None,
):
Expand All @@ -46,12 +50,16 @@ def __init__(
Args:
vision_session: QAICInferenceSession for vision model
processor: AutoImageProcessor for image preprocessing
tokenizer: AutoTokenizer for text tokenization
config: Configuration dictionary with vision model parameters
lang_session: Optional language session for coordination (to avoid resource conflicts)
"""
self._qeff_model = qeff_model
self._vision_session = vision_session
self._processor = processor
self._tokenizer = tokenizer
self._image_height = image_height
self._image_width = image_width
self._config = config or {}
self._lang_session = lang_session # Store language session for coordination

Expand All @@ -70,6 +78,126 @@ def is_available(self) -> bool:
"""
return self._vision_session is not None and self._processor is not None

def prepare_internVL_inputs(self, img_url: str, query: str) -> Dict[str, np.ndarray]:
"""
Prepare inputs for InternVL model

Args:
image_url: URL or path to image
query: Text query to process with image
prompt = [query]
"""
if not self._tokenizer:
raise ValueError("Tokenizer is required for InternVL input preparation")
prompt = query
pixel_values = []
num_patches_list = []
questions = []
img = requests.get(img_url, stream=True)
image = Image.open(BytesIO(img.content)).convert("RGB")

if self._image_height and self._image_width:
image = image.resize((self._image_height, self._image_width))
else:
logger.warning("Height and Width not specified. Using default image size for num_patches = 13.")
image = image.resize((1000, 747))

# preprocess the resized image
pixel_value = self._processor.load_image(image, max_num=12)
num_patches_list.append(pixel_value.shape[0])
pixel_values.append(pixel_value)

question = "<image>\n" + prompt
questions.append(question)

pixel_values = torch.cat(pixel_values, dim=0)

# Chat Template information for prompt preprocessing
messages: List[List[str]] = []
roles = ("<|im_start|>user\n", "<|im_start|>assistant\n")
prompt = self._processor(pixel_values, questions, messages, roles, num_patches_list=num_patches_list)

inputs = self._tokenizer(prompt, return_tensors="pt")
inputs["pixel_values"] = pixel_values.clone()

# Convert to numpy arrays
vision_inputs = {}
for k, v in inputs.items():
if k in {
"pixel_values",
"image_masks",
"image_input_idx",
"valid_idx",
"aspect_ratio_ids",
"aspect_ratio_mask",
}:
vision_inputs[k] = np.array(v)

# Convert specific inputs to float16
vision_inputs_fp16 = {"pixel_values", "image_masks"}
for k in vision_inputs_fp16:
if k in vision_inputs:
vision_inputs[k] = vision_inputs[k].astype("float16")

lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs}

return vision_inputs, lang_inputs

def prepare_molmo_inputs(self, image_url: str, query: str) -> Dict[str, np.ndarray]:
"""
Download and preprocess image into model inputs
Args:
image_url: URL or path to image
query: Text query to process with image
Returns:
Dictionary of vision model inputs
Raises:
ValueError: If vision handler is not properly initialized
RuntimeError: If image processing fails
"""
if not self.is_available():
raise ValueError("Vision handler not properly initialized. Need both vision_session and processor.")

try:
# Download image
if image_url.startswith(("http://", "https://")):
image = Image.open(requests.get(image_url, stream=True).raw)
else:
image = Image.open(image_url)
image = image.resize((536, 354))
inputs = self._processor.process(images=[image], text=query)
inputs = {k: v.unsqueeze(0) for k, v in inputs.items()}
inputs["attention_mask"] = torch.ones((inputs["input_ids"].shape), dtype=torch.int64)
valid = inputs["image_input_idx"] > 0
valid = valid.reshape(1, -1)
inputs["valid_idx"] = torch.nonzero(valid)[:, 1].unsqueeze(0)
inputs["pixel_values"] = inputs.pop("images")

# Convert to numpy arrays
vision_inputs = {}
for k, v in inputs.items():
if k in {
"pixel_values",
"image_masks",
"image_input_idx",
"valid_idx",
"aspect_ratio_ids",
"aspect_ratio_mask",
}:
vision_inputs[k] = np.array(v)

# Convert specific inputs to float16
vision_inputs_fp16 = {"pixel_values", "image_masks"}
for k in vision_inputs_fp16:
if k in vision_inputs:
vision_inputs[k] = vision_inputs[k].astype("float16")

lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs}

return vision_inputs, lang_inputs
except Exception as e:
raise RuntimeError(f"Failed to process image {image_url}: {str(e)}")

def prepare_vlm_inputs(self, image_url: str, query: str, prefill_seq_len: int) -> Dict[str, np.ndarray]:
"""
Download and preprocess image into model inputs
Expand All @@ -95,6 +223,9 @@ def prepare_vlm_inputs(self, image_url: str, query: str, prefill_seq_len: int) -
else:
image = Image.open(image_url)

if "mistral3" in self._qeff_model.model.config.model_type:
image = image.resize((1540, 1540))

# Prepare conversation format
conversation = [
{
Expand Down Expand Up @@ -323,7 +454,18 @@ def get_processed_inputs(

try:
## Get vlm inputs ##
vision_inputs, lang_inputs = self.prepare_vlm_inputs(image_url, query, prefill_seq_len)
if (
hasattr(self._qeff_model.model.config, "model_type")
and self._qeff_model.model.config.model_type == "internvl_chat"
):
vision_inputs, lang_inputs = self.prepare_internVL_inputs(image_url, query)
elif (
hasattr(self._qeff_model.model.config, "model_type")
and self._qeff_model.model.config.model_type == "molmo"
):
vision_inputs, lang_inputs = self.prepare_molmo_inputs(image_url, query)
else:
vision_inputs, lang_inputs = self.prepare_vlm_inputs(image_url, query, prefill_seq_len)

# Handle padding for language model
pad_token_id = 1
Expand Down
8 changes: 8 additions & 0 deletions QEfficient/generation/vlm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def __init__(
enable_debug_logs: bool = False,
write_io_dir: Optional[str] = None,
full_batch_size: Optional[int] = None,
image_height: Optional[int] = None,
image_width: Optional[int] = None,
is_tlm: bool = False,
include_sampler: bool = False,
return_pdfs: bool = False,
Expand Down Expand Up @@ -143,6 +145,9 @@ def __init__(
)
self.qeff_model = qeff_model
self.processor = processor
self.tokenizer = tokenizer
self.image_height = image_height
self.image_width = image_width
self._vision_qpc_path = vision_qpc_path
self.device_id = device_id # Store device_id for vision components
self.enable_debug_logs = enable_debug_logs # Store for vision components
Expand Down Expand Up @@ -173,6 +178,9 @@ def _init_vision_components(self):
qeff_model=self.qeff_model,
vision_session=self._vision_session,
processor=self.processor,
tokenizer=self.tokenizer,
image_height=self.image_height,
image_width=self.image_width,
config=vision_config,
lang_session=self._session, # Pass language session for coordination
)
Expand Down
Loading
Loading