@@ -32,14 +32,11 @@ def __init__(self, model):
3232 self .language_model = self .model .language_model
3333
3434 def forward (self , input_ids , vit_embeds , position_ids , past_key_values ):
35- # TODO: Check if Hardcoding this is okay, i.e. check if this value is common for all intern models
36- IMG_CONTEXT_TOKEN = 151667
37-
3835 input_embeds = self .model .language_model .get_input_embeddings ()(input_ids )
3936 B , N , C = input_embeds .shape
4037 image_input_embeds = input_embeds .reshape (B * N , C )
4138 image_input_ids = input_ids .reshape (B * N )
42- selected = image_input_ids == IMG_CONTEXT_TOKEN
39+ selected = image_input_ids == constants . INTERN_IMG_CONTEXT_TOKEN
4340 indices1 = selected .unsqueeze (0 ).to (torch .int64 ).cumsum (1 ) - 1
4441 indices0 = torch .arange (selected .unsqueeze (0 ).shape [0 ]).view (- 1 , 1 )
4542 image_features_expanded = vit_embeds .reshape (- 1 , C ).unsqueeze (0 )[indices0 , indices1 ]
@@ -73,16 +70,16 @@ def get_specializations(
7370 logger .warning (
7471 "User should pass `num_patches` to compile API to fix the dynamic axes `pixel_values`, you can get more info by calling get_inputs_info function!, Since its not found setting its value to 13"
7572 )
76- num_patches = 13
73+ num_patches = constants . INTERN_NUM_PATCHES
7774
78- prefill_seq_len = prefill_seq_len if prefill_seq_len else 3840 # 4096-256
79- ctx_len = ctx_len if ctx_len else 4096
75+ prefill_seq_len = prefill_seq_len if prefill_seq_len else constants . INTERN_PREFILL_SEQ_LEN # 4096-256
76+ ctx_len = ctx_len if ctx_len else constants . INTERN_CTX_LEN
8077 if img_size is None and hasattr (self .config .vision_config , "image_size" ):
8178 img_size = getattr (self .config .vision_config , "image_size" )
8279 elif img_size is None :
83- img_size = 448
80+ img_size = constants . INTERN_IMG_SIZE
8481 logger .warning ("Setting img_size to be 448, as it was neither passed nor found in vision_config" )
85- if img_size != 448 and kv_offload :
82+ if img_size != constants . INTERN_IMG_SIZE and kv_offload :
8683 raise NotImplementedError ("Image Size other than 448 is not supported for Intern models yet." )
8784 vision = [
8885 {
@@ -159,31 +156,40 @@ def get_output_names(self, kv_offload: bool = False):
159156 return output_names
160157
161158 def get_dummy_inputs (self , kv_offload : bool = False ):
162- num_patches = 13
163- C = 3
164159 if vis_cfg := getattr (self .config , "vision_config" , None ):
165- img_size = getattr (vis_cfg , "image_size" , 448 )
160+ img_size = getattr (vis_cfg , "image_size" , constants . INTERN_IMG_SIZE )
166161 else :
167- img_size = 448
168- if img_size != 448 and kv_offload :
162+ img_size = constants . INTERN_IMG_SIZE
163+ if img_size != constants . INTERN_IMG_SIZE and kv_offload :
169164 raise NotImplementedError ("Image Size other than 448 is not supported for Intern models yet." )
170165
171- # Taken from the modeling files of OpenGVLab/InternVL2_5-1B
172- feature_size = int ((((self .config .vision_config .hidden_size ** 0.5 ) * self .config .downsample_ratio ) ** 2 ))
166+ patch_size = getattr (self .config .vision_config , "patch_size" , None )
167+ downsample_ratio = getattr (self .config , "downsample_ratio" , None )
168+ if patch_size and downsample_ratio :
169+ computed_feature_size = int (((img_size / patch_size ) * downsample_ratio ) ** 2 )
170+ if computed_feature_size != constants .INTERN_FEATURE_SIZE :
171+ logger .warning (
172+ "Discrepancy detected between estimated and actual feature sizes. Could impact on functionality or accuracy"
173+ )
173174
174175 # Define shapes
175176 inputs_shapes = {}
176177 inputs_shapes ["input_ids" ] = (constants .ONNX_EXPORT_EXAMPLE_BATCH_SIZE , constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN )
177178 inputs_shapes ["vit_embeds" ] = (
178- num_patches ,
179- feature_size ,
179+ constants . INTERN_NUM_PATCHES ,
180+ constants . INTERN_FEATURE_SIZE ,
180181 self .language_model .config .hidden_size ,
181182 )
182183 inputs_shapes ["position_ids" ] = (
183184 constants .ONNX_EXPORT_EXAMPLE_BATCH_SIZE ,
184185 constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN ,
185186 )
186- inputs_shapes ["pixel_values" ] = (num_patches , C , img_size , img_size )
187+ inputs_shapes ["pixel_values" ] = (
188+ constants .INTERN_NUM_PATCHES ,
189+ constants .INTERN_NUM_CHANNELS ,
190+ img_size ,
191+ img_size ,
192+ )
187193
188194 # Define inputs
189195 vision_inputs = {}
@@ -220,15 +226,12 @@ def get_dummy_inputs(self, kv_offload: bool = False):
220226 return inputs
221227
222228 def forward (self , input_ids , pixel_values , position_ids , past_key_values ):
223- # TODO: Check if Hardcoding this is okay, i.e. check if this value is common for all intern models
224- IMG_CONTEXT_TOKEN = 151667
225-
226229 input_embeds = self .language_model .get_input_embeddings ()(input_ids )
227230 vit_embeds = self .extract_feature (pixel_values )
228231 B , N , C = input_embeds .shape
229232 image_input_embeds = input_embeds .reshape (B * N , C )
230233 image_input_ids = input_ids .reshape (B * N )
231- selected = image_input_ids == IMG_CONTEXT_TOKEN
234+ selected = image_input_ids == constants . INTERN_IMG_CONTEXT_TOKEN
232235 indices1 = selected .unsqueeze (0 ).to (torch .int64 ).cumsum (1 ) - 1
233236 indices0 = torch .arange (selected .unsqueeze (0 ).shape [0 ]).view (- 1 , 1 )
234237 image_features_expanded = vit_embeds .reshape (- 1 , C ).unsqueeze (0 )[indices0 , indices1 ]
0 commit comments