|
12 | 12 |
|
13 | 13 | import requests
|
14 | 14 | from PIL import Image
|
15 |
| -from transformers import AutoProcessor, TextStreamer |
| 15 | +from transformers import AutoProcessor, TextStreamer, PreTrainedModel |
16 | 16 | from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
|
17 | 17 |
|
18 | 18 | from QEfficient.base.common import QEFFCommonLoader
|
19 | 19 | from QEfficient.utils import check_and_assign_cache_dir, constants, load_hf_tokenizer
|
20 | 20 | from QEfficient.utils.logging_utils import logger
|
21 | 21 |
|
| 22 | +def execute_vlm_model( |
| 23 | + qeff_model: PreTrainedModel, |
| 24 | + model_name: str, |
| 25 | + image_url: str, |
| 26 | + image_path: str, |
| 27 | + prompt: Optional[str] = None, #type: ignore |
| 28 | + device_group: Optional[List[int]] = None, |
| 29 | + generation_len: Optional[int] = None, |
| 30 | +): |
| 31 | + if not (image_url or image_path): |
| 32 | + raise ValueError('Neither Image URL nor Image Path is found, either provide "image_url" or "image_path"') |
| 33 | + raw_image = Image.open(requests.get(image_url, stream=True).raw) if image_url else Image.open(image_path) |
| 34 | + |
| 35 | + processor = AutoProcessor.from_pretrained(model_name, use_fast=False) |
| 36 | + |
| 37 | + conversation = constants.Constants.conversation |
| 38 | + conversation[0]["content"][1].update({"text": prompt[0]}) # Currently accepting only 1 prompt |
| 39 | + |
| 40 | + # Converts a list of dictionaries with `"role"` and `"content"` keys to a list of token ids. |
| 41 | + input_text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) |
| 42 | + |
| 43 | + split_inputs = processor( |
| 44 | + text=input_text, |
| 45 | + images=raw_image, |
| 46 | + return_tensors="pt", |
| 47 | + add_special_tokens=False, |
| 48 | + ) |
| 49 | + streamer = TextStreamer(processor.tokenizer) |
| 50 | + output = qeff_model.generate( |
| 51 | + inputs=split_inputs, |
| 52 | + streamer=streamer, |
| 53 | + device_ids=device_group, |
| 54 | + generation_len=generation_len, |
| 55 | + ) |
| 56 | + return output |
22 | 57 |
|
23 | 58 | def main(
|
24 | 59 | model_name: str,
|
@@ -130,32 +165,16 @@ def main(
|
130 | 165 | # Execute
|
131 | 166 | #########
|
132 | 167 | if architecture in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values():
|
133 |
| - processor = AutoProcessor.from_pretrained(model_name, use_fast=False) |
134 |
| - |
135 |
| - if not (image_url or image_path): |
136 |
| - raise ValueError('Neither Image URL nor Image Path is found, either provide "image_url" or "image_path"') |
137 |
| - raw_image = Image.open(requests.get(image_url, stream=True).raw) if image_url else Image.open(image_path) |
138 |
| - |
139 |
| - conversation = constants.Constants.conversation |
140 |
| - conversation[0]["content"][1].update({"text": prompt[0]}) # Currently accepting only 1 prompt |
141 |
| - |
142 |
| - # Converts a list of dictionaries with `"role"` and `"content"` keys to a list of token ids. |
143 |
| - input_text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) |
144 |
| - |
145 |
| - split_inputs = processor( |
146 |
| - text=input_text, |
147 |
| - images=raw_image, |
148 |
| - return_tensors="pt", |
149 |
| - add_special_tokens=False, |
150 |
| - ) |
151 |
| - streamer = TextStreamer(processor.tokenizer) |
152 |
| - output = qeff_model.generate( |
153 |
| - inputs=split_inputs, |
154 |
| - streamer=streamer, |
155 |
| - device_ids=device_group, |
| 168 | + exec_info = execute_vlm_model( |
| 169 | + qeff_model=qeff_model, |
| 170 | + model_name=model_name, |
| 171 | + prompt=prompt, |
| 172 | + image_url=image_url, |
| 173 | + image_path=image_path, |
| 174 | + device_group=device_group, |
156 | 175 | generation_len=generation_len,
|
157 | 176 | )
|
158 |
| - print(output) |
| 177 | + print(exec_info) |
159 | 178 | else:
|
160 | 179 | tokenizer = load_hf_tokenizer(
|
161 | 180 | pretrained_model_name_or_path=(local_model_dir if local_model_dir else model_name),
|
@@ -265,7 +284,6 @@ def main(
|
265 | 284 | parser.add_argument(
|
266 | 285 | "--enable_qnn",
|
267 | 286 | "--enable-qnn",
|
268 |
| - action="store_true", |
269 | 287 | nargs="?",
|
270 | 288 | const=True,
|
271 | 289 | type=str,
|
|
0 commit comments