1414from QEfficient .utils .logging_utils import logger
1515
1616
17+ class QEffInternEncoderWrapper (nn .Module ):
18+ def __init__ (self , model ):
19+ super ().__init__ ()
20+ self .model = model
21+
22+ def forward (self , pixel_values ):
23+ vit_embeds = self .model .extract_feature (pixel_values )
24+ return vit_embeds
25+
26+
27+ class QEffInternDecoderWrapper (nn .Module ):
28+ def __init__ (self , model ):
29+ super ().__init__ ()
30+ self .model = model
31+ self .config = self .model .language_model .config
32+
33+ def forward (self , input_ids , vit_embeds , position_ids , past_key_values ):
34+ # TODO: Check if Hardcoding this is okay, i.e. check if this value is common for all intern models
35+ IMG_CONTEXT_TOKEN = 151667
36+
37+ input_embeds = self .model .language_model .get_input_embeddings ()(input_ids )
38+ B , N , C = input_embeds .shape
39+ image_input_embeds = input_embeds .reshape (B * N , C )
40+ image_input_ids = input_ids .reshape (B * N )
41+ selected = image_input_ids == IMG_CONTEXT_TOKEN
42+ indices1 = selected .unsqueeze (0 ).to (torch .int64 ).cumsum (1 ) - 1
43+ indices0 = torch .arange (selected .unsqueeze (0 ).shape [0 ]).view (- 1 , 1 )
44+ image_features_expanded = vit_embeds .reshape (- 1 , C ).unsqueeze (0 )[indices0 , indices1 ]
45+ image_input_embeds = torch .where (selected .unsqueeze (0 ).unsqueeze (- 1 ), image_features_expanded , input_embeds )
46+ inputs_embeds = torch .where (input_ids .shape [1 ] == torch .tensor (1 ), input_embeds , image_input_embeds )
47+ outputs = self .model .language_model (
48+ inputs_embeds = inputs_embeds , position_ids = position_ids , past_key_values = past_key_values , use_cache = True
49+ )
50+ return outputs .logits , vit_embeds , outputs .past_key_values
51+
52+
1753class QEffInternVLModel (nn .Module ):
54+ def get_qeff_vision_encoder (self ):
55+ return QEffInternEncoderWrapper (self )
56+
57+ def get_qeff_language_decoder (self ):
58+ return QEffInternDecoderWrapper (self )
59+
1860 def get_specializations (
19- self , batch_size : int , prefill_seq_len : int , ctx_len : int , img_size : int , ** compiler_options
61+ self ,
62+ batch_size : int ,
63+ prefill_seq_len : int ,
64+ ctx_len : int ,
65+ img_size : int ,
66+ kv_offload : bool = False ,
67+ ** compiler_options ,
2068 ):
2169 # TODO: check if this should be named num_patches or something else
2270 num_patches = compiler_options .pop ("num_patches" , None )
@@ -33,8 +81,18 @@ def get_specializations(
3381 elif img_size is None :
3482 img_size = 448
3583 logger .warning ("Setting img_size to be 448, as it was neither passed nor found in vision_config" )
36-
37- specializations = [
84+ if img_size != 448 and kv_offload :
85+ raise NotImplementedError ("Image Size other than 448 is not supported for Intern models yet." )
86+ vision = [
87+ {
88+ "batch_size" : batch_size ,
89+ "num_patches" : num_patches ,
90+ "img_size" : img_size ,
91+ "seq_len" : prefill_seq_len ,
92+ "ctx_len" : ctx_len ,
93+ }
94+ ]
95+ lang = [
3896 {
3997 "batch_size" : batch_size ,
4098 "seq_len" : prefill_seq_len ,
@@ -50,61 +108,92 @@ def get_specializations(
50108 "img_size" : img_size ,
51109 },
52110 ]
53- return specializations , compiler_options
54111
55- def get_onnx_dynamic_axes (
56- self ,
57- ):
112+ specializations = {}
113+
114+ if kv_offload :
115+ specializations ["vision" ] = vision
116+ specializations ["lang" ] = lang
117+ return specializations , compiler_options
118+ else :
119+ return lang , compiler_options
120+
121+ def get_onnx_dynamic_axes (self , kv_offload : bool = False ):
58122 # Define dynamic axes
59- dynamic_axes = {}
60- dynamic_axes ["input_ids" ] = {0 : "batch_size" , 1 : "seq_len" }
61- dynamic_axes ["position_ids" ] = {0 : "batch_size" , 1 : "seq_len" }
62- dynamic_axes ["pixel_values" ] = {0 : "num_patches" , 2 : "img_size" , 3 : "img_size" }
123+ vision_dynamic_axes = {}
124+ lang_dynamic_axes = {}
125+ lang_dynamic_axes ["input_ids" ] = {0 : "batch_size" , 1 : "seq_len" }
126+ lang_dynamic_axes ["position_ids" ] = {0 : "batch_size" , 1 : "seq_len" }
127+ vision_dynamic_axes ["pixel_values" ] = {0 : "num_patches" , 2 : "img_size" , 3 : "img_size" }
63128
64129 pkv_dynamic_axes = {0 : "batch_size" , 2 : "ctx_len" }
65130 for i in range (self .language_model .config .num_hidden_layers ):
66131 for kv in ["key" , "value" ]:
67- dynamic_axes [f"past_{ kv } .{ i } " ] = pkv_dynamic_axes
132+ lang_dynamic_axes [f"past_{ kv } .{ i } " ] = pkv_dynamic_axes
68133
134+ dynamic_axes = {}
135+ if kv_offload :
136+ dynamic_axes ["vision" ] = vision_dynamic_axes
137+ dynamic_axes ["lang" ] = lang_dynamic_axes
138+ else :
139+ dynamic_axes = {** vision_dynamic_axes , ** lang_dynamic_axes }
69140 return dynamic_axes
70141
71- def get_output_names (
72- self ,
73- ):
74- output_names = ["logits" , "pixel_values_RetainedState" ]
142+ def get_output_names (self , kv_offload : bool = False ):
143+ vision_output_names = ["vit_embeds" ]
144+ lang_output_names = ["logits" ]
75145 for i in range (self .language_model .config .num_hidden_layers ):
76146 for kv in ["key" , "value" ]:
77- output_names .append (f"past_{ kv } .{ i } _RetainedState" )
147+ lang_output_names .append (f"past_{ kv } .{ i } _RetainedState" )
148+
149+ output_names = {}
150+ if kv_offload :
151+ lang_output_names .insert (1 , "vit_embeds_RetainedState" )
152+ output_names ["vision" ] = vision_output_names
153+ output_names ["lang" ] = lang_output_names
154+ else :
155+ lang_output_names .insert (1 , "pixel_values_RetainedState" )
156+ return lang_output_names
78157 return output_names
79158
80159 def get_dummy_inputs (self , kv_offload : bool = False ):
81- if kv_offload :
82- raise ValueError ("kv_offload method not supported for InternVL yet!" )
83160 num_patches = 13
84161 C = 3
85162 if vis_cfg := getattr (self .config , "vision_config" , None ):
86163 img_size = getattr (vis_cfg , "image_size" , 448 )
87164 else :
88165 img_size = 448
166+ if img_size != 448 and kv_offload :
167+ raise NotImplementedError ("Image Size other than 448 is not supported for Intern models yet." )
168+
169+ # Taken from the modeling files of OpenGVLab/InternVL2_5-1B
170+ feature_size = int ((((self .config .vision_config .hidden_size ** 0.5 ) * self .config .downsample_ratio ) ** 2 ))
89171
90172 # Define shapes
91173 inputs_shapes = {}
92174 inputs_shapes ["input_ids" ] = (constants .ONNX_EXPORT_EXAMPLE_BATCH_SIZE , constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN )
175+ inputs_shapes ["vit_embeds" ] = (
176+ num_patches ,
177+ feature_size ,
178+ self .language_model .config .hidden_size ,
179+ )
93180 inputs_shapes ["position_ids" ] = (
94181 constants .ONNX_EXPORT_EXAMPLE_BATCH_SIZE ,
95182 constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN ,
96183 )
97184 inputs_shapes ["pixel_values" ] = (num_patches , C , img_size , img_size )
98185
99186 # Define inputs
100- inputs = {}
101- inputs ["input_ids" ] = torch .zeros ((inputs_shapes ["input_ids" ]), dtype = torch .int64 )
102- inputs ["position_ids" ] = (
187+ vision_inputs = {}
188+ lang_inputs = {}
189+ vision_inputs ["pixel_values" ] = torch .zeros ((inputs_shapes ["pixel_values" ]), dtype = torch .float32 )
190+ lang_inputs ["input_ids" ] = torch .zeros ((inputs_shapes ["input_ids" ]), dtype = torch .int64 )
191+ lang_inputs ["vit_embeds" ] = torch .zeros ((inputs_shapes ["vit_embeds" ]), dtype = torch .float32 )
192+ lang_inputs ["position_ids" ] = (
103193 torch .arange (constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN , dtype = torch .int64 )
104194 .view (1 , constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN )
105195 .repeat (constants .ONNX_EXPORT_EXAMPLE_BATCH_SIZE , 1 )
106196 )
107- inputs ["pixel_values" ] = torch .zeros ((inputs_shapes ["pixel_values" ]), dtype = torch .float32 )
108197
109198 # Add data for KV
110199 kv_cache_shape = get_padding_shape_from_config (
@@ -113,10 +202,18 @@ def get_dummy_inputs(self, kv_offload: bool = False):
113202 seq_len = constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN ,
114203 )
115204
116- inputs ["past_key_values" ] = [[] for _ in range (self .language_model .config .num_hidden_layers )]
205+ lang_inputs ["past_key_values" ] = [[] for _ in range (self .language_model .config .num_hidden_layers )]
117206 for i in range (self .language_model .config .num_hidden_layers ):
118207 for kv in ["key" , "value" ]:
119- inputs ["past_key_values" ][i ].append (torch .zeros (kv_cache_shape , dtype = torch .float32 ))
208+ lang_inputs ["past_key_values" ][i ].append (torch .zeros (kv_cache_shape , dtype = torch .float32 ))
209+
210+ inputs = {}
211+ if kv_offload :
212+ inputs ["vision" ] = vision_inputs
213+ inputs ["lang" ] = lang_inputs
214+ else :
215+ lang_inputs .pop ("vit_embeds" )
216+ inputs = {** vision_inputs , ** lang_inputs }
120217
121218 return inputs
122219
0 commit comments