diff --git a/QEfficient/base/common.py b/QEfficient/base/common.py index 1c393c2fe..fb35fcbb6 100644 --- a/QEfficient/base/common.py +++ b/QEfficient/base/common.py @@ -16,10 +16,9 @@ from typing import Any from transformers import AutoConfig -from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES from QEfficient.base.modeling_qeff import QEFFBaseModel -from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM +from QEfficient.transformers.modeling_utils import MODEL_CLASS_MAPPING from QEfficient.utils import login_and_download_hf_lm @@ -44,8 +43,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) -> config = AutoConfig.from_pretrained(pretrained_model_name_or_path) architecture = config.architectures[0] if config.architectures else None - if architecture in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): - model_class = QEFFAutoModelForCausalLM + class_name = MODEL_CLASS_MAPPING.get(architecture) + if class_name: + module = __import__("QEfficient.transformers.models.modeling_auto") + model_class = getattr(module, class_name) else: raise NotImplementedError( f"Unknown architecture={architecture}, either use specific auto model class for loading the model or raise an issue for support!" diff --git a/QEfficient/cloud/infer.py b/QEfficient/cloud/infer.py index 28eaa4d52..68be72fa8 100644 --- a/QEfficient/cloud/infer.py +++ b/QEfficient/cloud/infer.py @@ -10,11 +10,86 @@ import sys from typing import List, Optional +import requests +from PIL import Image +from transformers import PreTrainedModel, TextStreamer +from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES + from QEfficient.base.common import QEFFCommonLoader -from QEfficient.utils import check_and_assign_cache_dir, load_hf_tokenizer +from QEfficient.utils import check_and_assign_cache_dir, load_hf_processor, load_hf_tokenizer from QEfficient.utils.logging_utils import logger +# TODO: Remove after adding support for VLM's compile and execute +def execute_vlm_model( + qeff_model: PreTrainedModel, + model_name: str, + image_url: str, + image_path: str, + prompt: Optional[str] = None, # type: ignore + device_group: Optional[List[int]] = None, + local_model_dir: Optional[str] = None, + cache_dir: Optional[str] = None, + hf_token: Optional[str] = None, + generation_len: Optional[int] = None, +): + """ + This method generates output by executing the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards. + ``Mandatory`` Args: + :qeff_model (PreTrainedModel): QEfficient model object. + :model_name (str): Hugging Face Model Card name, Example: ``llava-hf/llava-1.5-7b-hf`` + :image_url (str): Image URL to be used for inference. ``Defaults to None.`` + :image_path (str): Image path to be used for inference. ``Defaults to None.`` + ``Optional`` Args: + :prompt (str): Sample prompt for the model text generation. ``Defaults to None.`` + :device_group (List[int]): Device Ids to be used for compilation. If ``len(device_group) > 1``, multiple Card setup is enabled. ``Defaults to None.`` + :local_model_dir (str): Path to custom model weights and config files. ``Defaults to None.`` + :cache_dir (str): Cache dir where downloaded HuggingFace files are stored. ``Defaults to None.`` + :hf_token (str): HuggingFace login token to access private repos. ``Defaults to None.`` + :generation_len (int): Number of tokens to be generated. ``Defaults to None.`` + Returns: + :dict: Output from the ``AI_100`` runtime. + """ + if not (image_url or image_path): + raise ValueError('Neither Image URL nor Image Path is found, either provide "image_url" or "image_path"') + raw_image = Image.open(requests.get(image_url, stream=True).raw) if image_url else Image.open(image_path) + + processor = load_hf_processor( + pretrained_model_name_or_path=(local_model_dir if local_model_dir else model_name), + cache_dir=cache_dir, + hf_token=hf_token, + ) + + # Added for QEff version 1.20 supported VLM models (mllama and llava) + conversation = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": prompt[0]}, + ], + } + ] + + # Converts a list of dictionaries with `"role"` and `"content"` keys to a list of token ids. + input_text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) + + split_inputs = processor( + text=input_text, + images=raw_image, + return_tensors="pt", + add_special_tokens=False, + ) + streamer = TextStreamer(processor.tokenizer) + output = qeff_model.generate( + inputs=split_inputs, + streamer=streamer, + device_ids=device_group, + generation_len=generation_len, + ) + return output + + def main( model_name: str, num_cores: int, @@ -65,6 +140,9 @@ def main( :allow_mxint8_mdp_io (bool): Allows MXINT8 compression of MDP IO traffic. ``Defaults to False.`` :enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.`` :qnn_config (str): Path of QNN Config parameters file. ``Defaults to None.`` + :kwargs: Pass any compiler option as input. Any flag that is supported by `qaic-exec` can be passed. Params are converted to flags as below: + -allocator_dealloc_delay=1 -> -allocator-dealloc-delay=1 + -qpc_crc=True -> -qpc-crc .. code-block:: bash @@ -72,11 +150,6 @@ def main( """ cache_dir = check_and_assign_cache_dir(local_model_dir, cache_dir) - tokenizer = load_hf_tokenizer( - pretrained_model_name_or_path=(local_model_dir if local_model_dir else model_name), - cache_dir=cache_dir, - hf_token=hf_token, - ) if "--mxfp6" in sys.argv: if args.mxfp6: @@ -93,6 +166,17 @@ def main( local_model_dir=local_model_dir, ) + image_path = kwargs.pop("image_path", None) + image_url = kwargs.pop("image_url", None) + + config = qeff_model.model.config + architecture = config.architectures[0] if config.architectures else None + + if architecture not in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values() and ( + kwargs.pop("img_size", None) or image_path or image_url + ): + logger.warning(f"Skipping image arguments as they are not valid for {architecture}") + ######### # Compile ######### @@ -116,14 +200,34 @@ def main( ######### # Execute ######### - _ = qeff_model.generate( - tokenizer, - prompts=prompt, - device_id=device_group, - prompt=prompt, - prompts_txt_file_path=prompts_txt_file_path, - generation_len=generation_len, - ) + if architecture in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values(): + exec_info = execute_vlm_model( + qeff_model=qeff_model, + model_name=model_name, + prompt=prompt, + image_url=image_url, + image_path=image_path, + device_group=device_group, + local_model_dir=local_model_dir, + cache_dir=cache_dir, + hf_token=hf_token, + generation_len=generation_len, + ) + print(exec_info) + else: + tokenizer = load_hf_tokenizer( + pretrained_model_name_or_path=(local_model_dir if local_model_dir else model_name), + cache_dir=cache_dir, + hf_token=hf_token, + ) + _ = qeff_model.generate( + tokenizer, + prompts=prompt, + device_id=device_group, + prompt=prompt, + prompts_txt_file_path=prompts_txt_file_path, + generation_len=generation_len, + ) if __name__ == "__main__": @@ -219,23 +323,25 @@ def main( parser.add_argument( "--enable_qnn", "--enable-qnn", - action="store_true", + nargs="?", + const=True, + type=str, default=False, help="Enables QNN. Optionally, a configuration file can be provided with [--enable_qnn CONFIG_FILE].\ If not provided, the default configuration will be used.\ Sample Config: QEfficient/compile/qnn_config.json", ) - parser.add_argument( - "qnn_config", - nargs="?", - type=str, - ) args, compiler_options = parser.parse_known_args() + + if isinstance(args.enable_qnn, str): + args.qnn_config = args.enable_qnn + args.enable_qnn = True + compiler_options_dict = {} for i in range(0, len(compiler_options)): if compiler_options[i].startswith("--"): - key = compiler_options[i].lstrip("-") + key = compiler_options[i].lstrip("-").replace("-", "_") value = ( compiler_options[i + 1] if i + 1 < len(compiler_options) and not compiler_options[i + 1].startswith("-") diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index 548d8ef80..c9d16d397 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -10,6 +10,7 @@ import torch import torch.nn as nn +import transformers.models.auto.modeling_auto as mapping from transformers import AutoModelForCausalLM from transformers.models.codegen.modeling_codegen import ( CodeGenAttention, @@ -284,6 +285,15 @@ } +MODEL_CLASS_MAPPING = { + **{architecture: "QEFFAutoModelForCausalLM" for architecture in mapping.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values()}, + **{ + architecture: "QEFFAutoModelForImageTextToText" + for architecture in mapping.MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values() + }, +} + + def _prepare_cross_attention_mask( cross_attention_mask: torch.Tensor, num_vision_tokens: int, diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 7faaff590..0182c4ef1 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1257,6 +1257,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona if kwargs.get("low_cpu_mem_usage", None): logger.warning("Updating low_cpu_mem_usage=False") + if kwargs.pop("continuous_batching", None): + NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") + kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) return cls(model, kv_offload=kv_offload, **kwargs) diff --git a/QEfficient/utils/__init__.py b/QEfficient/utils/__init__.py index a7f17e6bc..f6aa3296d 100755 --- a/QEfficient/utils/__init__.py +++ b/QEfficient/utils/__init__.py @@ -17,6 +17,7 @@ get_padding_shape_from_config, get_qpc_dir_path, hf_download, + load_hf_processor, load_hf_tokenizer, login_and_download_hf_lm, onnx_exists, diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index cc1353d20..dd4ddd0cf 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -510,7 +510,7 @@ def create_and_dump_qconfigs( # Extract QNN SDK details from YAML file if the environment variable is set qnn_sdk_details = None qnn_sdk_path = os.getenv(QnnConstants.QNN_SDK_PATH_ENV_VAR_NAME) - if qnn_sdk_path: + if enable_qnn and qnn_sdk_path: qnn_sdk_yaml_path = os.path.join(qnn_sdk_path, QnnConstants.QNN_SDK_YAML) with open(qnn_sdk_yaml_path, "r") as file: qnn_sdk_details = yaml.safe_load(file) diff --git a/scripts/Jenkinsfile b/scripts/Jenkinsfile index f3f2fc2d8..24113f9c8 100644 --- a/scripts/Jenkinsfile +++ b/scripts/Jenkinsfile @@ -69,7 +69,7 @@ pipeline { } stage('CLI Tests') { steps { - timeout(time: 15, unit: 'MINUTES') { + timeout(time: 60, unit: 'MINUTES') { sh ''' sudo docker exec ${BUILD_TAG} bash -c " source /qnn_sdk/bin/envsetup.sh && diff --git a/tests/cloud/test_infer_vlm.py b/tests/cloud/test_infer_vlm.py new file mode 100644 index 000000000..d06e09946 --- /dev/null +++ b/tests/cloud/test_infer_vlm.py @@ -0,0 +1,41 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import pytest + +from QEfficient.cloud.infer import main as infer + + +@pytest.mark.on_qaic +@pytest.mark.cli +@pytest.mark.multimodal +@pytest.mark.usefixtures("clean_up_after_test") +def test_vlm_cli(setup, mocker): + ms = setup + # Taking some values from setup fixture and assigning other's based on model's requirement. + # For example, mxint8 is not required for VLM models, so assigning False. + infer( + model_name="llava-hf/llava-1.5-7b-hf", + num_cores=ms.num_cores, + prompt="Describe the image.", + prompts_txt_file_path=None, + aic_enable_depth_first=ms.aic_enable_depth_first, + mos=ms.mos, + batch_size=1, + full_batch_size=None, + prompt_len=1024, + ctx_len=2048, + generation_len=20, + mxfp6=False, + mxint8=False, + local_model_dir=None, + cache_dir=None, + hf_token=ms.hf_token, + enable_qnn=False, + qnn_config=None, + image_url="https://i.etsystatic.com/8155076/r/il/0825c2/1594869823/il_fullxfull.1594869823_5x0w.jpg", + )