1717from instructlab .training .config import (
1818 DataProcessArgs ,
1919 DistributedBackend ,
20+ LoraOptions ,
2021 TorchrunArgs ,
2122 TrainingArgs ,
2223)
2324from instructlab .training .main_ds import run_training
2425
2526MINIMAL_TRAINING_ARGS = {
2627 "max_seq_len" : 140 , # this config fits nicely on 4xL40s and may need modification for other setups
27- "max_batch_len" : 15000 ,
28+ "max_batch_len" : 5000 ,
2829 "num_epochs" : 1 ,
29- "effective_batch_size" : 3840 ,
30+ "effective_batch_size" : 128 ,
3031 "save_samples" : 0 ,
3132 "learning_rate" : 1e-4 ,
3233 "warmup_steps" : 1 ,
5152RUNNER_CPUS_EXPECTED = 4
5253
5354# Number of samples to randomly sample from the processed dataset for faster training
54- NUM_SAMPLES_TO_KEEP = 5000
55+ NUM_SAMPLES_TO_KEEP = 2500
5556
5657
5758@pytest .fixture (scope = "module" )
@@ -232,25 +233,36 @@ def cached_training_data(
232233@pytest .mark .parametrize (
233234 "dist_backend" , [DistributedBackend .FSDP , DistributedBackend .DEEPSPEED ]
234235)
235- @pytest .mark .parametrize ("cpu_offload" , [True , False ])
236+ @pytest .mark .parametrize ("cpu_offload" , [False , True ])
237+ @pytest .mark .parametrize ("lora_rank" , [0 ])
238+ @pytest .mark .parametrize ("use_liger" , [False , True ])
236239def test_training_feature_matrix (
237240 cached_test_model : pathlib .Path ,
238241 cached_training_data : pathlib .Path ,
239242 checkpoint_dir : pathlib .Path ,
240243 prepared_data_dir : pathlib .Path ,
244+ use_liger : bool ,
245+ lora_rank : int ,
241246 cpu_offload : bool ,
242247 dist_backend : DistributedBackend ,
243248) -> None :
249+ torch_args = TorchrunArgs (** DEFAULT_TORCHRUN_ARGS )
244250 train_args = TrainingArgs (
245251 model_path = str (cached_test_model ),
246252 data_path = str (cached_training_data ),
247253 data_output_dir = str (prepared_data_dir ),
248254 ckpt_output_dir = str (checkpoint_dir ),
255+ lora = LoraOptions (rank = lora_rank ),
256+ use_liger = use_liger ,
249257 ** MINIMAL_TRAINING_ARGS ,
250258 )
251259
252260 train_args .distributed_backend = dist_backend
253261
262+ if lora_rank > 0 :
263+ # LoRA doesn't support full state saving.
264+ train_args .accelerate_full_state_at_epoch = False
265+
254266 if dist_backend == DistributedBackend .FSDP :
255267 train_args .fsdp_options .cpu_offload_params = cpu_offload
256268 else :
@@ -259,6 +271,4 @@ def test_training_feature_matrix(
259271 pytest .xfail ("DeepSpeed CPU Adam isn't currently building correctly" )
260272 train_args .deepspeed_options .cpu_offload_optimizer = cpu_offload
261273
262- torch_args = TorchrunArgs (** DEFAULT_TORCHRUN_ARGS )
263-
264274 run_training (torch_args = torch_args , train_args = train_args )
0 commit comments