@@ -236,7 +236,10 @@ def __getitem__(self, i) -> dict[str, torch.Tensor]:
236236
237237
238238def  make_eagle_supervised_data_module (
239-     tokenizer : transformers .PreTrainedTokenizer , data_args , use_offline_training : bool 
239+     tokenizer : transformers .PreTrainedTokenizer ,
240+     data_args ,
241+     use_offline_training : bool ,
242+     max_length = None ,
240243) ->  dict :
241244    """Make dataset and collator for supervised fine-tuning. 
242245
@@ -295,15 +298,15 @@ def make_eagle_supervised_data_module(
295298        train_dataset  =  dataset_cls (valid_entries [:num_train ], tokenizer = tokenizer )
296299        eval_dataset  =  dataset_cls (valid_entries [num_train :], tokenizer = tokenizer )
297300
298-         data_collator  =  DataCollatorForOffline ()
301+         data_collator  =  DataCollatorForOffline (max_length = max_length )
299302    else :
300303        print_rank_0 ("Loading input conversations..." )
301304        dataset_cls  =  LazySupervisedDataset  if  data_args .lazy_preprocess  else  SupervisedDataset 
302305
303306        train_dataset  =  dataset_cls (data_json [: int (len (data_json ) *  0.95 )], tokenizer = tokenizer )
304307        eval_dataset  =  dataset_cls (data_json [int (len (data_json ) *  0.95 ) :], tokenizer = tokenizer )
305308
306-         data_collator  =  DataCollatorWithPadding ()
309+         data_collator  =  DataCollatorWithPadding (max_length = max_length )
307310
308311    return  {
309312        "train_dataset" : train_dataset ,
@@ -313,6 +316,9 @@ def make_eagle_supervised_data_module(
313316
314317
315318class  DataCollatorWithPadding :
319+     def  __init__ (self , max_length ):
320+         self .max_length  =  max_length 
321+ 
316322    def  paddingtensor2d (self , intensors , length ):
317323        n , dim  =  intensors .shape 
318324        padding_tensor  =  torch .zeros (length  -  n , dim , dtype = intensors .dtype )
@@ -325,19 +331,18 @@ def paddingtensor(self, intensors, length):
325331        return  outtensors 
326332
327333    def  __call__ (self , features : list [dict [str , Any ]]) ->  dict [str , Any ]:
328-         max_length  =  max (item ["input_ids" ].shape [0 ] for  item  in  features )
329334        batch_input_ids  =  torch .stack (
330-             [self .paddingtensor (item ["input_ids" ], max_length ) for  item  in  features ]
335+             [self .paddingtensor (item ["input_ids" ], self . max_length ) for  item  in  features ]
331336        )
332337        batch_attention_mask  =  torch .stack (
333-             [self .paddingtensor (item ["attention_mask" ], max_length ) for  item  in  features ]
338+             [self .paddingtensor (item ["attention_mask" ], self . max_length ) for  item  in  features ]
334339        )
335340        batch_loss_mask  =  torch .stack (
336-             [self .paddingtensor (item ["loss_mask" ], max_length ) for  item  in  features ]
341+             [self .paddingtensor (item ["loss_mask" ], self . max_length ) for  item  in  features ]
337342        )
338343
339344        batch_labels  =  torch .stack (
340-             [self .paddingtensor (item ["labels" ], max_length ) for  item  in  features ]
345+             [self .paddingtensor (item ["labels" ], self . max_length ) for  item  in  features ]
341346        )
342347
343348        batch  =  {
@@ -357,16 +362,15 @@ def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
357362            raise  ValueError ("No kwargs found in batch features. Offline data required." )
358363
359364        features  =  [item ["kwargs" ]["base_model_outputs" ] for  item  in  features ]
360-         max_hs_length  =  max (item ["base_model_hidden_states" ].shape [0 ] for  item  in  features )
361365
362366        batch_hidden_states  =  torch .stack (
363367            [
364-                 self .paddingtensor2d (item ["base_model_hidden_states" ], max_hs_length )
368+                 self .paddingtensor2d (item ["base_model_hidden_states" ], self . max_length )
365369                for  item  in  features 
366370            ]
367371        )
368372        batch_aux_hidden_states  =  torch .stack (
369-             [self .paddingtensor2d (item ["aux_hidden_states" ], max_hs_length ) for  item  in  features ]
373+             [self .paddingtensor2d (item ["aux_hidden_states" ], self . max_length ) for  item  in  features ]
370374        )
371375
372376        batch  =  {
0 commit comments