@@ -754,68 +754,6 @@ def test_load_save_checkpoint_single_file(
754
754
755
755
assert len (output_state_dict .keys ()) == len (orig_state_dict .keys ()) - 1
756
756
757
- '''
758
- def test_load_save_checkpoint_single_file_with_dcp(
759
- self,
760
- single_file_checkpointer: FullModelHFCheckpointer,
761
- mistral_reward_model_hf_checkpoint: Path,
762
- ):
763
- """
764
- Test ``load_checkpoint`` and ``save_checkpoint`` method within the
765
- FullModelHFCheckpointer for a single checkpoint file for a mistral reward model
766
- with DCP.
767
-
768
- We test:
769
- * ``load_checkpoint`` loads the right sets of keys
770
- * Internal state of the checkpointer is correctly updated
771
- * Converted checkpoint can be loaded into the `mistral_classifier` torchtune implementation
772
- * Saved checkpoint keys match the original checkpoint
773
- """
774
- single_file_checkpointer._enable_dcp = True
775
- # Read the state dict directly from file using torch.load. This will be the state
776
- # dict we test against
777
- checkpoint_file = mistral_reward_model_hf_checkpoint
778
- orig_state_dict = safe_torch_load(checkpoint_file)
779
-
780
- # Converted state dict from the checkpointer
781
- state_dict = single_file_checkpointer.load_checkpoint()
782
- # Check that we've loaded all the keys minus the output bias
783
- assert len(state_dict["model"].keys()) == len(orig_state_dict.keys()) - 1
784
-
785
- # the keys in original state dict should match up with the keys in the weight_map
786
- for key in orig_state_dict.keys():
787
- if "inv_freq" in key or "output.bias" in key:
788
- continue
789
- assert key in single_file_checkpointer._weight_map
790
-
791
- # loading the state dict into the model implementation should work correctly
792
- model = mistral.mistral_classifier(
793
- num_classes=1,
794
- vocab_size=_VOCAB_SIZE,
795
- num_layers=1,
796
- num_heads=_NUM_HEADS,
797
- num_kv_heads=_NUM_KV_HEADS,
798
- embed_dim=_DIM,
799
- intermediate_dim=_HIDDEN_DIM,
800
- max_seq_len=128,
801
- )
802
- model.load_state_dict(state_dict["model"])
803
-
804
- single_file_checkpointer.save_checkpoint(state_dict, epoch=1)
805
-
806
- # Reload the output checkpoint file and compare to the original checkpoint. This
807
- # assumes we know what the name of the file is. This is fine, breaking this logic
808
- # should be something we capture through this test
809
- output_file = Path.joinpath(
810
- checkpoint_file.parent.parent / "output_dir",
811
- "epoch_1",
812
- SHARD_FNAME.format(cpt_idx="1".zfill(5), num_shards="1".zfill(5)),
813
- ).with_suffix(".safetensors")
814
- output_state_dict = safe_torch_load(output_file)
815
-
816
- assert len(output_state_dict.keys()) == len(orig_state_dict.keys()) - 1
817
- '''
818
-
819
757
820
758
class TestHFGemmaFullModelCheckpointer :
821
759
@pytest .fixture
@@ -983,65 +921,3 @@ def test_load_save_checkpoint_single_file(
983
921
output_state_dict = safe_torch_load (output_file )
984
922
985
923
assert len (output_state_dict .keys ()) == len (orig_state_dict .keys ())
986
-
987
- '''
988
- def test_load_save_checkpoint_single_file_with_dcp(
989
- self,
990
- single_file_checkpointer: FullModelHFCheckpointer,
991
- gemma_hf_checkpoint: Path,
992
- ):
993
- """
994
- Test ``load_checkpoint`` and ``save_checkpoint`` method within the
995
- FullModelHFCheckpointer for a single checkpoint file for Gemma with DCP enabled.
996
-
997
- We test:
998
- * ``load_checkpoint`` loads the right sets of keys
999
- * Internal state of the checkpointer is correctly updated
1000
- * Converted checkpoint can be loaded into the `gemma` TorchTune implementation
1001
- * lm_head weights are tied to the embed_tokens weights during saving
1002
- * lmhead weights are popped during loading
1003
- """
1004
- single_file_checkpointer._enable_dcp = True
1005
- # Read the state dict directly from file using torch.load. This will be the state
1006
- # dict we test against
1007
- checkpoint_file = gemma_hf_checkpoint
1008
- orig_state_dict = safe_torch_load(checkpoint_file)
1009
-
1010
- # Converted state dict from the checkpointer
1011
-
1012
- state_dict = single_file_checkpointer.load_checkpoint()
1013
- assert len(state_dict["model"].keys()) == len(orig_state_dict.keys())
1014
-
1015
- # the keys in original state dict should match up with the keys in the weight_map
1016
- for key in orig_state_dict.keys():
1017
- if "inv_freq" in key:
1018
- continue
1019
- assert key in single_file_checkpointer._weight_map
1020
-
1021
- # loading the state dict into the model implementation should work correctly
1022
- model = gemma.gemma(
1023
- vocab_size=_VOCAB_SIZE,
1024
- num_layers=1,
1025
- num_heads=_NUM_HEADS,
1026
- head_dim=_HEAD_DIM,
1027
- num_kv_heads=1,
1028
- embed_dim=_DIM,
1029
- intermediate_dim=_HIDDEN_DIM,
1030
- max_seq_len=128,
1031
- )
1032
- model.load_state_dict(state_dict["model"])
1033
-
1034
- single_file_checkpointer.save_checkpoint(state_dict, epoch=1)
1035
-
1036
- # Reload the output checkpoint file and compare to the original checkpoint. This
1037
- # assumes we know what the name of the file is. This is fine, breaking this logic
1038
- # should be something we capture through this test
1039
- output_file = Path.joinpath(
1040
- checkpoint_file.parent.parent / "output_dir",
1041
- "epoch_1",
1042
- SHARD_FNAME.format(cpt_idx="1".zfill(5), num_shards="1".zfill(5)),
1043
- ).with_suffix(".safetensors")
1044
- output_state_dict = safe_torch_load(output_file)
1045
-
1046
- assert len(output_state_dict.keys()) == len(orig_state_dict.keys())
1047
- '''
0 commit comments