Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use DCP in HFCheckpointer to read/write directly to HuggingFace #2494

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions tests/torchtune/training/checkpointing/test_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import pytest

import safetensors
import torch
from torch import randn

Expand Down Expand Up @@ -160,6 +161,12 @@ def llama2_hf_checkpoints(self, tmp_path, state_dict_1, state_dict_2):

torch.save(state_dict_1, checkpoint_file_1)
torch.save(state_dict_2, checkpoint_file_2)
safetensors.torch.save_file(
state_dict_1, checkpoint_dir / "model-00001-of-00002.safetensors"
)
safetensors.torch.save_file(
state_dict_2, checkpoint_dir / "model-00002-of-00002.safetensors"
)

config = {
"hidden_size": 64,
Expand All @@ -169,6 +176,14 @@ def llama2_hf_checkpoints(self, tmp_path, state_dict_1, state_dict_2):
config_file = Path.joinpath(checkpoint_dir, "config.json")
with config_file.open("w") as f:
json.dump(config, f)
metadata_file = Path.joinpath(checkpoint_dir, "model.safetensors.index.json")
metadata = {"weight_map": {}}
for key in state_dict_1.keys():
metadata["weight_map"][key] = "model-00001-of-00002.safetensors"
for key in state_dict_2.keys():
metadata["weight_map"][key] = "model-00002-of-00002.safetensors"
with metadata_file.open("w") as f:
json.dump(metadata, f)

return (checkpoint_file_1, checkpoint_file_2)

Expand Down Expand Up @@ -504,6 +519,77 @@ def test_save_checkpoint_in_peft_format(
actual_adapter_state_dict[k], expected_adapter_state_dict[new_k]
)

def test_save_load_checkpoint_multiple_file_with_dcp(
self,
multi_file_checkpointer: FullModelHFCheckpointer,
llama2_hf_checkpoints: Tuple[Path, Path],
):
"""
Test ``load_checkpoint`` method within the FullModelCheckpointer for multiple
checkpoint files with DCP enabled.

We test:
* ``load_checkpoint`` loads the right sets of keys
* Internal state of the checkpointer is correctly updated
* Converted checkpoint can be loaded into the llama2 torchtune implementation
"""
multi_file_checkpointer._enable_dcp = True
# Read the state dict directly from files
checkpoint_file_1, checkpoint_file_2 = llama2_hf_checkpoints
orig_state_dict_1 = safe_torch_load(checkpoint_file_1)
orig_state_dict_2 = safe_torch_load(checkpoint_file_2)

# merged state dict from checkpointer
state_dict = multi_file_checkpointer.load_checkpoint()

# We ignore inv_freq as is standard practice
assert len(state_dict["model"].keys()) + 2 == len(
orig_state_dict_1.keys()
) + len(orig_state_dict_2.keys())

# the keys in the weight_map should match up with the keys in the weight_map
for key in orig_state_dict_1.keys():
if "inv_freq" in key:
continue
assert key in multi_file_checkpointer._weight_map

for key in orig_state_dict_2.keys():
if "inv_freq" in key:
continue
assert key in multi_file_checkpointer._weight_map

# finally loading into the model should work
model = llama2.llama2(
vocab_size=_VOCAB_SIZE,
num_layers=2,
num_heads=_NUM_HEADS,
num_kv_heads=_NUM_KV_HEADS,
embed_dim=_DIM,
max_seq_len=128,
)
model.load_state_dict(state_dict["model"])

multi_file_checkpointer.save_checkpoint(state_dict, epoch=3)

# Reload the output checkpoint file and compare to the original checkpoint. This
# assumes we know what the name of the file is. This is fine, breaking this logic
# should be something we capture through this test
output_file_1 = Path.joinpath(
checkpoint_file_1.parent.parent / "output_dir",
"epoch_3",
SHARD_FNAME.format(cpt_idx="1".zfill(5), num_shards="2".zfill(5)),
).with_suffix(".safetensors")
output_file_2 = Path.joinpath(
checkpoint_file_2.parent.parent / "output_dir",
"epoch_3",
SHARD_FNAME.format(cpt_idx="2".zfill(5), num_shards="2".zfill(5)),
).with_suffix(".safetensors")
output_state_dict_1 = safe_torch_load(output_file_1)
output_state_dict_2 = safe_torch_load(output_file_2)

assert len(output_state_dict_1.keys()) + 1 == len(orig_state_dict_1.keys())
assert len(output_state_dict_2.keys()) + 1 == len(orig_state_dict_2.keys())


class TestHFMistralRewardModelFullModelCheckpointer:
@pytest.fixture
Expand Down
Loading