Skip to content

Commit bd465a2

Browse files
authored
Onboarding Granite Vision (#359)
Adding LlavaNext to model of QEfficient transformers HF(Original Model) pytorch output The highest scoring model on ChartQA is Granite Vision with a score of 0.87. QEff Model pytorch output The highest scoring model on ChartQA is Granite Vision with a score of 0.87. QEff Model ORT output The highest scoring model on ChartQA is Granite Vision with a score of 0.87. QEff Model AIC output The highest scoring model on ChartQA is Granite Vision with a score of 0.87. --------- Signed-off-by: Dipankar Sarkar <[email protected]>
1 parent c99e105 commit bd465a2

File tree

7 files changed

+524
-5
lines changed

7 files changed

+524
-5
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
Lines changed: 347 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,347 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
8+
9+
import numpy as np
10+
import torch
11+
import torch.nn as nn
12+
from transformers.models.llava_next.modeling_llava_next import (
13+
LlavaNextForConditionalGeneration,
14+
get_anyres_image_grid_shape,
15+
)
16+
17+
from QEfficient.utils import constants
18+
from QEfficient.utils._utils import IOInfo
19+
from QEfficient.utils.logging_utils import logger
20+
21+
22+
class QEffLlavaNextEncoderWrapper(nn.Module):
23+
def __init__(self, model):
24+
super().__init__()
25+
self.model = model
26+
self.model.vision_model = self.model.vision_tower
27+
28+
def forward(self, pixel_values, image_sizes):
29+
if pixel_values.dim() == constants.GRANITEVISION_PIXEL_VALUE_DIM:
30+
pixel_values_new = pixel_values.squeeze(0)
31+
32+
image_feature = self.model.vision_tower(pixel_values_new, output_hidden_states=True)
33+
if isinstance(self.model.config.vision_feature_layer, int):
34+
selected_image_feature = image_feature.hidden_states[self.model.config.vision_feature_layer]
35+
else:
36+
hs_pool = [image_feature.hidden_states[layer_idx] for layer_idx in self.model.config.vision_feature_layer]
37+
selected_image_feature = torch.cat(hs_pool, dim=-1)
38+
39+
vision_feature_select_strategy = self.model.config.vision_feature_select_strategy
40+
if vision_feature_select_strategy == "default":
41+
selected_image_feature = selected_image_feature[:, 1:]
42+
elif vision_feature_select_strategy == "full":
43+
selected_image_feature = selected_image_feature
44+
else:
45+
raise ValueError(f"Unexpected select feature strategy: {self.model.config.vision_feature_select_strategy}")
46+
image_features = self.model.multi_modal_projector(selected_image_feature)
47+
image_features = torch.split(image_features, [image_features.shape[0]], dim=0)
48+
new_image_features = []
49+
50+
# Image feature
51+
for image_idx, image_feature in enumerate(image_features):
52+
if image_feature.shape[0] > 1:
53+
base_image_feature = image_feature[0]
54+
image_feature = image_feature[1:]
55+
height = width = (
56+
self.model.config.vision_config.image_size // self.model.config.vision_config.patch_size
57+
)
58+
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
59+
image_sizes[image_idx],
60+
self.model.config.image_grid_pinpoints,
61+
self.model.config.vision_config.image_size,
62+
)
63+
64+
if (
65+
np.prod(image_feature.shape) % (num_patch_height * num_patch_width * height * width) != 0
66+
and vision_feature_select_strategy == "default"
67+
):
68+
logger.warning_once(
69+
"Image feature shape does not line up with the provided patch size. "
70+
"You may be using the `default` vision_feature_select_strategy with a"
71+
" visual encoder that does not have CLS."
72+
)
73+
74+
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
75+
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
76+
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
77+
78+
if not isinstance(image_sizes[image_idx], (list, tuple)):
79+
if not isinstance(image_sizes[image_idx], (torch.Tensor, np.ndarray)):
80+
raise TypeError(
81+
f"image_size invalid type: {type(image_sizes[image_idx])} not valid, should be either list, tuple, np.ndarray or tensor"
82+
)
83+
original_size = image_sizes[image_idx].tolist()
84+
original_height, original_width = original_size
85+
current_height, current_width = image_feature.shape[1:]
86+
87+
if torch.is_tensor(current_height):
88+
current_height = current_height.item()
89+
current_width = current_width.item()
90+
91+
scale_factor = current_width / original_width
92+
new_height = int(round(original_height * scale_factor, 7))
93+
padding = (current_height - new_height) // 2
94+
image_feature = image_feature[:, padding : current_height - padding, :]
95+
if self.model.image_newline is not None:
96+
image_feature = torch.cat(
97+
(
98+
image_feature,
99+
self.model.image_newline[:, None, None]
100+
.expand(*image_feature.shape[:-1], 1)
101+
.to(image_feature.device, image_feature.dtype),
102+
),
103+
dim=-1,
104+
)
105+
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
106+
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
107+
else:
108+
image_feature = image_feature[0]
109+
if self.model.image_newline is not None:
110+
image_feature = torch.cat((image_feature, self.model.image_newline[None].to(image_feature)), dim=0)
111+
new_image_features.append(image_feature)
112+
image_features = torch.cat(new_image_features, dim=0)
113+
return image_features
114+
115+
116+
class QEffLlavaNextDecoderWrapper(nn.Module):
117+
def __init__(self, model):
118+
super().__init__()
119+
self.model = model
120+
self.config = self.model.config
121+
self.language_model = self.model.language_model
122+
123+
def forward(self, input_ids, image_features, position_ids, past_key_values):
124+
inputs_embeds = self.model.get_input_embeddings()(input_ids)
125+
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
126+
mask = input_ids == self.config.image_token_index
127+
indices1 = mask.to(torch.int64).cumsum(1) - 1
128+
image_features_expanded = image_features[indices1]
129+
image_inputs_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds)
130+
# *where to skip image encoder for decode*
131+
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_inputs_embeds)
132+
outputs = self.language_model(
133+
inputs_embeds=inputs_embeds,
134+
position_ids=position_ids,
135+
past_key_values=past_key_values,
136+
)
137+
return outputs.logits, image_features, outputs.past_key_values
138+
139+
140+
class QEffLlavaNextForConditionalGeneration(LlavaNextForConditionalGeneration):
141+
def get_qeff_vision_encoder(self):
142+
return QEffLlavaNextEncoderWrapper(self)
143+
144+
def get_qeff_language_decoder(self):
145+
return QEffLlavaNextDecoderWrapper(self)
146+
147+
def get_dummy_inputs(self, kv_offload: bool = False, **kwargs):
148+
num_layers = self.config.text_config.num_hidden_layers
149+
num_key_value_heads = self.config.text_config.num_key_value_heads
150+
head_dim = self.config.text_config.hidden_size // self.config.text_config.num_attention_heads
151+
if vis_cfg := getattr(self.config, "vision_config", None):
152+
img_size = getattr(vis_cfg, "image_size", constants.GRANITEVISION_IMG_SIZE)
153+
else:
154+
img_size = constants.GRANITEVISION_IMG_SIZE
155+
if img_size != constants.GRANITEVISION_IMG_SIZE and kv_offload:
156+
raise NotImplementedError("Image Size other than 384 is not supported for LlavaNext models yet.")
157+
vision_inputs = {
158+
"pixel_values": torch.zeros(
159+
(
160+
constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
161+
constants.GRANITEVISION_NUM_PATCHES,
162+
constants.GRANITEVISION_NUM_CHANNELS,
163+
constants.GRANITEVISION_IMG_SIZE,
164+
constants.GRANITEVISION_IMG_SIZE,
165+
),
166+
dtype=torch.float32,
167+
),
168+
"image_sizes": torch.tensor(
169+
[[constants.GRANITEVISION_IMG_SIZE_HEIGHT, constants.GRANITEVISION_IMG_SIZE_WIDTH]], dtype=torch.int64
170+
),
171+
}
172+
lang_inputs = {
173+
"input_ids": torch.ones(
174+
(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.GRANITEVISION_SEQ_LEN), dtype=torch.int64
175+
),
176+
"attention_mask": torch.ones(
177+
(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.GRANITEVISION_SEQ_LEN), dtype=torch.int64
178+
),
179+
"image_features": torch.ones(
180+
(constants.GRANITEVISION_FEATURE_SIZE, self.language_model.config.hidden_size), dtype=torch.float32
181+
),
182+
}
183+
lang_inputs["position_ids"] = lang_inputs.pop("attention_mask").cumsum(1)
184+
lang_inputs["past_key_values"] = []
185+
for i in range(num_layers):
186+
lang_inputs["past_key_values"].append(
187+
(
188+
torch.zeros(
189+
constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
190+
num_key_value_heads,
191+
constants.GRANITEVISION_CTX_LEN,
192+
head_dim,
193+
),
194+
torch.zeros(
195+
constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
196+
num_key_value_heads,
197+
constants.GRANITEVISION_CTX_LEN,
198+
head_dim,
199+
),
200+
)
201+
)
202+
lang_inputs["position_ids"] = torch.full(lang_inputs["position_ids"].shape, constants.GRANITEVISION_CTX_LEN - 1)
203+
inputs = {}
204+
if kv_offload:
205+
inputs["vision"] = vision_inputs
206+
inputs["lang"] = lang_inputs
207+
else:
208+
lang_inputs.pop("image_features")
209+
inputs = {**vision_inputs, **lang_inputs}
210+
return inputs
211+
212+
def get_specializations(
213+
self,
214+
batch_size: int,
215+
prefill_seq_len: int,
216+
ctx_len: int,
217+
img_size: int,
218+
kv_offload: bool = False,
219+
**compiler_options,
220+
):
221+
max_num_images = compiler_options.pop("max_num_images", 1)
222+
num_patches = compiler_options.pop("num_patches", None)
223+
image_size_height = compiler_options.pop("image_size_height", None)
224+
image_size_width = compiler_options.pop("image_size_width", None)
225+
226+
if num_patches is None:
227+
num_patches = constants.GRANITEVISION_NUM_PATCHES
228+
if image_size_height is None:
229+
image_size_height = constants.GRANITEVISION_IMG_SIZE_HEIGHT
230+
if image_size_width is None:
231+
image_size_width = constants.GRANITEVISION_IMG_SIZE_WIDTH
232+
233+
if num_patches != constants.GRANITEVISION_NUM_PATCHES:
234+
logger.warning("Image Num Patches should be set to 10")
235+
num_patches = constants.GRANITEVISION_NUM_PATCHES
236+
237+
if image_size_height != constants.GRANITEVISION_IMG_SIZE_HEIGHT:
238+
logger.warning(
239+
"Image Size Height Should be fixed to 1109. Please Reshape the image to (w x h) (1610 x 1109)"
240+
)
241+
image_size_height = constants.GRANITEVISION_IMG_SIZE_HEIGHT
242+
243+
if image_size_width != constants.GRANITEVISION_IMG_SIZE_WIDTH:
244+
logger.warning(
245+
"Image Size Width Should be fixed to 1610. Please Reshape the image to (w x h) (1610 x 1109)"
246+
)
247+
image_size_width = constants.GRANITEVISION_IMG_SIZE_WIDTH
248+
249+
prefill_seq_len = prefill_seq_len if prefill_seq_len else constants.GRANITEVISION_SEQ_LEN
250+
ctx_len = ctx_len if ctx_len else constants.GRANITEVISION_CTX_LEN
251+
if not kv_offload:
252+
raise NotImplementedError("We currently support on Dual QPC for this model please set kv_offload to True")
253+
if img_size is None and hasattr(self.config.vision_config, "image_size"):
254+
img_size = getattr(self.config.vision_config, "image_size")
255+
elif img_size is None:
256+
img_size = constants.GRANITEVISION_IMG_SIZE
257+
logger.warning("Setting img_size to be 384, as it was neither passed nor found in vision_config")
258+
if img_size != constants.GRANITEVISION_IMG_SIZE and kv_offload:
259+
logger.warning("Image Size other than 384 is not supported for LlavaNext models yet.")
260+
vision = [
261+
{
262+
"batch_size": batch_size,
263+
"seq_len": prefill_seq_len,
264+
"ctx_len": ctx_len,
265+
"image_size_height": image_size_height,
266+
"image_size_width": image_size_width,
267+
"num_patches": num_patches,
268+
"max_num_images": max_num_images,
269+
"img_size": img_size,
270+
}
271+
]
272+
lang = [
273+
{
274+
"batch_size": batch_size,
275+
"seq_len": prefill_seq_len,
276+
"ctx_len": ctx_len,
277+
"image_size_height": image_size_height,
278+
"image_size_width": image_size_width,
279+
"num_patches": num_patches,
280+
"max_num_images": max_num_images,
281+
"img_size": img_size,
282+
},
283+
{
284+
"batch_size": batch_size,
285+
"seq_len": "1",
286+
"ctx_len": ctx_len,
287+
"image_size_height": image_size_height,
288+
"image_size_width": image_size_width,
289+
"num_patches": num_patches,
290+
"max_num_images": max_num_images,
291+
"img_size": img_size,
292+
},
293+
]
294+
specializations = {}
295+
if kv_offload:
296+
specializations["vision"] = vision
297+
specializations["lang"] = lang
298+
return specializations, compiler_options
299+
else:
300+
return lang, compiler_options
301+
302+
def get_onnx_dynamic_axes(self, kv_offload: bool = False):
303+
# Define dynamic axes
304+
num_layers = self.config.text_config.num_hidden_layers
305+
vision_dynamic_axes = {
306+
"pixel_values": {0: "batch_size", 1: "num_patches", 3: "img_size", 4: "img_size"},
307+
"image_sizes": {0: "image_size_height", 1: "image_size_width"},
308+
}
309+
lang_dynamic_axes = {
310+
"input_ids": {0: "batch_size", 1: "seq_len"},
311+
"position_ids": {0: "batch_size", 1: "seq_len"},
312+
}
313+
for i in range(num_layers):
314+
lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"}
315+
lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"}
316+
dynamic_axes = {}
317+
if kv_offload:
318+
dynamic_axes["vision"] = vision_dynamic_axes
319+
dynamic_axes["lang"] = lang_dynamic_axes
320+
else:
321+
dynamic_axes = {**vision_dynamic_axes, **lang_dynamic_axes}
322+
return dynamic_axes
323+
324+
def get_output_names(self, kv_offload: bool = False):
325+
vision_output_names = ["image_features"]
326+
lang_output_names = ["logits"]
327+
for i in range(self.language_model.config.num_hidden_layers):
328+
for kv in ["key", "value"]:
329+
lang_output_names.append(f"past_{kv}.{i}_RetainedState")
330+
331+
output_names = {}
332+
if kv_offload:
333+
lang_output_names.insert(1, "image_features_RetainedState")
334+
output_names["vision"] = vision_output_names
335+
output_names["lang"] = lang_output_names
336+
else:
337+
lang_output_names.insert(1, "pixel_values_RetainedState")
338+
return lang_output_names
339+
return output_names
340+
341+
def get_inputs_info(self):
342+
return [
343+
IOInfo(name="input_ids", datatype=torch.int64, shape=("batch_size", "seq_len")),
344+
IOInfo(name="attention_mask", datatype=torch.int64, shape=("batch_size", "seq_len")),
345+
IOInfo(name="pixel_values", datatype=torch.float32, shape=("batch_size", 10, 3, "img_size", "img_size")),
346+
IOInfo(name="image_sizes", datatype=torch.int64, shape=(1109, 1610)),
347+
]

QEfficient/transformers/models/pytorch_transforms.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -----------------------------------------------------------------------------
22
#
3-
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
3+
# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved.
44
# SPDX-License-Identifier: BSD-3-Clause
55
#
66
# -----------------------------------------------------------------------------
@@ -66,7 +66,12 @@
6666
LlamaModel,
6767
LlamaRMSNorm,
6868
)
69-
from transformers.models.llava.modeling_llava import LlavaForConditionalGeneration
69+
from transformers.models.llava.modeling_llava import (
70+
LlavaForConditionalGeneration,
71+
)
72+
from transformers.models.llava_next.modeling_llava_next import (
73+
LlavaNextForConditionalGeneration,
74+
)
7075
from transformers.models.mistral.modeling_mistral import (
7176
MistralAttention,
7277
MistralDecoderLayer,
@@ -191,7 +196,12 @@
191196
QEffLlamaForCausalLM,
192197
QEffLlamaModel,
193198
)
194-
from QEfficient.transformers.models.llava.modeling_llava import QEffLlavaForConditionalGeneration
199+
from QEfficient.transformers.models.llava.modeling_llava import (
200+
QEffLlavaForConditionalGeneration,
201+
)
202+
from QEfficient.transformers.models.llava_next.modeling_llava_next import (
203+
QEffLlavaNextForConditionalGeneration,
204+
)
195205
from QEfficient.transformers.models.mistral.modeling_mistral import (
196206
QEffMistralAttention,
197207
QEffMistralDecoderLayer,
@@ -303,6 +313,8 @@ class KVCacheTransform(ModuleMappingTransform):
303313
LlamaForCausalLM: QEffLlamaForCausalLM,
304314
# Llava
305315
LlavaForConditionalGeneration: QEffLlavaForConditionalGeneration,
316+
# Llava Next
317+
LlavaNextForConditionalGeneration: QEffLlavaNextForConditionalGeneration,
306318
# Gemma
307319
GemmaAttention: QEffGemmaAttention,
308320
GemmaDecoderLayer: QEffGemmaDecoderLayer,

0 commit comments

Comments
 (0)