diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index 6eef81bd3..1789c7475 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -161,8 +161,8 @@ def forward( value_states_new = torch.index_put(value_states_old, indices, value_states) # Select old or new image KV states based on q_len - key_states = torch.where(q_len == 1, key_states_old, key_states_new) - value_states = torch.where(q_len == 1, value_states_old, value_states_new) + key_states = torch.where(torch.tensor(q_len == 1), key_states_old, key_states_new) + value_states = torch.where(torch.tensor(q_len == 1), value_states_old, value_states_new) # Update the image cache past_key_value.key_cache[self.layer_idx] = key_states @@ -924,7 +924,7 @@ def forward( return_dict=return_dict, cache_position=cache_position, ) - + outputs["pixel_values"] = pixel_values return outputs def get_dummy_inputs(self, kv_offload: bool = False): @@ -1092,6 +1092,8 @@ def get_output_names(self, kv_offload: bool = False): "logits", *[f"past_{kv}.{i}_RetainedState" for i in range(num_hidden_layers) for kv in ["key", "value"]], ] + if not kv_offload: + lang_output_names.append("pixel_values_RetainedState") output_names = {} if kv_offload: diff --git a/QEfficient/utils/__init__.py b/QEfficient/utils/__init__.py index f6aa3296d..55719cea5 100755 --- a/QEfficient/utils/__init__.py +++ b/QEfficient/utils/__init__.py @@ -13,8 +13,10 @@ check_and_assign_cache_dir, dump_qconfig, get_num_layers_from_config, + get_num_layers_vlm, get_onnx_dir_name, get_padding_shape_from_config, + get_padding_shape_vlm, get_qpc_dir_path, hf_download, load_hf_processor, diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index ea09e97d7..05cd63968 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -357,6 +357,52 @@ def get_num_layers_from_config(config): return n_layer +def get_num_layers_vlm(config): + """ + Gets number of layers from model config of VLM + -------- + + :config: AutoConfig from pretrained model. + + Return: + number of layers of text and vision part + """ + + if hasattr(config, "llm_config") and hasattr(config, "vision_config"): # Intern + n_layers_text = config.llm_config.num_hidden_layers + n_layers_vision = config.vision_config.num_hidden_layers + elif hasattr(config, "text_config") and hasattr(config, "vision_config"): # Llava, Mllama + n_layers_text = config.text_config.num_hidden_layers + n_layers_vision = config.vision_config.num_hidden_layers + + return (n_layers_text, n_layers_vision) + + +def get_padding_shape_vlm(config, ctx_len, batch_size=1): + """ + Gets padding dims for VLM models- number of kv heads and d_head + and returns padding shape - (batch_size, number of kv heads, seq_len, hidden size) + required for initialization of past_key_values + -------- + + :config: AutoConfig from pretrained model. + :batch_size: int. number of input prompts used to create inputs + :seq_len: int. sequence length to run the model for. + + Return: + List[int, int, int, int] + """ + if hasattr(config, "text_config"): + n_heads = config.text_config.num_key_value_heads + d_head = config.text_config.hidden_size // config.text_config.num_attention_heads + padding_shape = [batch_size, n_heads, ctx_len, d_head] + elif hasattr(config, "llm_config"): + n_heads = config.llm_config.num_key_value_heads + d_head = config.llm_config.hidden_size // config.llm_config.num_attention_heads + padding_shape = [batch_size, n_heads, ctx_len, d_head] + return padding_shape + + def execute_command(process: str, command: str, output_file_path: Optional[str] = None): """ Executes the give command using subprocess. diff --git a/QEfficient/utils/generate_inputs.py b/QEfficient/utils/generate_inputs.py index c45cfec41..701b62aca 100644 --- a/QEfficient/utils/generate_inputs.py +++ b/QEfficient/utils/generate_inputs.py @@ -4,11 +4,16 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- +from typing import List import numpy as np import torch -from QEfficient.utils import get_num_layers_from_config, get_padding_shape_from_config, padding_check_and_fix +from QEfficient.utils import ( + get_num_layers_from_config, + get_padding_shape_from_config, + padding_check_and_fix, +) class InputHandler: @@ -198,3 +203,249 @@ def update_ort_outputs(self, ort_outputs): outputs["logits"] = ort_outputs["logits"] return outputs + + +class InputHandlerVLM: + def __init__( + self, batch_size, config, image, conversation, processor, prompt, prompt_len, ctx_len, max_gen_len, n_layer + ): + self.ctx_len = ctx_len + self.prompt_len = prompt_len + self.max_gen_len = max_gen_len + self.config = config + self.image = image + self.prompt = prompt + self.batch_size = batch_size + self.n_layer = n_layer + self.processor = processor + self.conversation = conversation + + def prepare_pytorch_inputs(self): + """ + Function responsible for creating Prefill stage tensor inputs for PyTorch model. + + Return: + :Dict: input_ids, position_ids, past_key_values + """ + inputs = self.processor(images=self.image, text=self.prompt, return_tensors="pt") + if hasattr(self.config, "text_config"): + txt_cfg = self.config.text_config + else: + txt_cfg = self.config.llm_config + + num_hidden_layers = txt_cfg.num_hidden_layers + num_key_value_heads = txt_cfg.num_key_value_heads + head_dim = txt_cfg.hidden_size // txt_cfg.num_attention_heads + if hasattr(txt_cfg, "cross_attention_layers"): + cross_attention_layers = txt_cfg.cross_attention_layers + + vis_cfg = self.config.vision_config + num_patches = (vis_cfg.image_size // vis_cfg.patch_size) ** 2 + 1 + image_tokens_len = vis_cfg.max_num_tiles * num_patches + + inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1) - 1 + inputs["past_key_values"] = [] + for i in range(num_hidden_layers): + # Specific to mllama as of now + if hasattr(txt_cfg, "cross_attention_layers") and i in cross_attention_layers: + idx = cross_attention_layers.index(i) + assert idx == ((i - 3) // 5), f"{i}, {(i - 3) // 5}" + inputs["past_key_values"].append( + ( + torch.zeros(1, num_key_value_heads, image_tokens_len, head_dim), + torch.zeros(1, num_key_value_heads, image_tokens_len, head_dim), + ) + ) + else: + inputs["past_key_values"].append( + ( + torch.zeros(1, num_key_value_heads, self.ctx_len, head_dim), + torch.zeros(1, num_key_value_heads, self.ctx_len, head_dim), + ) + ) + + return inputs + + def prepare_vlm_ort_inputs(self): + if hasattr(self.config, "text_config"): + txt_cfg = self.config.text_config + else: + txt_cfg = self.config.llm_config + num_hidden_layers = txt_cfg.num_hidden_layers + num_key_value_heads = txt_cfg.num_key_value_heads + head_dim = txt_cfg.hidden_size // txt_cfg.num_attention_heads + if hasattr(txt_cfg, "cross_attention_layers"): + cross_attention_layers = txt_cfg.cross_attention_layers + vis_cfg = self.config.vision_config + num_patches = (vis_cfg.image_size // vis_cfg.patch_size) ** 2 + 1 + image_tokens_len = vis_cfg.max_num_tiles * num_patches + + inputs = self.processor(images=self.image, text=self.prompt, return_tensors="np") + if "attention_mask" in inputs.keys(): + inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1) - 1 + inputs["past_key_values"] = [] + + vision_inputs = { + k: v for k, v in inputs.items() if k in {"pixel_values", "aspect_ratio_ids", "aspect_ratio_mask"} + } + + for i in range(num_hidden_layers): + if hasattr(txt_cfg, "cross_attention_layers") and i in cross_attention_layers: + idx = cross_attention_layers.index(i) + assert idx == ((i - 3) // 5), f"{i}, {(i - 3) // 5}" + inputs["past_key." + str(i)] = np.zeros( + (self.batch_size, num_key_value_heads, image_tokens_len, head_dim), dtype=np.float32 + ) + inputs["past_value." + str(i)] = np.zeros( + (self.batch_size, num_key_value_heads, image_tokens_len, head_dim), dtype=np.float32 + ) + else: + inputs["past_key." + str(i)] = np.zeros( + (self.batch_size, num_key_value_heads, self.ctx_len, head_dim), dtype=np.float32 + ) + inputs["past_value." + str(i)] = np.zeros( + (self.batch_size, num_key_value_heads, self.ctx_len, head_dim), dtype=np.float32 + ) + lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} + return vision_inputs, lang_inputs + + def update_vlm_ort_outputs(self, ort_outputs): + """ + Function responsible for updating ONNXRT session outputs. + + ``Mandatory`` Args: + :ort_outputs (Dict): Numpy outputs of Onnx model from current iteration + + Return: + updated_outputs (Dict): Updated past_key_values, logits, pixel_values + """ + present_key_values = [] + for i in range(self.n_layer[0]): + if "past_key." + str(i) + "_RetainedState" in ort_outputs: + present_key_values.append(ort_outputs["past_key." + str(i) + "_RetainedState"]) + if "past_value." + str(i) + "_RetainedState" in ort_outputs: + present_key_values.append(ort_outputs["past_value." + str(i) + "_RetainedState"]) + + outputs = {} + outputs["past_key_values"] = present_key_values + outputs["logits"] = ort_outputs["logits"] + outputs["pixel_values_RetainedState"] = ( + ort_outputs["pixel_values_RetainedState"] if "pixel_values_RetainedState" in ort_outputs else None + ) + outputs["image_features_RetainedState"] = ( + ort_outputs["image_features_RetainedState"] if "image_features_RetainedState" in ort_outputs else None + ) + return outputs + + def update_vlm_ort_inputs(self, inputs, ort_outputs): + """ + Function responsible for updating Prefill stage inputs to create inputs for decode stage inputs for ONNX model to be run on ONNXRT. + + ``Mandatory`` Args: + :inputs (Dict): NumPy inputs of Onnx model from previous iteration + :ort_outputs (Dict): Numpy outputs of Onnx model from previous iteration + + Return: + :Dict: Updated input_ids, position_ids, pixel_values and past_key_values + """ + updated_inputs = {} + updated_inputs["input_ids"] = ort_outputs["logits"].argmax(-1) + updated_inputs["position_ids"] = np.max(inputs["position_ids"], axis=1, keepdims=True) + 1 + for i in range(self.n_layer[0]): + updated_inputs["past_key." + str(i)] = ort_outputs["past_key_values"][i * 2] + updated_inputs["past_value." + str(i)] = ort_outputs["past_key_values"][i * 2 + 1] + if "pixel_values_RetainedState" in ort_outputs.keys(): + updated_inputs["pixel_values"] = ort_outputs["pixel_values_RetainedState"] + if "image_features_RetainedState" in ort_outputs.keys(): + updated_inputs["image_features"] = ort_outputs["image_features_RetainedState"] + + if "cross_attention_mask" in inputs.keys(): + bs, _, num_images, img_tiles = inputs["cross_attention_mask"].shape + updated_inputs["cross_attention_mask"] = torch.ones( + (bs, 1, num_images, img_tiles), dtype=torch.int64 + ).numpy() + + for k, v in inputs.items(): + if k not in updated_inputs.keys(): + updated_inputs[k] = v + return updated_inputs + + +class InputHandlerInternVL(InputHandlerVLM): + def __init__(self, batch_size, config, image, processor, prompt, prompt_len, ctx_len, max_gen_len, n_layer): + self.ctx_len = ctx_len + self.prompt_len = prompt_len + self.max_gen_len = max_gen_len + self.config = config + self.image = image + self.prompt = prompt + self.batch_size = batch_size + self.n_layer = n_layer + self.processor = processor + + def prepare_pytorch_inputs(self): + question = "\n" + self.prompt + pixel_values = self.processor.load_image(self.image, max_num=12) + # Chat Template information for prompt preprocessing + messages: List[List[str]] = [] + roles = ("<|im_start|>user\n", "<|im_start|>assistant\n") + prompt = self.processor(pixel_values, question, messages, roles) + inputs = self.processor.tokenizer(prompt, return_tensors="pt") + inputs["pixel_values"] = pixel_values.clone() + + if hasattr(self.config, "text_config"): + txt_cfg = self.config.text_config + else: + txt_cfg = self.config.llm_config + + num_hidden_layers = txt_cfg.num_hidden_layers + num_key_value_heads = txt_cfg.num_key_value_heads + head_dim = txt_cfg.hidden_size // txt_cfg.num_attention_heads + + inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1) - 1 + inputs["past_key_values"] = [] + for i in range(num_hidden_layers): + inputs["past_key_values"].append( + ( + torch.zeros(1, num_key_value_heads, self.ctx_len, head_dim), + torch.zeros(1, num_key_value_heads, self.ctx_len, head_dim), + ) + ) + + return inputs + + def prepare_vlm_ort_inputs(self): + if hasattr(self.config, "text_config"): + txt_cfg = self.config.text_config + else: + txt_cfg = self.config.llm_config + num_hidden_layers = txt_cfg.num_hidden_layers + num_key_value_heads = txt_cfg.num_key_value_heads + head_dim = txt_cfg.hidden_size // txt_cfg.num_attention_heads + + question = "\n" + self.prompt + pixel_values = self.processor.load_image(self.image, max_num=12) + # Chat Template information for prompt preprocessing + messages: List[List[str]] = [] + roles = ("<|im_start|>user\n", "<|im_start|>assistant\n") + prompt = self.processor(pixel_values, question, messages, roles) + inputs = self.processor.tokenizer(prompt, return_tensors="np") + inputs["pixel_values"] = pixel_values.numpy() + + if "attention_mask" in inputs.keys(): + inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1) - 1 + inputs["past_key_values"] = [] + + vision_inputs = { + k: v for k, v in inputs.items() if k in {"pixel_values", "aspect_ratio_ids", "aspect_ratio_mask"} + } + + for i in range(num_hidden_layers): + inputs["past_key." + str(i)] = np.zeros( + (self.batch_size, num_key_value_heads, self.ctx_len, head_dim), dtype=np.float32 + ) + inputs["past_value." + str(i)] = np.zeros( + (self.batch_size, num_key_value_heads, self.ctx_len, head_dim), dtype=np.float32 + ) + lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} + return vision_inputs, lang_inputs diff --git a/QEfficient/utils/run_utils.py b/QEfficient/utils/run_utils.py index 267b2bb9e..f817f56d0 100644 --- a/QEfficient/utils/run_utils.py +++ b/QEfficient/utils/run_utils.py @@ -11,9 +11,10 @@ import onnx import onnxruntime import torch +from transformers import TextStreamer from QEfficient.generation.text_generation_inference import TextGeneration -from QEfficient.utils.generate_inputs import InputHandler +from QEfficient.utils.generate_inputs import InputHandler, InputHandlerInternVL, InputHandlerVLM # TODO: Deprecate this class and encourage the use of `QeffAutoModel...` classes @@ -243,3 +244,201 @@ def run_kv_model_on_cloud_ai_100(self, qpc_path, device_group=None): print("Prompt:", repr(self.input_handler.prompt)) print("Completion:", repr(predicted_string)) return execinfo.generated_ids + + +class ApiRunnerVlm: + """ + ApiRunnerVlm class is responsible for running Vision models: + --------- + + 1. HuggingFace ``PyTorch`` model + 2. Transformed KV Pytorch Model + 3. ``ONNX`` model on ONNXRT + 4. ``ONNX`` model on Cloud AI 100 + """ + + def __init__( + self, batch_size, processor, config, image, conversation, prompt, prompt_len, ctx_len, max_gen_len, n_layer + ): + """ """ + self.input_handler_vlm = InputHandlerVLM( + batch_size=batch_size, + prompt_len=prompt_len, + ctx_len=ctx_len, + max_gen_len=max_gen_len, + config=config, + image=image, + conversation=conversation, + processor=processor, + n_layer=n_layer, + prompt=prompt, + ) + self.processor = processor + self.ctx_len = ctx_len + self.prompt_len = prompt_len + self.batch_size = batch_size + self.config = config + self.gen_len = max_gen_len + + @torch.no_grad() + def run_vlm_hf_model_on_pytorch(self, model, inputs): + output = model.generate(**inputs, max_new_tokens=self.gen_len, do_sample=False) + offset_output = output[0, inputs["input_ids"].shape[1] :] + py_output = self.processor.tokenizer.decode(offset_output).strip() + print("Original HF Model Outputs (Torch CPU):") + print("Completion:", repr(py_output)) + return offset_output + + @torch.no_grad() + def run_vlm_kv_model_on_pytorch(self, model): + generation_len = self.gen_len + generated_ids = torch.full((self.batch_size, generation_len), self.processor.tokenizer.pad_token_id) + inputs = self.input_handler_vlm.prepare_pytorch_inputs() + + outputs = model(**inputs) + inputs["input_ids"] = outputs[0].argmax(2) + if "cross_attention_mask" in inputs: + bs, _, num_images, img_tiles = inputs["cross_attention_mask"].shape + inputs["cross_attention_mask"] = torch.ones((bs, 1, num_images, img_tiles), dtype=torch.int64) + + generated_ids[:, 0] = inputs["input_ids"].squeeze(1) + finished_sequences = inputs["input_ids"] == self.processor.tokenizer.eos_token_id + inputs["position_ids"] = inputs["position_ids"].max(1, keepdim=True).values + 1 + + print("QEFF Model Outputs (Torch CPU):") + streamer = TextStreamer(self.processor.tokenizer) + streamer.put(inputs["input_ids"]) + for num_token in range(1, self.gen_len): + outputs = model(**inputs) + inputs["input_ids"] = outputs[0].argmax(2) + inputs["position_ids"] += 1 + streamer.put(inputs["input_ids"]) + generated_ids[:, num_token] = inputs["input_ids"].squeeze(1) + finished_sequences |= inputs["input_ids"] == self.processor.tokenizer.eos_token_id + if finished_sequences.all(): + break + streamer.end() + return generated_ids[0] + + def run_ort_session(self, inputs, session) -> dict: + output_names = [x.name for x in session.get_outputs()] + session_input_names = [x.name for x in session.get_inputs()] + session_inputs = {} + for inp_name in session_input_names: + if inp_name in inputs.keys(): + session_inputs[inp_name] = inputs[inp_name] + outputs_data = session.run(output_names, session_inputs) + ort_outputs = dict(zip(output_names, outputs_data)) + return ort_outputs + + def setup_ort_session(self, model_path): + m = onnx.load(model_path, load_external_data=False) + # NOTE: OrtValue objects should be kept around until the session is run, hence this dict is required + added_initializers = {} + for node in m.graph.node: + if node.op_type == "Constant": + np_tensor = onnx.numpy_helper.to_array(node.attribute[0].t, os.path.dirname(model_path)) + if len(np_tensor.shape) == 0 and np_tensor.item() == 2147483647: + added_initializers[node.output[0]] = onnxruntime.OrtValue.ortvalue_from_numpy( + np.array(0, np_tensor.dtype) + ) + session_options = onnxruntime.SessionOptions() + for name, value in added_initializers.items(): + session_options.add_initializer(name, value) + session = onnxruntime.InferenceSession(model_path, session_options) + + return added_initializers, session + + def run_vlm_kv_model_on_ort(self, model_path): + vision_inputs, lang_inputs = self.input_handler_vlm.prepare_vlm_ort_inputs() + # TODO: Make a DAG based parser to compile and run N ONNX files with dependencies + ### If kv_offload was `True` + if isinstance(model_path, list): + encoder_path = model_path[0] + decoder_path = model_path[1] + + added_initializers, encoder_session = self.setup_ort_session(encoder_path) + + encoder_ort_outputs = self.run_ort_session(vision_inputs, session=encoder_session) + lang_inputs.update(encoder_ort_outputs) + del added_initializers + ### TEXT COMPONENT RUNNING + + added_initializers, decoder_session = self.setup_ort_session(decoder_path) + generated_ids = [] + + ort_outputs = self.run_ort_session(lang_inputs, session=decoder_session) + ort_outputs = self.input_handler_vlm.update_vlm_ort_outputs(ort_outputs) + for _ in range(1, self.gen_len): + generated_ids.append(ort_outputs["logits"].argmax(-1).reshape(-1, 1)) + lang_inputs = self.input_handler_vlm.update_vlm_ort_inputs(lang_inputs, ort_outputs) + ort_outputs = self.run_ort_session(lang_inputs, decoder_session) + ort_outputs = self.input_handler_vlm.update_vlm_ort_outputs(ort_outputs) + generated_ids.append(ort_outputs["logits"].argmax(-1).reshape(-1, 1)) + generated_ids = np.concatenate(generated_ids, axis=1) + predicted_string = self.processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + print("ORT KV_OFFLOAD Session Outputs:") + print("Completion:", repr(predicted_string)) + del added_initializers + + ### IF MODELPATH IS A SINGLE POSIXPATH + else: + added_initializers, session = self.setup_ort_session(model_path) + generated_ids = [] + inputs = {**vision_inputs, **lang_inputs} + ort_outputs = self.run_ort_session(inputs, session=session) + ort_outputs = self.input_handler_vlm.update_vlm_ort_outputs(ort_outputs) + for _ in range(1, self.gen_len): + generated_ids.append(ort_outputs["logits"].argmax(-1).reshape(-1, 1)) + inputs = self.input_handler_vlm.update_vlm_ort_inputs(inputs, ort_outputs) + ort_outputs = self.run_ort_session(inputs, session) + ort_outputs = self.input_handler_vlm.update_vlm_ort_outputs(ort_outputs) + generated_ids.append(ort_outputs["logits"].argmax(-1).reshape(-1, 1)) + generated_ids = np.concatenate(generated_ids, axis=1) + predicted_string = self.processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + print("ORT Session Outputs:") + print("Completion:", repr(predicted_string)) + del added_initializers + return generated_ids + + +class ApiRunnerInternVL(ApiRunnerVlm): + """ + ApiRunner for InternVL Vision models: + --------- + + 1. HuggingFace ``PyTorch`` model + 2. Transformed KV Pytorch Model + 3. ``ONNX`` model on ONNXRT + 4. ``ONNX`` model on Cloud AI 100 + """ + + def __init__(self, batch_size, processor, config, image, prompt, prompt_len, ctx_len, max_gen_len, n_layer): + """ """ + self.input_handler_vlm = InputHandlerInternVL( + batch_size=batch_size, + prompt_len=prompt_len, + ctx_len=ctx_len, + max_gen_len=max_gen_len, + config=config, + image=image, + processor=processor, + n_layer=n_layer, + prompt=prompt, + ) + self.processor = processor + self.ctx_len = ctx_len + self.prompt_len = prompt_len + self.batch_size = batch_size + self.config = config + self.gen_len = max_gen_len + + @torch.no_grad() + def run_vlm_hf_model_on_pytorch(self, model, inputs, generation_config): + outputs = model.generate(**inputs, **generation_config) + generated_ids = outputs[0].detach().numpy() + + py_output = self.processor.tokenizer.decode(generated_ids, skip_special_tokens=True).strip() + print("Original HF Model Outputs (Torch CPU):") + print("Completion:", repr(py_output)) + return generated_ids diff --git a/QEfficient/utils/test_utils.py b/QEfficient/utils/test_utils.py new file mode 100644 index 000000000..1b5d81c04 --- /dev/null +++ b/QEfficient/utils/test_utils.py @@ -0,0 +1,152 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import torch +import torch.nn as nn +import torchvision.transforms as T +from torchvision.transforms.functional import InterpolationMode + + +# Processor class for InternVL models +class InternProcessor: + """ + InternVL model only has an AutoTokenizer so this class performs the processing tasks similar to an AutoProcessor. + The methods used here are borrowed from the original InternVL modelling files. + "https://huggingface.co/OpenGVLab/InternVL2_5-1B/" + """ + + def __init__(self, model: nn.Module, tokenizer): + self.model = model + image_size = self.model.config.force_image_size or self.model.config.vision_config.image_size + patch_size = self.model.config.vision_config.patch_size + self.template = model.config.template + self.num_image_token = int((image_size // patch_size) ** 2 * (self.model.config.downsample_ratio**2)) + self.tokenizer = tokenizer + self.IMAGENET_MEAN = (0.485, 0.456, 0.406) + self.IMAGENET_STD = (0.229, 0.224, 0.225) + + def build_transform(self, input_size): + MEAN, STD = self.IMAGENET_MEAN, self.IMAGENET_STD + transform = T.Compose( + [ + T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), + T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=MEAN, std=STD), + ] + ) + return transform + + def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size): + best_ratio_diff = float("inf") + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + def dynamic_preprocess(self, image, min_num=1, max_num=12, image_size=448, use_thumbnail=False): + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + # calculate the existing image aspect ratio + target_ratios = set( + (i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if i * j <= max_num and i * j >= min_num + ) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + # find the closest aspect ratio to the target + target_aspect_ratio = self.find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size + ) + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images + + # Process the input messages to generate prompt for the model. + def get_prompt(self, messages) -> str: + """Get the prompt for generation.""" + ## Chat template used for InternVL + system_prompt = "<|im_start|>system\n你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。" + sep = "<|im_end|>\n" + + ret = system_prompt + sep + for role, message in messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + message + sep + else: + ret += role + return ret + + def load_image(self, image, input_size=448, max_num=12): + transform = self.build_transform(input_size=input_size) + images = self.dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) + pixel_values = [transform(image) for image in images] + pixel_values = torch.stack(pixel_values) + return pixel_values + + def __call__( + self, + pixel_values, + question, + messages, + roles, + history=None, + num_patches_list=None, + IMG_START_TOKEN="", + IMG_END_TOKEN="", + IMG_CONTEXT_TOKEN="", + verbose=False, + ) -> str: + if history is None and pixel_values is not None and "" not in question: + question = "\n" + question + if num_patches_list is None: + num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else [] + assert pixel_values is None or len(pixel_values) == sum(num_patches_list) + img_context_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) + self.model.img_context_token_id = img_context_token_id + + messages.append([roles[0], question]) + messages.append([roles[1], None]) + query = self.get_prompt(messages) + if verbose and pixel_values is not None: + image_bs = pixel_values.shape[0] + print(f"dynamic ViT batch size: {image_bs}") + for num_patches in num_patches_list: + image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN + query = query.replace("", image_tokens, 1) + return query diff --git a/scripts/Jenkinsfile b/scripts/Jenkinsfile index 26278c359..fcd2fece5 100644 --- a/scripts/Jenkinsfile +++ b/scripts/Jenkinsfile @@ -24,6 +24,7 @@ pipeline { pip install .[test] && pip install junitparser pytest-xdist && pip install librosa==0.10.2 soundfile==0.13.1 && #packages needed to load example for whisper testing + pip install --extra-index-url https://download.pytorch.org/whl/cpu timm==1.0.14 torchvision==0.19.1+cpu einops==0.8.1 && #packages to load VLMs rm -rf QEfficient" ''' } @@ -67,6 +68,23 @@ pipeline { } } } + stage('QAIC MultiModal Tests') { + steps { + timeout(time: 60, unit: 'MINUTES') { + sh ''' + sudo docker exec ${BUILD_TAG} bash -c " + cd /efficient-transformers && + . preflight_qeff/bin/activate && + mkdir -p $PWD/Non_cli_qaic_multimodal && + export TOKENIZERS_PARALLELISM=false && + export QEFF_HOME=$PWD/Non_cli_qaic_multimodal && + pytest tests -m '(not cli) and (on_qaic) and (multimodal) and (not qnn)' --ignore tests/vllm -n 4 --junitxml=tests/tests_log6.xml && + junitparser merge tests/tests_log6.xml tests/tests_log.xml && + deactivate" + ''' + } + } + } stage('CLI Tests') { steps { timeout(time: 60, unit: 'MINUTES') { diff --git a/tests/transformers/models/test_image_text_to_text_models.py b/tests/transformers/models/test_image_text_to_text_models.py new file mode 100644 index 000000000..199ac0160 --- /dev/null +++ b/tests/transformers/models/test_image_text_to_text_models.py @@ -0,0 +1,361 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +from io import BytesIO +from typing import List + +import pytest +import requests +from PIL import Image +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForImageTextToText, + AutoProcessor, + AutoTokenizer, + TextStreamer, +) + +from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM, QEFFAutoModelForImageTextToText +from QEfficient.utils import hf_download +from QEfficient.utils._utils import get_num_layers_vlm +from QEfficient.utils.device_utils import get_available_device_id +from QEfficient.utils.run_utils import ApiRunnerInternVL, ApiRunnerVlm +from QEfficient.utils.test_utils import InternProcessor + +HF_TOKEN = "" +NEW_GENERATION_TOKENS = 10 +test_models_config = [ + # CONFIG PARAMS NEEDED FOR A MODEL TO BE TESTED + # ( + # model_name, + # kv_offload, + # batch_size, + # prompt_len, + # ctx_len, + # img_size, + # img_url", + # text_prompt, + # number of layers of the model, + # ), + ( + "llava-hf/llava-1.5-7b-hf", + True, + 1, + 784, + 1024, + 336, + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg", + "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud", + 1, + ), + ( + "llava-hf/llava-1.5-7b-hf", + False, + 1, + 784, + 1024, + 336, + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg", + "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud", + 1, + ), + # ( + # "meta-llama/Llama-3.2-11B-Vision-Instruct", + # True, + # 1, + # 32, + # 512, + # 560, + # "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + # "Explain this image", + # 7, + # ), +] + +intern_model_config = [ + ( + "OpenGVLab/InternVL2_5-1B", + True, + 1, + 384, + 512, + "https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-1-2048.jpg", + "Please describe the image in detail.", + 2, + ), + ( + "OpenGVLab/InternVL2_5-1B", + False, + 1, + 384, + 512, + "https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-1-2048.jpg", + "Please describe the image in detail.", + 2, + ), +] + + +def load_image_text_to_text_model(model_config): + model_path = hf_download( + repo_id=model_config._name_or_path, + hf_token=HF_TOKEN, + ignore_patterns=["*.onnx", "*.ot", "*.md", "*.tflite", "*.pdf", "*.h5", "*.msgpack"], + ) + try: + model_hf = AutoModelForImageTextToText.from_pretrained( + model_path, + low_cpu_mem_usage=False, + token=HF_TOKEN, + config=model_config, + ) + except ValueError: + model_hf = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=False, + token=HF_TOKEN, + trust_remote_code=True, + config=model_config, + ) + params = sum(p.numel() for p in model_hf.parameters()) + model_hf.eval() + return model_hf, params + + +def set_num_layers(config, n_layer=1): + ## -1 indicates use all the layers of the model. + if n_layer == -1: + return config + elif hasattr(config, "model_type") and "mllama" in config.model_type: + config.text_config.num_hidden_layers = n_layer + config.text_config.cross_attention_layers = [ + x for x in config.text_config.cross_attention_layers if x < n_layer + ] + elif hasattr(config, "text_config"): + config.text_config.num_hidden_layers = n_layer + config.vision_config.num_hidden_layers = n_layer + elif hasattr(config, "llm_config"): + config.llm_config.num_hidden_layers = n_layer + config.vision_config.num_hidden_layers = n_layer + return config + + +def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( + model_name: str, + img_size: int, + img_url: str, + query: str, + prompt_len: int, + ctx_len: int, + max_gen_len: int = 20, + batch_size: int = 1, + n_layer: int = 1, + kv_offload: bool = False, + num_devices: int = 1, +): + model_config = {"model_name": model_name} + model_config["img_size"] = img_size + config = AutoConfig.from_pretrained( + model_config["model_name"], token=HF_TOKEN, trust_remote_code=True, padding=True + ) + config = set_num_layers(config, n_layer=n_layer) + model_hf, _ = load_image_text_to_text_model(config) + processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True, padding=True) + + n_layer = get_num_layers_vlm(config) + image = Image.open(requests.get(img_url, stream=True).raw) + conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": query}, + {"type": "image"}, + ], + }, + ] + prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) + api_runner = ApiRunnerVlm( + batch_size, + processor, + config, + image, + conversation, + prompt, + prompt_len, + ctx_len, + max_gen_len, + n_layer, + ) + + inputs = processor(images=image, text=prompt, return_tensors="pt") + streamer = TextStreamer(processor.tokenizer) + pytorch_hf_tokens = api_runner.run_vlm_hf_model_on_pytorch(model_hf, inputs) + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_config["model_name"], + kv_offload=kv_offload, + config=config, + token=HF_TOKEN, + ) + + # pytorch_kv_tokens = api_runner.run_vlm_kv_model_on_pytorch(qeff_model.model) + # assert (pytorch_kv_tokens == pytorch_hf_tokens).all(), ( + # "Tokens don't match for pytorch HF output and pytorch KV output" + # ) + + qeff_model.export() + # onnx_model_path = qeff_model.export() + # ort_tokens = api_runner.run_vlm_kv_model_on_ort(onnx_model_path) + # assert (pytorch_hf_tokens == ort_tokens).all(), "Tokens don't match for pytorch HF output and ORT output" + if not get_available_device_id(): + pytest.skip("No available devices to run model on Cloud AI 100") + qeff_model.compile( + img_size=model_config["img_size"], + num_devices=num_devices, + prefill_seq_len=prompt_len, + ctx_len=ctx_len, + mxfp6=False, + ) + inputs = processor(images=image, text=prompt, return_tensors="pt") + print("QPC Outputs (QAIC):") + output = qeff_model.generate(inputs=inputs, generation_len=NEW_GENERATION_TOKENS, streamer=streamer) + qpc_tokens = output.generated_ids[:, :-1] + assert (pytorch_hf_tokens == qpc_tokens).all(), "Tokens don't match for pytorch HF output and QPC output" + return + + +def check_intern_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( + model_name: str, + img_url: str, + query: str, + prompt_len: int, + ctx_len: int, + max_gen_len: int = 20, + batch_size: int = 1, + n_layer: int = 1, + kv_offload: bool = False, + num_devices: int = 1, +): + model_config = {"model_name": model_name} + + config = AutoConfig.from_pretrained(model_config["model_name"], trust_remote_code=True) + config._attn_implementation = "eager" + config = set_num_layers(config, n_layer=n_layer) + model_hf, _ = load_image_text_to_text_model(config) + n_layer = get_num_layers_vlm(config) + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False) + processor = InternProcessor(model_hf, tokenizer) + img = requests.get(img_url, stream=True) + image = Image.open(BytesIO(img.content)).convert("RGB") + image = image.resize((448, 448)) + + api_runner = ApiRunnerInternVL( + batch_size, + processor, + config, + image, + query, + prompt_len, + ctx_len, + max_gen_len, + n_layer, + ) + pixel_values = processor.load_image(image, max_num=12) + question = "\n" + query + # Chat Template information for prompt preprocessing + messages: List[List[str]] = [] + roles = ("<|im_start|>user\n", "<|im_start|>assistant\n") + prompt = processor(pixel_values, question, messages, roles) + + inputs = tokenizer(prompt, return_tensors="pt") + batch_size, prompt_len = inputs["input_ids"].shape + inputs["pixel_values"] = pixel_values.clone() + + generation_config = dict(max_new_tokens=max_gen_len, do_sample=False) + generation_config["eos_token_id"] = tokenizer.convert_tokens_to_ids("<|im_end|>\n".strip()) + pytorch_hf_tokens = api_runner.run_vlm_hf_model_on_pytorch(model_hf, inputs, generation_config) + + qeff_model = QEFFAutoModelForCausalLM.from_pretrained( + model_config["model_name"], + kv_offload=kv_offload, + config=config, + token=HF_TOKEN, + ) + # pytorch_kv_tokens = api_runner.run_vlm_kv_model_on_pytorch(qeff_model.model) + # assert (pytorch_hf_tokens == pytorch_kv_tokens).all(), ( + # "Tokens don't match for pytorch HF output and QEFF KV Model output" + # ) + + streamer = TextStreamer(processor.tokenizer) + qeff_model.export() + + # onnx_model_path = qeff_model.export() + # ort_tokens = api_runner.run_vlm_kv_model_on_ort(onnx_model_path) + # assert (pytorch_hf_tokens == ort_tokens).all(), "Tokens don't match for pytorch HF output and ORT output" + if not get_available_device_id(): + pytest.skip("No available devices to run model on Cloud AI 100") + qeff_model.compile( + num_patches=1, + num_devices=num_devices, + prefill_seq_len=prompt_len, + ctx_len=ctx_len, + mxfp6=False, + ) + print("QPC Outputs (QAIC):") + output = qeff_model.generate(inputs=inputs, generation_len=NEW_GENERATION_TOKENS, streamer=streamer) + qpc_tokens = output.generated_ids[:, :-1] + assert (pytorch_hf_tokens == qpc_tokens).all(), "Tokens don't match for pytorch HF output and QPC output" + return + + +@pytest.mark.on_qaic +@pytest.mark.multimodal +@pytest.mark.parametrize( + "model_name, kv_offload, batch_size, prompt_len, ctx_len, img_size, img_url, query, n_layer", test_models_config +) +def test_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( + model_name, kv_offload, batch_size, prompt_len, ctx_len, img_size, img_url, query, n_layer +): + """ + Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, + prompt_len=prompt_len, + ctx_len=ctx_len, + max_gen_len=NEW_GENERATION_TOKENS, + img_size=img_size, + img_url=img_url, + query=query, + n_layer=n_layer, + batch_size=batch_size, + kv_offload=kv_offload, + ) + + +@pytest.mark.on_qaic +@pytest.mark.multimodal +@pytest.mark.parametrize( + "model_name, kv_offload, batch_size, prompt_len, ctx_len, img_url, query, n_layer", intern_model_config +) +def test_image_text_to_text_intern_pytorch_vs_kv_vs_ort_vs_ai100( + model_name, kv_offload, batch_size, prompt_len, ctx_len, img_url, query, n_layer +): + check_intern_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, + prompt_len=prompt_len, + ctx_len=ctx_len, + max_gen_len=NEW_GENERATION_TOKENS, + img_url=img_url, + query=query, + n_layer=n_layer, + batch_size=batch_size, + kv_offload=kv_offload, + )