Skip to content

Commit ca55d42

Browse files
committed
Addressed Comments
Signed-off-by: Asmita Goswami <[email protected]>
1 parent 1608804 commit ca55d42

File tree

1 file changed

+44
-26
lines changed

1 file changed

+44
-26
lines changed

Diff for: QEfficient/cloud/infer.py

+44-26
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,48 @@
1212

1313
import requests
1414
from PIL import Image
15-
from transformers import AutoProcessor, TextStreamer
15+
from transformers import AutoProcessor, TextStreamer, PreTrainedModel
1616
from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
1717

1818
from QEfficient.base.common import QEFFCommonLoader
1919
from QEfficient.utils import check_and_assign_cache_dir, constants, load_hf_tokenizer
2020
from QEfficient.utils.logging_utils import logger
2121

22+
def execute_vlm_model(
23+
qeff_model: PreTrainedModel,
24+
model_name: str,
25+
image_url: str,
26+
image_path: str,
27+
prompt: Optional[str] = None, #type: ignore
28+
device_group: Optional[List[int]] = None,
29+
generation_len: Optional[int] = None,
30+
):
31+
if not (image_url or image_path):
32+
raise ValueError('Neither Image URL nor Image Path is found, either provide "image_url" or "image_path"')
33+
raw_image = Image.open(requests.get(image_url, stream=True).raw) if image_url else Image.open(image_path)
34+
35+
processor = AutoProcessor.from_pretrained(model_name, use_fast=False)
36+
37+
conversation = constants.Constants.conversation
38+
conversation[0]["content"][1].update({"text": prompt[0]}) # Currently accepting only 1 prompt
39+
40+
# Converts a list of dictionaries with `"role"` and `"content"` keys to a list of token ids.
41+
input_text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
42+
43+
split_inputs = processor(
44+
text=input_text,
45+
images=raw_image,
46+
return_tensors="pt",
47+
add_special_tokens=False,
48+
)
49+
streamer = TextStreamer(processor.tokenizer)
50+
output = qeff_model.generate(
51+
inputs=split_inputs,
52+
streamer=streamer,
53+
device_ids=device_group,
54+
generation_len=generation_len,
55+
)
56+
return output
2257

2358
def main(
2459
model_name: str,
@@ -130,32 +165,16 @@ def main(
130165
# Execute
131166
#########
132167
if architecture in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values():
133-
processor = AutoProcessor.from_pretrained(model_name, use_fast=False)
134-
135-
if not (image_url or image_path):
136-
raise ValueError('Neither Image URL nor Image Path is found, either provide "image_url" or "image_path"')
137-
raw_image = Image.open(requests.get(image_url, stream=True).raw) if image_url else Image.open(image_path)
138-
139-
conversation = constants.Constants.conversation
140-
conversation[0]["content"][1].update({"text": prompt[0]}) # Currently accepting only 1 prompt
141-
142-
# Converts a list of dictionaries with `"role"` and `"content"` keys to a list of token ids.
143-
input_text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
144-
145-
split_inputs = processor(
146-
text=input_text,
147-
images=raw_image,
148-
return_tensors="pt",
149-
add_special_tokens=False,
150-
)
151-
streamer = TextStreamer(processor.tokenizer)
152-
output = qeff_model.generate(
153-
inputs=split_inputs,
154-
streamer=streamer,
155-
device_ids=device_group,
168+
exec_info = execute_vlm_model(
169+
qeff_model=qeff_model,
170+
model_name=model_name,
171+
prompt=prompt,
172+
image_url=image_url,
173+
image_path=image_path,
174+
device_group=device_group,
156175
generation_len=generation_len,
157176
)
158-
print(output)
177+
print(exec_info)
159178
else:
160179
tokenizer = load_hf_tokenizer(
161180
pretrained_model_name_or_path=(local_model_dir if local_model_dir else model_name),
@@ -265,7 +284,6 @@ def main(
265284
parser.add_argument(
266285
"--enable_qnn",
267286
"--enable-qnn",
268-
action="store_true",
269287
nargs="?",
270288
const=True,
271289
type=str,

0 commit comments

Comments
 (0)