Skip to content

Loading checkpoint before fabric.setup(model) gets abnormal loss when using fabric.init_module() #20490

@kobenaxie

Description

@kobenaxie

Bug description

Init model with fabric.init_module(True) and load checkpoint after model = fabric.setup(model), the training loss is normal

with fabric.init_module(empty_init=(fabric.world_size > 1)):
    model = GPT(config)
model = fabric.setup(model)
load_checkpoint(fabric, model, checkpoint_path)

step = 1 | loss train: 0.8448048233985901
step = 2 | loss train: 1.3229767084121704
step = 3 | loss train: 1.2647839784622192
step = 4 | loss train: 1.287076711654663
step = 5 | loss train: 1.0357563495635986

but when loading checkpoint before model = fabric.setup(model), get loss much larger

with fabric.init_module(empty_init=(fabric.world_size > 1)):
    model = GPT(config)
load_checkpoint(fabric, model, checkpoint_path)
model = fabric.setup(model)

step = 1 | loss train: 12.027938842773438
step = 2 | loss train: 12.051375389099121
step = 3 | loss train: 12.112957954406738
step = 4 | loss train: 12.08558177947998
step = 5 | loss train: 12.089488983154297

Another phenomenon is that, if not using fabric.init_module(), I can get normal loss when loading checkpoint before fabric.setup(model),

# with fabric.init_module(empty_init=(fabric.world_size > 1)):
if True:
    model = GPT(config)
load_checkpoint(fabric, model, checkpoint_path)
model = fabric.setup(model)

step = 1 | loss train: 0.8447667956352234
step = 2 | loss train: 1.3229438066482544
step = 3 | loss train: 1.2663335800170898
step = 4 | loss train: 1.2902932167053223
step = 5 | loss train: 1.035811185836792

So how to load hf models converted by litgpt.scripts.convert_hf_checkpoint in a correct way?

What version are you seeing the problem on?

v2.4

How to reproduce the bug

from pathlib import Path

import torch
import lightning as L
from lightning.fabric.strategies import FSDPStrategy

from litgpt.args import TrainArgs
from litgpt.config import Config
from litgpt.model import GPT, Block
from litgpt.data import Alpaca2k
from litgpt.tokenizer import Tokenizer
from litgpt.utils import (
    chunked_cross_entropy,
    load_checkpoint,
    num_parameters,
    get_default_supported_precision,
)


def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int):
    # linear warmup followed by cosine annealing
    scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps)
    scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(max_steps - warmup_steps))
    return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps])


def main(
    checkpoint_dir: Path,
    devices: int = 8,
    num_nodes: int = 1,
    precision: str = "bf16-true",
    seed: int = 1337,
) -> None:
    torch.set_float32_matmul_precision("high")

    train_args = TrainArgs(
        save_interval = 1000,
        log_interval = 1,
        global_batch_size = 64,
        micro_batch_size = 4,
        lr_warmup_steps = 1000,
        epochs = 10,
        max_steps = 10000,
    )

    strategy = FSDPStrategy(
        auto_wrap_policy={Block},
        activation_checkpointing_policy={Block},
        state_dict_type="full",
        limit_all_gathers=True,
        cpu_offload=False,
    )
    
    fabric = L.Fabric(
        accelerator="cuda",
        devices=devices,
        num_nodes=num_nodes,
        strategy=strategy,
        precision=precision,
    )
    fabric.launch()
    fabric.seed_everything(seed)  # same seed for every process to init model (FSDP)
    
    dataset = Alpaca2k()
    tokenizer = Tokenizer(str(checkpoint_dir))
    dataset.connect(tokenizer, batch_size=train_args.micro_batch_size, max_seq_length=512)
    with fabric.rank_zero_first():
        dataset.prepare_data()
    dataset.setup()
    dataloader = dataset.train_dataloader()
    dataloader = fabric.setup_dataloaders(dataloader)

    checkpoint_path = str(checkpoint_dir / "lit_model.pth")
    config = Config.from_file(checkpoint_dir / "model_config.yaml")
    with fabric.init_module(empty_init=(fabric.world_size > 1)):
        model = GPT(config)
    fabric.print(f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}")
    # load_checkpoint(fabric, model, checkpoint_path)
    model = fabric.setup(model)
    load_checkpoint(fabric, model, checkpoint_path)

    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0002, weight_decay=0.0, betas=(0.9, 0.95))
    optimizer = fabric.setup_optimizers(optimizer)
    scheduler = get_lr_scheduler(optimizer, warmup_steps=train_args.lr_warmup_steps, max_steps=train_args.max_steps)

    model.train()
    for epoch in range(train_args.epochs):
        for step, batch in enumerate(dataloader, 1):
            input, target = batch["input_ids"], batch["labels"]
            logits = model(input)
            loss = chunked_cross_entropy(logits[..., :-1, :], target[..., 1:])
            fabric.backward(loss)

            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()
            fabric.print(f"{step = } | loss train: {loss.detach().item()}")


if __name__ == "__main__":
    checkpoint_dir = Path("./Qwen2.5-1.5B/")

    main(checkpoint_dir)

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.4.0):
#- PyTorch Version (e.g., 2.4.1):
#- Python version (e.g., 3.10):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:12.1
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source): pip

More info

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions