Skip to content

Commit b7e926e

Browse files
quic-dhirajkuSwati Allabadi
authored andcommitted
BUG FIX - get_model_configs() fix for InternVL and Llava (#312)
made few changes in the modeling files of both models for this method to now work appropriately. Signed-off-by: quic-dhirajku <[email protected]>
1 parent 5885037 commit b7e926e

File tree

2 files changed

+3
-0
lines changed

2 files changed

+3
-0
lines changed

QEfficient/transformers/models/internvl/modeling_internvl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def __init__(self, model):
2929
super().__init__()
3030
self.model = model
3131
self.config = self.model.language_model.config
32+
self.language_model = self.model.language_model
3233

3334
def forward(self, input_ids, vit_embeds, position_ids, past_key_values):
3435
# TODO: Check if Hardcoding this is okay, i.e. check if this value is common for all intern models

QEfficient/transformers/models/llava/modeling_llava.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class QEFFLlavaEncoderWrapper(nn.Module):
2525
def __init__(self, model):
2626
super().__init__()
2727
self.model = model
28+
self.model.vision_model = self.model.vision_tower
2829

2930
def forward(self, pixel_values):
3031
# Image features
@@ -47,6 +48,7 @@ def __init__(self, model):
4748
super().__init__()
4849
self.model = model
4950
self.config = self.model.config
51+
self.language_model = self.model.language_model
5052

5153
def forward(self, input_ids, image_features, position_ids, past_key_values):
5254
inputs_embeds = self.model.get_input_embeddings()(input_ids)

0 commit comments

Comments
 (0)