Skip to content

Commit 01dffb6

Browse files
authored
Docs string added for the Image class and granite models are added in validation page (#303)
Signed-off-by: Abukhoyer Shaik <[email protected]>
1 parent 7e6c3f0 commit 01dffb6

File tree

3 files changed

+70
-8
lines changed

3 files changed

+70
-8
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1147,9 +1147,69 @@ def get_model_config(self) -> dict:
11471147

11481148
class QEFFAutoModelForImageTextToText:
11491149
"""
1150-
A factory class for creating QEFFAutoModelForImageTextToText instances with for single and Dual QPC approach
1150+
The QEFFAutoModelForImageTextToText class is used to work with multimodal language models from the HuggingFace hub.
1151+
While you can initialize the class directly, it's best to use the ``from_pretrained`` method for this purpose. This class supports both single and dual QPC approaches.
11511152
Attributes:
11521153
_hf_auto_class (class): The Hugging Face AutoModel class for ImageTextToText models.
1154+
1155+
``Mandatory`` Args:
1156+
:pretrained_model_name_or_path (str): Model card name from HuggingFace or local path to model directory.
1157+
1158+
``Optional`` Args:
1159+
:kv_offload (bool): Flag to toggle between single and dual QPC approaches. If set to False, the Single QPC approach will be used; otherwise, the dual QPC approach will be applied. Defaults to True.
1160+
1161+
.. code-block:: python
1162+
import requests
1163+
from PIL import Image
1164+
from transformers import AutoProcessor, TextStreamer
1165+
1166+
from QEfficient import QEFFAutoModelForImageTextToText
1167+
1168+
# Add HuggingFace Token to access the model
1169+
HF_TOKEN = ""
1170+
model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
1171+
query = "Describe this image."
1172+
image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
1173+
1174+
## STEP - 1 Load the Processor and Model, and kv_offload=True/False for dual and single qpc
1175+
processor = AutoProcessor.from_pretrained(model_name, token=token)
1176+
model = QEFFAutoModelForImageTextToText.from_pretrained(model_name, token=token, attn_implementation="eager", kv_offload=False)
1177+
1178+
## STEP - 2 Export & Compile the Model
1179+
model.compile(
1180+
prefill_seq_len=32,
1181+
ctx_len=512,
1182+
img_size=560,
1183+
num_cores=16,
1184+
num_devices=1,
1185+
mxfp6_matmul=False,
1186+
)
1187+
1188+
## STEP - 3 Load and process the inputs for Inference
1189+
image = Image.open(requests.get(image_url, stream=True).raw)
1190+
messages = [
1191+
{
1192+
"role": "user",
1193+
"content": [
1194+
{"type": "image"},
1195+
{"type": "text", "text": query},
1196+
],
1197+
}
1198+
]
1199+
input_text = [processor.apply_chat_template(messages, add_generation_prompt=True)]
1200+
inputs = processor(
1201+
text=input_text,
1202+
images=image,
1203+
return_tensors="pt",
1204+
add_special_tokens=False,
1205+
padding="max_length",
1206+
max_length=prefill_seq_len,
1207+
)
1208+
1209+
## STEP - 4 Run Inference on the compiled model
1210+
streamer = TextStreamer(processor.tokenizer)
1211+
model.generate(inputs=inputs, streamer=streamer, generation_len=generation_len)
1212+
11531213
"""
11541214

11551215
_hf_auto_class = AutoModelForImageTextToText

docs/source/quick_start.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ Use the qualcomm_efficient_converter API to export the KV transformed Model to O
239239

240240
generated_qpc_path = qeff_model.compile(
241241
num_cores=14,
242-
mxfp6=True,
242+
mxfp6_matmul=True,
243243
)
244244
```
245245

@@ -250,8 +250,8 @@ Benchmark the model on Cloud AI 100, run the infer API to print tokens and tok/s
250250
```Python
251251
# post compilation, we can print the latency stats for the kv models, We provide API to print token and Latency stats on AI 100
252252
# We need the compiled prefill and decode qpc to compute the token generated, This is based on Greedy Sampling Approach
253-
254-
qeff_model.generate(prompts=["My name is"])
253+
tokenizer = AutoTokenizer.from_pretrained(model_name)
254+
qeff_model.generate(prompts=["My name is"],tokenizer=tokenizer)
255255
```
256256
End to End demo examples for various models are available in **notebooks** directory. Please check them out.
257257

docs/source/validate.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,15 @@
4141

4242
| Architecture | Model Family | Representative Models |
4343
|--------------|--------------|---------------------------------|
44-
| **BertModel** | BERT-based | [BAAI/bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5)<br> [BAAI/bge-large-en-v1.5](https://huggingface.co/BAAI/bge-large-en-v1.5)<br>[BAAI/bge-small-en-v1.5](https://huggingface.co/BAAI/bge-small-en-v1.5) <br>[e5-large-v2](https://huggingface.co/intfloat/e5-large-v2) |
44+
| **BertModel** | BERT-based | [BAAI/bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5)<br> [BAAI/bge-large-en-v1.5](https://huggingface.co/BAAI/bge-large-en-v1.5)<br>[BAAI/bge-small-en-v1.5](https://huggingface.co/BAAI/bge-small-en-v1.5) <br>[e5-large-v2](https://huggingface.co/intfloat/e5-large-v2) |
4545
| **LlamaModel** | Llama-based | [intfloat/e5-mistral-7b-instruct](https://huggingface.co/intfloat/e5-mistral-7b-instruct) |
46-
| **Qwen2ForCausalLM** | Qwen2 | [stella_en_1.5B_v5](https://huggingface.co/NovaSearch/stella_en_1.5B_v5) |
47-
| **XLMRobertaForSequenceClassification** | XLM-RoBERTa | [bge-reranker-v2-m3bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3) |
4846
| **MPNetForMaskedLM** | MPNet | [sentence-transformers/multi-qa-mpnet-base-cos-v1](https://huggingface.co/sentence-transformers/multi-qa-mpnet-base-cos-v1) |
49-
| **NomicBertModel** | NomicBERT | [nomic-embed-text-v1.5](https://huggingface.co/nomic-ai/nomic-embed-text-v1.5) |
5047
| **MistralModel** | Mistral | [e5-mistral-7b-instruct](https://huggingface.co/intfloat/e5-mistral-7b-instruct) |
48+
| **NomicBertModel** | NomicBERT | [nomic-embed-text-v1.5](https://huggingface.co/nomic-ai/nomic-embed-text-v1.5) |
49+
| **Qwen2ForCausalLM** | Qwen2 | [stella_en_1.5B_v5](https://huggingface.co/NovaSearch/stella_en_1.5B_v5) |
50+
| **RobertaModel** | RoBERTa | [ibm-granite/granite-embedding-30m-english](https://huggingface.co/ibm-granite/granite-embedding-30m-english)<br> [ibm-granite/granite-embedding-125m-english](https://huggingface.co/ibm-granite/granite-embedding-125m-english) |
51+
| **XLMRobertaForSequenceClassification** | XLM-RoBERTa | [bge-reranker-v2-m3bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3) |
52+
| **XLMRobertaModel** | XLM-RoBERTa |[ibm-granite/granite-embedding-107m-multilingual](https://huggingface.co/ibm-granite/granite-embedding-107m-multilingual)<br> [ibm-granite/granite-embedding-278m-multilingual](https://huggingface.co/ibm-granite/granite-embedding-278m-multilingual) |
5153

5254
## Multimodal Language Models
5355

0 commit comments

Comments
 (0)