11
11
12
12
import pytest
13
13
14
+ import safetensors
14
15
import torch
15
16
from torch import randn
16
17
@@ -160,6 +161,12 @@ def llama2_hf_checkpoints(self, tmp_path, state_dict_1, state_dict_2):
160
161
161
162
torch .save (state_dict_1 , checkpoint_file_1 )
162
163
torch .save (state_dict_2 , checkpoint_file_2 )
164
+ safetensors .torch .save_file (
165
+ state_dict_1 , checkpoint_dir / "model-00001-of-00002.safetensors"
166
+ )
167
+ safetensors .torch .save_file (
168
+ state_dict_2 , checkpoint_dir / "model-00002-of-00002.safetensors"
169
+ )
163
170
164
171
config = {
165
172
"hidden_size" : 64 ,
@@ -169,6 +176,14 @@ def llama2_hf_checkpoints(self, tmp_path, state_dict_1, state_dict_2):
169
176
config_file = Path .joinpath (checkpoint_dir , "config.json" )
170
177
with config_file .open ("w" ) as f :
171
178
json .dump (config , f )
179
+ metadata_file = Path .joinpath (checkpoint_dir , "model.safetensors.index.json" )
180
+ metadata = {"weight_map" : {}}
181
+ for key in state_dict_1 .keys ():
182
+ metadata ["weight_map" ][key ] = "model-00001-of-00002.safetensors"
183
+ for key in state_dict_2 .keys ():
184
+ metadata ["weight_map" ][key ] = "model-00002-of-00002.safetensors"
185
+ with metadata_file .open ("w" ) as f :
186
+ json .dump (metadata , f )
172
187
173
188
return (checkpoint_file_1 , checkpoint_file_2 )
174
189
@@ -328,136 +343,6 @@ def test_save_load_checkpoint_multiple_file(
328
343
assert len (output_state_dict_1 .keys ()) + 1 == len (orig_state_dict_1 .keys ())
329
344
assert len (output_state_dict_2 .keys ()) + 1 == len (orig_state_dict_2 .keys ())
330
345
331
- def test_load_save_checkpoint_single_file_with_dcp (
332
- self ,
333
- single_file_checkpointer : FullModelHFCheckpointer ,
334
- llama2_hf_checkpoints : Tuple [Path , Path ],
335
- ):
336
- """
337
- Test ``load_checkpoint`` and ``save_checkpoint`` method within the
338
- FullModelHFCheckpointer for a single checkpoint file.
339
-
340
- We test:
341
- * ``load_checkpoint`` loads the right sets of keys
342
- * Internal state of the checkpointer is correctly updated
343
- * Converted checkpoint can be loaded into the llama2 torchtune implementation
344
- * Saved checkpoint keys match the original checkpoint
345
- """
346
- single_file_checkpointer ._enable_dcp = True
347
- # Read the state dict directly from file using torch.load. This will be the state
348
- # dict we test against
349
- checkpoint_file , _ = llama2_hf_checkpoints
350
- orig_state_dict = safe_torch_load (checkpoint_file )
351
-
352
- # Converted state dict from the checkpointer
353
- state_dict = single_file_checkpointer .load_checkpoint ()
354
-
355
- # Check that we've loaded all the keys; We ignore inv_freq as is standard practice
356
- assert len (state_dict ["model" ].keys ()) + 1 == len (orig_state_dict .keys ())
357
-
358
- # the keys in original state dict should match up with the keys in the weight_map
359
- for key in orig_state_dict .keys ():
360
- if "inv_freq" in key :
361
- continue
362
- assert key in single_file_checkpointer ._weight_map
363
-
364
- # loading the state dict into the model implementation should work correctly
365
- model = llama2 .llama2 (
366
- vocab_size = _VOCAB_SIZE ,
367
- num_layers = 1 ,
368
- num_heads = _NUM_HEADS ,
369
- num_kv_heads = _NUM_KV_HEADS ,
370
- embed_dim = _DIM ,
371
- max_seq_len = 128 ,
372
- )
373
- model .load_state_dict (state_dict ["model" ])
374
-
375
- single_file_checkpointer .save_checkpoint (state_dict , epoch = 1 )
376
-
377
- # Reload the output checkpoint file and compare to the original checkpoint. This
378
- # assumes we know what the name of the file is. This is fine, breaking this logic
379
- # should be something we capture through this test
380
- output_file = Path .joinpath (
381
- checkpoint_file .parent .parent / "output_dir" ,
382
- "epoch_1" ,
383
- SHARD_FNAME .format (cpt_idx = "1" .zfill (5 ), num_shards = "1" .zfill (5 )),
384
- ).with_suffix (".safetensors" )
385
- output_state_dict = safe_torch_load (output_file )
386
-
387
- # We ignore inv_freq as is standard practice and so output dict will have one less key
388
- assert len (output_state_dict .keys ()) + 1 == len (orig_state_dict .keys ())
389
-
390
- def test_save_load_checkpoint_multiple_file_with_dcp (
391
- self ,
392
- multi_file_checkpointer : FullModelHFCheckpointer ,
393
- llama2_hf_checkpoints : Tuple [Path , Path ],
394
- ):
395
- """
396
- Test ``load_checkpoint`` method within the FullModelCheckpointer for multiple
397
- checkpoint file.
398
-
399
- We test:
400
- * ``load_checkpoint`` loads the right sets of keys
401
- * Internal state of the checkpointer is correctly updated
402
- * Converted checkpoint can be loaded into the llama2 torchtune implementation
403
- """
404
- multi_file_checkpointer ._enable_dcp = True
405
- # Read the state dict directly from files
406
- checkpoint_file_1 , checkpoint_file_2 = llama2_hf_checkpoints
407
- orig_state_dict_1 = safe_torch_load (checkpoint_file_1 )
408
- orig_state_dict_2 = safe_torch_load (checkpoint_file_2 )
409
-
410
- # merged state dict from checkpointer
411
- state_dict = multi_file_checkpointer .load_checkpoint ()
412
-
413
- # We ignore inv_freq as is standard practice
414
- assert len (state_dict ["model" ].keys ()) + 2 == len (
415
- orig_state_dict_1 .keys ()
416
- ) + len (orig_state_dict_2 .keys ())
417
-
418
- # the keys in the weight_map should match up with the keys in the weight_map
419
- for key in orig_state_dict_1 .keys ():
420
- if "inv_freq" in key :
421
- continue
422
- assert key in multi_file_checkpointer ._weight_map
423
-
424
- for key in orig_state_dict_2 .keys ():
425
- if "inv_freq" in key :
426
- continue
427
- assert key in multi_file_checkpointer ._weight_map
428
-
429
- # finally loading into the model should work
430
- model = llama2 .llama2 (
431
- vocab_size = _VOCAB_SIZE ,
432
- num_layers = 2 ,
433
- num_heads = _NUM_HEADS ,
434
- num_kv_heads = _NUM_KV_HEADS ,
435
- embed_dim = _DIM ,
436
- max_seq_len = 128 ,
437
- )
438
- model .load_state_dict (state_dict ["model" ])
439
-
440
- multi_file_checkpointer .save_checkpoint (state_dict , epoch = 1 )
441
-
442
- # Reload the output checkpoint file and compare to the original checkpoint. This
443
- # assumes we know what the name of the file is. This is fine, breaking this logic
444
- # should be something we capture through this test
445
- output_file_1 = Path .joinpath (
446
- checkpoint_file_1 .parent .parent / "output_dir" ,
447
- "epoch_1" ,
448
- SHARD_FNAME .format (cpt_idx = "1" .zfill (5 ), num_shards = "2" .zfill (5 )),
449
- ).with_suffix (".safetensors" )
450
- output_file_2 = Path .joinpath (
451
- checkpoint_file_2 .parent .parent / "output_dir" ,
452
- "epoch_1" ,
453
- SHARD_FNAME .format (cpt_idx = "2" .zfill (5 ), num_shards = "2" .zfill (5 )),
454
- ).with_suffix (".safetensors" )
455
- output_state_dict_1 = safe_torch_load (output_file_1 )
456
- output_state_dict_2 = safe_torch_load (output_file_2 )
457
-
458
- assert len (output_state_dict_1 .keys ()) + 1 == len (orig_state_dict_1 .keys ())
459
- assert len (output_state_dict_2 .keys ()) + 1 == len (orig_state_dict_2 .keys ())
460
-
461
346
def test_load_save_adapter_only (
462
347
self , tmp_path , single_file_checkpointer , llama2_hf_checkpoints
463
348
):
@@ -634,6 +519,77 @@ def test_save_checkpoint_in_peft_format(
634
519
actual_adapter_state_dict [k ], expected_adapter_state_dict [new_k ]
635
520
)
636
521
522
+ def test_save_load_checkpoint_multiple_file_with_dcp (
523
+ self ,
524
+ multi_file_checkpointer : FullModelHFCheckpointer ,
525
+ llama2_hf_checkpoints : Tuple [Path , Path ],
526
+ ):
527
+ """
528
+ Test ``load_checkpoint`` method within the FullModelCheckpointer for multiple
529
+ checkpoint file.
530
+
531
+ We test:
532
+ * ``load_checkpoint`` loads the right sets of keys
533
+ * Internal state of the checkpointer is correctly updated
534
+ * Converted checkpoint can be loaded into the llama2 torchtune implementation
535
+ """
536
+ multi_file_checkpointer ._enable_dcp = True
537
+ # Read the state dict directly from files
538
+ checkpoint_file_1 , checkpoint_file_2 = llama2_hf_checkpoints
539
+ orig_state_dict_1 = safe_torch_load (checkpoint_file_1 )
540
+ orig_state_dict_2 = safe_torch_load (checkpoint_file_2 )
541
+
542
+ # merged state dict from checkpointer
543
+ state_dict = multi_file_checkpointer .load_checkpoint ()
544
+
545
+ # We ignore inv_freq as is standard practice
546
+ assert len (state_dict ["model" ].keys ()) + 2 == len (
547
+ orig_state_dict_1 .keys ()
548
+ ) + len (orig_state_dict_2 .keys ())
549
+
550
+ # the keys in the weight_map should match up with the keys in the weight_map
551
+ for key in orig_state_dict_1 .keys ():
552
+ if "inv_freq" in key :
553
+ continue
554
+ assert key in multi_file_checkpointer ._weight_map
555
+
556
+ for key in orig_state_dict_2 .keys ():
557
+ if "inv_freq" in key :
558
+ continue
559
+ assert key in multi_file_checkpointer ._weight_map
560
+
561
+ # finally loading into the model should work
562
+ model = llama2 .llama2 (
563
+ vocab_size = _VOCAB_SIZE ,
564
+ num_layers = 2 ,
565
+ num_heads = _NUM_HEADS ,
566
+ num_kv_heads = _NUM_KV_HEADS ,
567
+ embed_dim = _DIM ,
568
+ max_seq_len = 128 ,
569
+ )
570
+ model .load_state_dict (state_dict ["model" ])
571
+
572
+ multi_file_checkpointer .save_checkpoint (state_dict , epoch = 3 )
573
+
574
+ # Reload the output checkpoint file and compare to the original checkpoint. This
575
+ # assumes we know what the name of the file is. This is fine, breaking this logic
576
+ # should be something we capture through this test
577
+ output_file_1 = Path .joinpath (
578
+ checkpoint_file_1 .parent .parent / "output_dir" ,
579
+ "epoch_3" ,
580
+ SHARD_FNAME .format (cpt_idx = "1" .zfill (5 ), num_shards = "2" .zfill (5 )),
581
+ ).with_suffix (".safetensors" )
582
+ output_file_2 = Path .joinpath (
583
+ checkpoint_file_2 .parent .parent / "output_dir" ,
584
+ "epoch_3" ,
585
+ SHARD_FNAME .format (cpt_idx = "2" .zfill (5 ), num_shards = "2" .zfill (5 )),
586
+ ).with_suffix (".safetensors" )
587
+ output_state_dict_1 = safe_torch_load (output_file_1 )
588
+ output_state_dict_2 = safe_torch_load (output_file_2 )
589
+
590
+ assert len (output_state_dict_1 .keys ()) + 1 == len (orig_state_dict_1 .keys ())
591
+ assert len (output_state_dict_2 .keys ()) + 1 == len (orig_state_dict_2 .keys ())
592
+
637
593
638
594
class TestHFMistralRewardModelFullModelCheckpointer :
639
595
@pytest .fixture
@@ -717,6 +673,12 @@ def mistral_reward_model_hf_checkpoint(self, tmp_path, state_dict):
717
673
config_file = Path .joinpath (checkpoint_dir , "config.json" )
718
674
with config_file .open ("w" ) as f :
719
675
json .dump (config , f )
676
+ metadata_file = Path .joinpath (checkpoint_dir , "model.safetensors.index.json" )
677
+ metadata = {"weight_map" : {}}
678
+ for key in state_dict .keys ():
679
+ metadata ["weight_map" ][key ] = key
680
+ with metadata_file .open ("w" ) as f :
681
+ json .dump (metadata , f )
720
682
721
683
return checkpoint_file
722
684
@@ -792,6 +754,7 @@ def test_load_save_checkpoint_single_file(
792
754
793
755
assert len (output_state_dict .keys ()) == len (orig_state_dict .keys ()) - 1
794
756
757
+ '''
795
758
def test_load_save_checkpoint_single_file_with_dcp(
796
759
self,
797
760
single_file_checkpointer: FullModelHFCheckpointer,
@@ -851,6 +814,7 @@ def test_load_save_checkpoint_single_file_with_dcp(
851
814
output_state_dict = safe_torch_load(output_file)
852
815
853
816
assert len(output_state_dict.keys()) == len(orig_state_dict.keys()) - 1
817
+ '''
854
818
855
819
856
820
class TestHFGemmaFullModelCheckpointer :
@@ -922,6 +886,9 @@ def gemma_hf_checkpoint(self, tmp_path, state_dict):
922
886
checkpoint_file = checkpoint_dir / "gemma_hf_checkpoint.pt"
923
887
924
888
torch .save (state_dict , checkpoint_file )
889
+ safetensors .torch .save_file (
890
+ state_dict , checkpoint_dir / "model-00001-of-00001.safetensors"
891
+ )
925
892
926
893
config = {
927
894
"hidden_size" : _DIM ,
@@ -934,15 +901,23 @@ def gemma_hf_checkpoint(self, tmp_path, state_dict):
934
901
with config_file .open ("w" ) as f :
935
902
json .dump (config , f )
936
903
904
+ # metadata file for dcp
905
+ metadata_file = Path .joinpath (checkpoint_dir , "model.safetensors.index.json" )
906
+ metadata = {"weight_map" : {}}
907
+ for key in state_dict .keys ():
908
+ metadata ["weight_map" ][key ] = "model-00001-of-00001.safetensors"
909
+ with metadata_file .open ("w" ) as f :
910
+ json .dump (metadata , f )
911
+
937
912
return checkpoint_file
938
913
939
914
@pytest .fixture
940
915
def single_file_checkpointer (
941
916
self , gemma_hf_checkpoint , tmp_path
942
917
) -> FullModelHFCheckpointer :
943
918
checkpoint_file = gemma_hf_checkpoint
944
- checkpoint_dir = str ( Path .joinpath (tmp_path , "checkpoint_dir" ) )
945
- output_dir = str ( Path .joinpath (tmp_path , "output_dir" ) )
919
+ checkpoint_dir = Path .joinpath (tmp_path , "checkpoint_dir" )
920
+ output_dir = Path .joinpath (tmp_path , "output_dir" )
946
921
return FullModelHFCheckpointer (
947
922
checkpoint_dir = checkpoint_dir ,
948
923
checkpoint_files = [checkpoint_file ],
@@ -1009,7 +984,8 @@ def test_load_save_checkpoint_single_file(
1009
984
1010
985
assert len (output_state_dict .keys ()) == len (orig_state_dict .keys ())
1011
986
1012
- def test_load_save_checkpoint_single_file_dcp (
987
+ '''
988
+ def test_load_save_checkpoint_single_file_with_dcp(
1013
989
self,
1014
990
single_file_checkpointer: FullModelHFCheckpointer,
1015
991
gemma_hf_checkpoint: Path,
@@ -1068,3 +1044,4 @@ def test_load_save_checkpoint_single_file_dcp(
1068
1044
output_state_dict = safe_torch_load(output_file)
1069
1045
1070
1046
assert len(output_state_dict.keys()) == len(orig_state_dict.keys())
1047
+ '''
0 commit comments