@@ -105,6 +105,7 @@ class HvConfig(BaseConfig):
105105 skip_load_from_peers : bool = False
106106 world_rank : int
107107 galaxy_size : int
108+ warmup_outerstep : int = 10
108109
109110 @model_validator (mode = "before" )
110111 def cast_str_to_list (cls , values : dict [str , Any ]) -> dict [str , Any ]:
@@ -190,8 +191,18 @@ def get_model(config: Config) -> LlamaForCausalLM:
190191
191192
192193def _get_cosine_schedule_with_warmup_lr_lambda (
193- current_step : int , * , num_warmup_steps : int , num_training_steps : int , num_cycles : float , min_lr_rate : float = 0.0
194+ current_step : int ,
195+ * ,
196+ num_warmup_steps : int ,
197+ num_training_steps : int ,
198+ num_inner_steps : int ,
199+ warmup_outerstep : int | None ,
200+ num_cycles : float ,
201+ min_lr_rate : float = 0.0 ,
194202):
203+ if warmup_outerstep is not None and current_step % num_inner_steps < warmup_outerstep :
204+ return 0
205+
195206 if current_step < num_warmup_steps :
196207 return float (current_step ) / float (max (1 , num_warmup_steps ))
197208 progress = float (current_step - num_warmup_steps ) / float (max (1 , num_training_steps - num_warmup_steps ))
@@ -200,11 +211,13 @@ def _get_cosine_schedule_with_warmup_lr_lambda(
200211 return max (0 , factor )
201212
202213
203- def get_cosine_schedule_with_warmup (optimizer , num_warmup_steps , num_training_steps , num_inner_steps ):
214+ def get_cosine_schedule_with_warmup (optimizer , config : Config ):
204215 lambda_lr = partial (
205216 _get_cosine_schedule_with_warmup_lr_lambda ,
206- num_warmup_steps = num_warmup_steps ,
207- num_training_steps = num_training_steps ,
217+ num_warmup_steps = config .warmup_steps ,
218+ num_training_steps = config .total_steps ,
219+ num_inner_steps = config .hv .local_steps ,
220+ warmup_outerstep = config .hv .warmup_outerstep ,
208221 num_cycles = 0.5 ,
209222 )
210223 return LambdaLR (optimizer , lambda_lr , - 1 )
@@ -301,9 +314,7 @@ def train(config: Config):
301314 def scheduler_fn (opt ):
302315 return get_cosine_schedule_with_warmup (
303316 opt ,
304- num_warmup_steps = config .warmup_steps ,
305- num_training_steps = config .total_steps ,
306- num_inner_steps = config .hv .local_steps ,
317+ config = config ,
307318 )
308319
309320 if config .hv is not None :
0 commit comments