Skip to content

Commit

Permalink
[PyTorch] Add weights_only=False for torch.load (#1374)
Browse files Browse the repository at this point in the history
add weights_only=False for torch.load

Signed-off-by: Charlene Yang <[email protected]>
  • Loading branch information
cyanguwa authored Dec 18, 2024
1 parent 7f5c784 commit 83dac8c
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion tests/pytorch/test_float8tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def test_serialization(
del x_fp8, byte_stream

# Deserialize tensor
x_fp8 = torch.load(io.BytesIO(x_bytes))
x_fp8 = torch.load(io.BytesIO(x_bytes), weights_only=False)
del x_bytes

# Check results
Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,7 +1101,7 @@ def get_model(dtype, config):

del block
block = get_model(dtype, config)
block.load_state_dict(torch.load(path))
block.load_state_dict(torch.load(path, weights_only=False))
torch.set_rng_state(_cpu_rng_state_new)
torch.cuda.set_rng_state(_cuda_rng_state_new)

Expand Down
6 changes: 3 additions & 3 deletions tests/pytorch/test_torch_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def forward(self, inp, weight):
torch.save(model_in.state_dict(), tmp_filename)

model_out = Test_TE_Export(precision, True)
model_out.load_state_dict(torch.load(tmp_filename))
model_out.load_state_dict(torch.load(tmp_filename, weights_only=False))
model_out.eval()

# scaling fwd
Expand Down Expand Up @@ -263,7 +263,7 @@ def test_fp8_model_checkpoint(
# to load the fp8 metadata before loading tensors.
#
# Load checkpoint
model.load_state_dict(torch.load(io.BytesIO(model_bytes)))
model.load_state_dict(torch.load(io.BytesIO(model_bytes), weights_only=False))
del model_bytes

# Check that loaded model matches saved model
Expand Down Expand Up @@ -450,7 +450,7 @@ def train_step(
torch.testing.assert_close(m_model.scale_inv, m_ref["scale_inv"], **exact_tols)

# Load checkpoint
model.load_state_dict(torch.load(io.BytesIO(model_bytes)))
model.load_state_dict(torch.load(io.BytesIO(model_bytes), weights_only=False))
del model_bytes

# Check that new model's FP8 metadata matches saved model
Expand Down

0 comments on commit 83dac8c

Please sign in to comment.