Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rough draft for integration with DCP HF Storage Reader / Writer #2435

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions recipes/configs/llama3_2/3B_full_single_device.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
# Config for single device full finetuning in full_finetune_single_device.py
# using a Llama3.2 3B Instruct model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Llama-3.2-3B-Instruct --output-dir /tmp/Llama-3.2-3B-Instruct --ignore-patterns "original/consolidated.00.pth"
#
# The default config uses an optimizer from bitsandbytes. If you do not have it installed,
# you can install it with
# pip install bitsandbytes
Expand All @@ -25,7 +21,7 @@ output_dir: /tmp/torchtune/llama3_2_3B/full_single_device # /tmp may be deleted
# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Llama-3.2-3B-Instruct/original/tokenizer.model
path: hf://meta-llama/Llama-3.2-3B-Instruct/original/tokenizer.model
max_seq_len: null

# Dataset
Expand All @@ -41,12 +37,8 @@ model:

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-3.2-3B-Instruct/
checkpoint_files: [
model-00001-of-00002.safetensors,
model-00002-of-00002.safetensors,
]
recipe_checkpoint: null
checkpoint_dir: hf://meta-llama/Llama-3.2-3B-Instruct
checkpoint_files: null
output_dir: ${output_dir}
model_type: LLAMA3_2
resume_from_checkpoint: False
Expand Down
206 changes: 120 additions & 86 deletions torchtune/training/checkpointing/_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch
import torch.distributed as dist
from safetensors.torch import save_file

from torch.distributed.checkpoint import (
async_save,
FileSystemReader,
Expand All @@ -24,6 +25,13 @@
save,
)

# Replace this with something that actually works
if version("torch") > (2, 7):
from torch.distributed.checkpoint import (
_HuggingFaceStorageReader,
_HuggingFaceStorageWriter,
)

from torchtune import training
from torchtune.models import convert_weights
from torchtune.training.checkpointing._utils import (
Expand Down Expand Up @@ -430,6 +438,14 @@ def __init__(
)
self._output_dir.mkdir(parents=True, exist_ok=True)

# Use DCP specs if looking at the fsspec for HF
if checkpoint_dir.startswith("hf://"):
self._storage_reader = _HuggingFaceStorageReader
self._storage_writer = _HuggingFaceStorageWriter
assert checkpoint_files is None
else:
self._storage_reader, self._storage_writer = None, None

# weight_map contains the state_dict key -> checkpoint file mapping so we can correctly
# parition the state dict into output checkpoint files. This is updated during checkpoint
# load
Expand Down Expand Up @@ -484,6 +500,25 @@ def __init__(
f"\n\tadapter_checkpoint: {self._adapter_checkpoint}"
)

def _manually_merge_sharded_state_dicts(self):
# merged state_dict contains keys and weights from all the checkpoint files
merged_state_dict: Dict[str, torch.Tensor] = {}
# _checkpoint_paths are already sorted so simply enumerate to generate the right id
for cpt_idx, cpt_path in enumerate(self._checkpoint_paths):
state_dict = safe_torch_load(cpt_path)
for key, value in state_dict.items():
# Ensure that the state dict is a flat dict of keys and tensors. Breaking this assumption
# will break recipe code
if not isinstance(value, torch.Tensor):
raise ValueError(
f"Expected all values in the state dict to be torch.Tensor. "
f"Found {type(value)} instead."
)
# idx is written in the 4 digit format (eg: 0001, 0002, etc.)
self._weight_map[key] = f"{cpt_idx + 1:04}"
merged_state_dict.update(state_dict)
return merged_state_dict

def load_checkpoint(self) -> Dict[str, Any]:
"""
Load HF checkpoint from file.
Expand All @@ -504,32 +539,17 @@ def load_checkpoint(self) -> Dict[str, Any]:

self._weight_map = {}

# merged state_dict contains keys and weights from all the checkpoint files
merged_state_dict: Dict[str, torch.Tensor] = {}

# converted_state_dict is the final state_dict passed to the recipe after the
# keys are converted into the torchtune format. This optionally also contains
# the recipe state and adapter weights
converted_state_dict: Dict[str, Dict[str, torch.Tensor]] = {}

# _checkpoint_paths are already sorted so simply enumerate to generate the right id
for cpt_idx, cpt_path in enumerate(self._checkpoint_paths):
state_dict = safe_torch_load(cpt_path)
for key, value in state_dict.items():
# Ensure that the state dict is a flat dict of keys and tensors. Breaking this assumption
# will break recipe code
if not isinstance(value, torch.Tensor):
raise ValueError(
f"Expected all values in the state dict to be torch.Tensor. "
f"Found {type(value)} instead."
)
# idx is written in the 4 digit format (eg: 0001, 0002, etc.)
self._weight_map[key] = f"{cpt_idx + 1:04}"
merged_state_dict.update(state_dict)
if self._storage_reader is not None:
# DCP load using the storage reader
merged_state_dict = load()
else:
merged_state_dict = self._manually_merge_sharded_state_dicts()

# delete the state_dict to free up memory; TODO check if this del is needed
del state_dict
gc.collect()
if self._model_type == ModelType.PHI3_MINI:
log_rank_zero(
logger=logger,
Expand Down Expand Up @@ -634,6 +654,25 @@ def load_checkpoint(self) -> Dict[str, Any]:

return converted_state_dict

def _generate_splits_for_sharded_state_dict(state_dict: Dict[str, torch.Tensor]):
# split the state_dict into separate dicts, one for each output checkpoint file
# e.g. split_state_dicts= {
# "0001": {"key1": tensor1, "key2": tensor2},
# "0002": {"key3": tensor3}
# }
split_state_dicts: Dict[str, Dict[str, torch.Tensor]] = {}
total_size = 0
for key, weight in state_dict[training.MODEL_KEY].items():
cpt_idx = self._weight_map[key]

# initialize dict
if cpt_idx not in split_state_dicts:
split_state_dicts[cpt_idx] = {}

split_state_dicts[cpt_idx].update({key: weight})
total_size += weight.numel() * weight.element_size()
return split_state_dicts, total_size

def save_checkpoint(
self,
state_dict: Dict[str, Any],
Expand Down Expand Up @@ -729,76 +768,71 @@ def save_checkpoint(
head_dim=self._config.get("head_dim", None),
)

# split the state_dict into separate dicts, one for each output checkpoint file
# e.g. split_state_dicts= {
# "0001": {"key1": tensor1, "key2": tensor2},
# "0002": {"key3": tensor3}
# }
split_state_dicts: Dict[str, Dict[str, torch.Tensor]] = {}
total_size = 0
for key, weight in state_dict[training.MODEL_KEY].items():
cpt_idx = self._weight_map[key]

# initialize dict
if cpt_idx not in split_state_dicts:
split_state_dicts[cpt_idx] = {}

split_state_dicts[cpt_idx].update({key: weight})
total_size += weight.numel() * weight.element_size()

# write the partitioned state dicts to the right checkpoint file
# e.g. model-00001-of-00004.safetensors, model-00002-of-00004.safetensors, etc
num_shards = len(split_state_dicts)
map_original_name_to_new_name = {}
for cpt_idx, model_state_dict in split_state_dicts.items():
# TODO: We should probably use the original shard name and just add a prefix
# however, having the SHARD_FNAME standardizes our checkpoints
shard_name = SHARD_FNAME.format(
cpt_idx=f"{cpt_idx}".zfill(5), num_shards=f"{num_shards}".zfill(5)
)
map_original_name_to_new_name[cpt_idx] = shard_name
output_path = Path.joinpath(
self._output_dir, f"epoch_{epoch}", shard_name
)
output_path.parent.mkdir(parents=True, exist_ok=True)
if not self._safe_serialization:
output_path = output_path.with_suffix(".bin")
torch.save(model_state_dict, output_path)
else:
output_path = output_path.with_suffix(".safetensors")
save_file(model_state_dict, output_path, metadata={"format": "pt"})
if self._storage_writer is not None:
# DCP save using the storage writer
save()
else:
(
split_state_dicts,
total_size,
) = self._generate_splits_for_sharded_state_dict(state_dict)

# write the partitioned state dicts to the right checkpoint file
# e.g. model-00001-of-00004.safetensors, model-00002-of-00004.safetensors, etc
num_shards = len(split_state_dicts)
map_original_name_to_new_name = {}
for cpt_idx, model_state_dict in split_state_dicts.items():
# TODO: We should probably use the original shard name and just add a prefix
# however, having the SHARD_FNAME standardizes our checkpoints
shard_name = SHARD_FNAME.format(
cpt_idx=f"{cpt_idx}".zfill(5),
num_shards=f"{num_shards}".zfill(5),
)
map_original_name_to_new_name[cpt_idx] = shard_name
output_path = Path.joinpath(
self._output_dir, f"epoch_{epoch}", shard_name
)
output_path.parent.mkdir(parents=True, exist_ok=True)
if not self._safe_serialization:
output_path = output_path.with_suffix(".bin")
torch.save(model_state_dict, output_path)
else:
output_path = output_path.with_suffix(".safetensors")
save_file(
model_state_dict, output_path, metadata={"format": "pt"}
)

logger.info(
"Model checkpoint of size "
f"{os.path.getsize(output_path) / 1024**3:.2f} GiB "
f"saved to {output_path}"
logger.info(
"Model checkpoint of size "
f"{os.path.getsize(output_path) / 1024**3:.2f} GiB "
f"saved to {output_path}"
)

# Save the appropriate index file based on serialization format
# e.g. {metadata: {total_size: 1234}, weight_map: {"key1": "model_0001.safetensors", "key2": "model_0002.safetensors"}}
if self._safe_serialization:
weight_map = {
k: map_original_name_to_new_name[cpt_idx] + ".safetensors"
for k, cpt_idx in self._weight_map.items()
}
index_file_name = SAFETENSOR_INDEX_FNAME
else:
weight_map = {
k: map_original_name_to_new_name[cpt_idx] + ".bin"
for k, cpt_idx in self._weight_map.items()
}
index_file_name = TORCH_INDEX_FNAME

index_path = Path.joinpath(
self._output_dir, f"epoch_{epoch}", index_file_name
)

# Save the appropriate index file based on serialization format
# e.g. {metadata: {total_size: 1234}, weight_map: {"key1": "model_0001.safetensors", "key2": "model_0002.safetensors"}}
if self._safe_serialization:
weight_map = {
k: map_original_name_to_new_name[cpt_idx] + ".safetensors"
for k, cpt_idx in self._weight_map.items()
}
index_file_name = SAFETENSOR_INDEX_FNAME
else:
weight_map = {
k: map_original_name_to_new_name[cpt_idx] + ".bin"
for k, cpt_idx in self._weight_map.items()
index_data = {
"metadata": {"total_size": total_size},
"weight_map": weight_map,
}
index_file_name = TORCH_INDEX_FNAME

index_path = Path.joinpath(
self._output_dir, f"epoch_{epoch}", index_file_name
)

index_data = {
"metadata": {"total_size": total_size},
"weight_map": weight_map,
}
with open(index_path, "w") as f:
json.dump(index_data, f, indent=2)
with open(index_path, "w") as f:
json.dump(index_data, f, indent=2)

if training.ADAPTER_KEY in state_dict:

Expand Down
Loading