Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use DCP in HFCheckpointer to read/write directly to HuggingFace #2494

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
amend
ankitageorge committed Mar 11, 2025
commit cfd291efbc261f6b0670ede43fb018795a8b26ac
9 changes: 2 additions & 7 deletions recipes/configs/llama3_1/8B_full.yaml
Original file line number Diff line number Diff line change
@@ -39,13 +39,8 @@ model:

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
checkpoint_files: [
model-00001-of-00004.safetensors,
model-00002-of-00004.safetensors,
model-00003-of-00004.safetensors,
model-00004-of-00004.safetensors
]
checkpoint_dir: hf://meta-llama/Llama-3.1-8B-Instruct/
checkpoint_files: null
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3
63 changes: 37 additions & 26 deletions torchtune/training/checkpointing/_checkpointer.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import gc
import json
import os
import re
@@ -13,25 +12,22 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Protocol, Union

import fsspec

import torch
import torch.distributed as dist
from safetensors.torch import save_file

from torch.distributed.checkpoint import (
_HuggingFaceStorageReader,
_HuggingFaceStorageWriter,
async_save,
FileSystemReader,
FileSystemWriter,
load,
save,
)

# Replace this with something that actually works
if version("torch") > (2, 7):
from torch.distributed.checkpoint import (
_HuggingFaceStorageReader,
_HuggingFaceStorageWriter,
)

from torchtune import training
from torchtune.models import convert_weights
from torchtune.training.checkpointing._utils import (
@@ -428,40 +424,47 @@ def __init__(
)

self._safe_serialization = safe_serialization
self._checkpoint_dir = Path(checkpoint_dir)
self._checkpoint_dir = checkpoint_dir
self._model_type = ModelType[model_type]
self._output_dir = Path(output_dir)
self._output_dir = output_dir
check_outdir_not_in_ckptdir(
ckpt_dir=self._checkpoint_dir, out_dir=self._output_dir
)
self._output_dir.mkdir(parents=True, exist_ok=True)

# Use DCP specs if looking at the fsspec for HF
if checkpoint_dir.startswith("hf://"):
self._straight_hf = True
assert checkpoint_files is None
self._dcp_hf = True
self._fs = fsspec.filesystem(
"hf",
)
checkpoint_files = []
else:
self._dcp_hf = False
self._fs = fsspec.filesystem("file")

self._fs.mkdir(output_dir, parents=True, exist_ok=True)

# weight_map contains the state_dict key -> checkpoint file mapping so we can correctly
# parition the state dict into output checkpoint files. This is updated during checkpoint
# load
self._weight_map: Dict[str, str] = None

# the config.json file contains model params needed for state dict conversion
self._config = json.loads(
Path.joinpath(self._checkpoint_dir, "config.json").read_text()
)
self._config = None
print(self._fs.ls(self._checkpoint_dir))
with self._fs.open(
os.path.join(self._checkpoint_dir, "config.json"), "r"
) as json_file:
self._config = json.loads(json_file.read())

# repo_id is necessary for when saving an adapter config, so its compatible with HF.
# This json file is produced and saved in the download step.
# contents are {"repo_id": "some_model/some_model_version"}
repo_id_path = Path.joinpath(self._checkpoint_dir, REPO_ID_FNAME).with_suffix(
".json"
)
repo_id_path = os.path.join(self._checkpoint_dir, REPO_ID_FNAME) + ".json"

self.repo_id = None
if repo_id_path.exists():
with open(repo_id_path, "r") as json_file:
if self._fs.exists(repo_id_path):
with self._fs.open(repo_id_path, "r") as json_file:
data = json.load(json_file)
self.repo_id = data.get("repo_id")

@@ -545,13 +548,21 @@ def load_checkpoint(self) -> Dict[str, Any]:
# DCP load using the storage reader
hf_storage_reader = _HuggingFaceStorageReader(path=self._checkpoint_dir)
metadata = hf_storage_reader.read_metadata()
planner = _EmptyStateDictLoadPlanner()
planner.set_up_planner(state_dict={}, metadata=metadata)
merged_state_dict = load(
state_dict={},
state_dict = {}
for key in metadata.state_dict_metadata.keys():
# arbitrary value to ensure that the state_dict is not empty
state_dict[key] = torch.zeros(1)

load(
state_dict=state_dict,
storage_reader=_HuggingFaceStorageReader(path=self._checkpoint_dir),
load_planner=planner,
)

merged_state_dict = state_dict
print("num keys in merged state dict: ", len(merged_state_dict.keys()))
for merged_key, merged_value in merged_state_dict.items():
print("size of ", merged_key, merged_value.size())
break
else:
merged_state_dict = self._manually_merge_sharded_state_dicts()

16 changes: 9 additions & 7 deletions torchtune/training/checkpointing/_utils.py
Original file line number Diff line number Diff line change
@@ -403,10 +403,10 @@ def copy_files(


def get_recipe_checkpoint_path(
output_dir: Path,
output_dir: str,
recipe_checkpoint: Optional[str] = None,
should_load_recipe_state: bool = False,
) -> Optional[Path]:
) -> Optional[str]:
"""
If recipe_checkpoint is None, look for recipe_state.pt in {output_dir}/{RECIPE_STATE_DIRNAME}/recipe_state.pt.
This is to make it easier to resume from a previous run, without having to specify the recipe_checkpoint.
@@ -437,15 +437,15 @@ def get_recipe_checkpoint_path(
"If should_load_recipe_state is True, recipe_checkpoint file must be provided."
)

return Path(recipe_checkpoint_path)
return recipe_checkpoint_path


def get_adapter_checkpoint_path(
output_dir: Path,
output_dir: str,
adapter_checkpoint: Optional[str] = None,
should_load_recipe_state: bool = False,
pattern: str = r"^epoch_(\d+)",
) -> Optional[Path]:
) -> Optional[str]:
r"""
If adapter_checkpoint is None, look for it in {output_dir}/epoch_{latest_epoch}/adapter_model.pt.
This is to make it easier to resume from a previous run, without having to specify the adapter_checkpoint.
@@ -479,7 +479,7 @@ def get_adapter_checkpoint_path(
if os.path.exists(tentative_adapter_checkpoint_path):
adapter_checkpoint_path = tentative_adapter_checkpoint_path

return Path(adapter_checkpoint_path) if adapter_checkpoint_path else None
return adapter_checkpoint_path if adapter_checkpoint_path else None


def get_model_checkpoint_path(
@@ -581,11 +581,13 @@ def validate_checkpoint_files(
return checkpoint_paths


def check_outdir_not_in_ckptdir(ckpt_dir: Path, out_dir: Path) -> bool:
def check_outdir_not_in_ckptdir(ckpt_dir: str, out_dir: str) -> bool:
"""
Checks that the output directory is not equal to or a subdirectory of the checkpoint directory.
This is necessary to avoid making copies of copies when geting config files from ckpt_dir.
"""
if ckpt_dir.startswith("hf://"):
return True

# Resolve the absolute paths to avoid issues with relative paths
_ckpt_dir = ckpt_dir.resolve()