|
7 | 7 | """ |
8 | 8 |
|
9 | 9 | from functools import partial |
| 10 | +import math |
10 | 11 | import os |
11 | 12 | import time |
12 | 13 | from contextlib import nullcontext |
|
28 | 29 | DataCollatorForLanguageModeling, |
29 | 30 | LlamaConfig, |
30 | 31 | LlamaForCausalLM, |
31 | | - get_cosine_schedule_with_warmup, |
32 | 32 | ) |
33 | 33 | from torch.distributed.fsdp import ( |
34 | 34 | FullyShardedDataParallel as FSDP, |
|
39 | 39 | from torch.distributed import broadcast_object_list |
40 | 40 | from open_diloco.ckpt_utils import load_checkpoint, save_checkpoint |
41 | 41 | from open_diloco.hivemind_diloco import AllReduceStrategy, DiLoCoOptimizer |
42 | | - |
| 42 | +from torch.optim.lr_scheduler import LambdaLR |
43 | 43 |
|
44 | 44 | from hivemind.dht.dht import DHT |
45 | 45 | from hivemind.utils.networking import log_visible_maddrs |
@@ -189,6 +189,27 @@ def get_model(config: Config) -> LlamaForCausalLM: |
189 | 189 | return LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path=config.path_model, config=config_model) |
190 | 190 |
|
191 | 191 |
|
| 192 | +def _get_cosine_schedule_with_warmup_lr_lambda( |
| 193 | + current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float, min_lr_rate: float = 0.0 |
| 194 | +): |
| 195 | + if current_step < num_warmup_steps: |
| 196 | + return float(current_step) / float(max(1, num_warmup_steps)) |
| 197 | + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) |
| 198 | + factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) |
| 199 | + factor = factor * (1 - min_lr_rate) + min_lr_rate |
| 200 | + return max(0, factor) |
| 201 | + |
| 202 | + |
| 203 | +def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_inner_steps): |
| 204 | + lambda_lr = partial( |
| 205 | + _get_cosine_schedule_with_warmup_lr_lambda, |
| 206 | + num_warmup_steps=num_warmup_steps, |
| 207 | + num_training_steps=num_training_steps, |
| 208 | + num_cycles=0.5, |
| 209 | + ) |
| 210 | + return LambdaLR(optimizer, lambda_lr, -1) |
| 211 | + |
| 212 | + |
192 | 213 | def train(config: Config): |
193 | 214 | sharding_strategy = get_sharding_strategy(config.sharding_strategy) |
194 | 215 | local_rank = int(os.environ["LOCAL_RANK"]) |
@@ -282,6 +303,7 @@ def scheduler_fn(opt): |
282 | 303 | opt, |
283 | 304 | num_warmup_steps=config.warmup_steps, |
284 | 305 | num_training_steps=config.total_steps, |
| 306 | + num_inner_steps=config.hv.local_steps, |
285 | 307 | ) |
286 | 308 |
|
287 | 309 | if config.hv is not None: |
|
0 commit comments