Skip to content

Commit f4857cb

Browse files
committed
tests
1 parent 7f29c66 commit f4857cb

File tree

2 files changed

+4
-128
lines changed

2 files changed

+4
-128
lines changed

tests/torchtune/training/checkpointing/test_checkpointer.py

-124
Original file line numberDiff line numberDiff line change
@@ -754,68 +754,6 @@ def test_load_save_checkpoint_single_file(
754754

755755
assert len(output_state_dict.keys()) == len(orig_state_dict.keys()) - 1
756756

757-
'''
758-
def test_load_save_checkpoint_single_file_with_dcp(
759-
self,
760-
single_file_checkpointer: FullModelHFCheckpointer,
761-
mistral_reward_model_hf_checkpoint: Path,
762-
):
763-
"""
764-
Test ``load_checkpoint`` and ``save_checkpoint`` method within the
765-
FullModelHFCheckpointer for a single checkpoint file for a mistral reward model
766-
with DCP.
767-
768-
We test:
769-
* ``load_checkpoint`` loads the right sets of keys
770-
* Internal state of the checkpointer is correctly updated
771-
* Converted checkpoint can be loaded into the `mistral_classifier` torchtune implementation
772-
* Saved checkpoint keys match the original checkpoint
773-
"""
774-
single_file_checkpointer._enable_dcp = True
775-
# Read the state dict directly from file using torch.load. This will be the state
776-
# dict we test against
777-
checkpoint_file = mistral_reward_model_hf_checkpoint
778-
orig_state_dict = safe_torch_load(checkpoint_file)
779-
780-
# Converted state dict from the checkpointer
781-
state_dict = single_file_checkpointer.load_checkpoint()
782-
# Check that we've loaded all the keys minus the output bias
783-
assert len(state_dict["model"].keys()) == len(orig_state_dict.keys()) - 1
784-
785-
# the keys in original state dict should match up with the keys in the weight_map
786-
for key in orig_state_dict.keys():
787-
if "inv_freq" in key or "output.bias" in key:
788-
continue
789-
assert key in single_file_checkpointer._weight_map
790-
791-
# loading the state dict into the model implementation should work correctly
792-
model = mistral.mistral_classifier(
793-
num_classes=1,
794-
vocab_size=_VOCAB_SIZE,
795-
num_layers=1,
796-
num_heads=_NUM_HEADS,
797-
num_kv_heads=_NUM_KV_HEADS,
798-
embed_dim=_DIM,
799-
intermediate_dim=_HIDDEN_DIM,
800-
max_seq_len=128,
801-
)
802-
model.load_state_dict(state_dict["model"])
803-
804-
single_file_checkpointer.save_checkpoint(state_dict, epoch=1)
805-
806-
# Reload the output checkpoint file and compare to the original checkpoint. This
807-
# assumes we know what the name of the file is. This is fine, breaking this logic
808-
# should be something we capture through this test
809-
output_file = Path.joinpath(
810-
checkpoint_file.parent.parent / "output_dir",
811-
"epoch_1",
812-
SHARD_FNAME.format(cpt_idx="1".zfill(5), num_shards="1".zfill(5)),
813-
).with_suffix(".safetensors")
814-
output_state_dict = safe_torch_load(output_file)
815-
816-
assert len(output_state_dict.keys()) == len(orig_state_dict.keys()) - 1
817-
'''
818-
819757

820758
class TestHFGemmaFullModelCheckpointer:
821759
@pytest.fixture
@@ -983,65 +921,3 @@ def test_load_save_checkpoint_single_file(
983921
output_state_dict = safe_torch_load(output_file)
984922

985923
assert len(output_state_dict.keys()) == len(orig_state_dict.keys())
986-
987-
'''
988-
def test_load_save_checkpoint_single_file_with_dcp(
989-
self,
990-
single_file_checkpointer: FullModelHFCheckpointer,
991-
gemma_hf_checkpoint: Path,
992-
):
993-
"""
994-
Test ``load_checkpoint`` and ``save_checkpoint`` method within the
995-
FullModelHFCheckpointer for a single checkpoint file for Gemma with DCP enabled.
996-
997-
We test:
998-
* ``load_checkpoint`` loads the right sets of keys
999-
* Internal state of the checkpointer is correctly updated
1000-
* Converted checkpoint can be loaded into the `gemma` TorchTune implementation
1001-
* lm_head weights are tied to the embed_tokens weights during saving
1002-
* lmhead weights are popped during loading
1003-
"""
1004-
single_file_checkpointer._enable_dcp = True
1005-
# Read the state dict directly from file using torch.load. This will be the state
1006-
# dict we test against
1007-
checkpoint_file = gemma_hf_checkpoint
1008-
orig_state_dict = safe_torch_load(checkpoint_file)
1009-
1010-
# Converted state dict from the checkpointer
1011-
1012-
state_dict = single_file_checkpointer.load_checkpoint()
1013-
assert len(state_dict["model"].keys()) == len(orig_state_dict.keys())
1014-
1015-
# the keys in original state dict should match up with the keys in the weight_map
1016-
for key in orig_state_dict.keys():
1017-
if "inv_freq" in key:
1018-
continue
1019-
assert key in single_file_checkpointer._weight_map
1020-
1021-
# loading the state dict into the model implementation should work correctly
1022-
model = gemma.gemma(
1023-
vocab_size=_VOCAB_SIZE,
1024-
num_layers=1,
1025-
num_heads=_NUM_HEADS,
1026-
head_dim=_HEAD_DIM,
1027-
num_kv_heads=1,
1028-
embed_dim=_DIM,
1029-
intermediate_dim=_HIDDEN_DIM,
1030-
max_seq_len=128,
1031-
)
1032-
model.load_state_dict(state_dict["model"])
1033-
1034-
single_file_checkpointer.save_checkpoint(state_dict, epoch=1)
1035-
1036-
# Reload the output checkpoint file and compare to the original checkpoint. This
1037-
# assumes we know what the name of the file is. This is fine, breaking this logic
1038-
# should be something we capture through this test
1039-
output_file = Path.joinpath(
1040-
checkpoint_file.parent.parent / "output_dir",
1041-
"epoch_1",
1042-
SHARD_FNAME.format(cpt_idx="1".zfill(5), num_shards="1".zfill(5)),
1043-
).with_suffix(".safetensors")
1044-
output_state_dict = safe_torch_load(output_file)
1045-
1046-
assert len(output_state_dict.keys()) == len(orig_state_dict.keys())
1047-
'''

torchtune/training/checkpointing/_checkpointer.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ def __init__(
442442
f"Got {self._fs} and {output_fs} instead."
443443
)
444444

445-
self._fs.mkdir(output_dir, exist_ok=True)
445+
self._fs.mkdirs(output_dir, exist_ok=True)
446446

447447
# weight_map contains the state_dict key -> checkpoint file mapping so we can correctly
448448
# parition the state dict into output checkpoint files. This is updated during checkpoint
@@ -813,7 +813,7 @@ def save_checkpoint(
813813
output_path = os.path.join(
814814
self._output_dir, f"epoch_{epoch}", shard_name
815815
)
816-
self._fs.mkdir(os.path.dirname(output_path), exist_ok=True)
816+
self._fs.mkdirs(os.path.dirname(output_path), exist_ok=True)
817817
if not self._safe_serialization:
818818
output_path = output_path = ".bin"
819819
torch.save(model_state_dict, output_path)
@@ -865,7 +865,7 @@ def save_checkpoint(
865865
os.path.join(self._output_dir, f"epoch_{epoch}", ADAPTER_MODEL_FNAME)
866866
+ ".pt"
867867
)
868-
self._fs.mkdir(os.path.dirname(output_path), exist_ok=True)
868+
self._fs.mkdirs(os.path.dirname(output_path), exist_ok=True)
869869
torch.save(state_dict[training.ADAPTER_KEY], output_path)
870870
logger.info(
871871
"Adapter checkpoint of size "
@@ -894,7 +894,7 @@ def save_checkpoint(
894894
output_path = os.path.join(
895895
self._output_dir, f"epoch_{epoch}", ADAPTER_MODEL_FNAME
896896
)
897-
self._fs.mkdir(os.path.dirname(output_path), exist_ok=True)
897+
self._fs.mkdirs(os.path.dirname(output_path), exist_ok=True)
898898
if not self._safe_serialization:
899899
output_path = output_path + ".bin"
900900
torch.save(state_dict[training.ADAPTER_KEY], output_path)

0 commit comments

Comments
 (0)