-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Open
Labels
Description
Bug description
Init model with fabric.init_module(True)
and load checkpoint after model = fabric.setup(model)
, the training loss is normal
with fabric.init_module(empty_init=(fabric.world_size > 1)):
model = GPT(config)
model = fabric.setup(model)
load_checkpoint(fabric, model, checkpoint_path)
step = 1 | loss train: 0.8448048233985901
step = 2 | loss train: 1.3229767084121704
step = 3 | loss train: 1.2647839784622192
step = 4 | loss train: 1.287076711654663
step = 5 | loss train: 1.0357563495635986
but when loading checkpoint before model = fabric.setup(model)
, get loss much larger
with fabric.init_module(empty_init=(fabric.world_size > 1)):
model = GPT(config)
load_checkpoint(fabric, model, checkpoint_path)
model = fabric.setup(model)
step = 1 | loss train: 12.027938842773438
step = 2 | loss train: 12.051375389099121
step = 3 | loss train: 12.112957954406738
step = 4 | loss train: 12.08558177947998
step = 5 | loss train: 12.089488983154297
Another phenomenon is that, if not using fabric.init_module()
, I can get normal loss when loading checkpoint before fabric.setup(model)
,
# with fabric.init_module(empty_init=(fabric.world_size > 1)):
if True:
model = GPT(config)
load_checkpoint(fabric, model, checkpoint_path)
model = fabric.setup(model)
step = 1 | loss train: 0.8447667956352234
step = 2 | loss train: 1.3229438066482544
step = 3 | loss train: 1.2663335800170898
step = 4 | loss train: 1.2902932167053223
step = 5 | loss train: 1.035811185836792
So how to load hf models converted by litgpt.scripts.convert_hf_checkpoint
in a correct way?
What version are you seeing the problem on?
v2.4
How to reproduce the bug
from pathlib import Path
import torch
import lightning as L
from lightning.fabric.strategies import FSDPStrategy
from litgpt.args import TrainArgs
from litgpt.config import Config
from litgpt.model import GPT, Block
from litgpt.data import Alpaca2k
from litgpt.tokenizer import Tokenizer
from litgpt.utils import (
chunked_cross_entropy,
load_checkpoint,
num_parameters,
get_default_supported_precision,
)
def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int):
# linear warmup followed by cosine annealing
scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps)
scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(max_steps - warmup_steps))
return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps])
def main(
checkpoint_dir: Path,
devices: int = 8,
num_nodes: int = 1,
precision: str = "bf16-true",
seed: int = 1337,
) -> None:
torch.set_float32_matmul_precision("high")
train_args = TrainArgs(
save_interval = 1000,
log_interval = 1,
global_batch_size = 64,
micro_batch_size = 4,
lr_warmup_steps = 1000,
epochs = 10,
max_steps = 10000,
)
strategy = FSDPStrategy(
auto_wrap_policy={Block},
activation_checkpointing_policy={Block},
state_dict_type="full",
limit_all_gathers=True,
cpu_offload=False,
)
fabric = L.Fabric(
accelerator="cuda",
devices=devices,
num_nodes=num_nodes,
strategy=strategy,
precision=precision,
)
fabric.launch()
fabric.seed_everything(seed) # same seed for every process to init model (FSDP)
dataset = Alpaca2k()
tokenizer = Tokenizer(str(checkpoint_dir))
dataset.connect(tokenizer, batch_size=train_args.micro_batch_size, max_seq_length=512)
with fabric.rank_zero_first():
dataset.prepare_data()
dataset.setup()
dataloader = dataset.train_dataloader()
dataloader = fabric.setup_dataloaders(dataloader)
checkpoint_path = str(checkpoint_dir / "lit_model.pth")
config = Config.from_file(checkpoint_dir / "model_config.yaml")
with fabric.init_module(empty_init=(fabric.world_size > 1)):
model = GPT(config)
fabric.print(f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}")
# load_checkpoint(fabric, model, checkpoint_path)
model = fabric.setup(model)
load_checkpoint(fabric, model, checkpoint_path)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0002, weight_decay=0.0, betas=(0.9, 0.95))
optimizer = fabric.setup_optimizers(optimizer)
scheduler = get_lr_scheduler(optimizer, warmup_steps=train_args.lr_warmup_steps, max_steps=train_args.max_steps)
model.train()
for epoch in range(train_args.epochs):
for step, batch in enumerate(dataloader, 1):
input, target = batch["input_ids"], batch["labels"]
logits = model(input)
loss = chunked_cross_entropy(logits[..., :-1, :], target[..., 1:])
fabric.backward(loss)
optimizer.step()
optimizer.zero_grad()
scheduler.step()
fabric.print(f"{step = } | loss train: {loss.detach().item()}")
if __name__ == "__main__":
checkpoint_dir = Path("./Qwen2.5-1.5B/")
main(checkpoint_dir)
Error messages and logs
# Error messages and logs here please
Environment
Current environment
#- PyTorch Lightning Version (e.g., 2.4.0):
#- PyTorch Version (e.g., 2.4.1):
#- Python version (e.g., 3.10):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:12.1
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source): pip
More info
No response