Skip to content

Commit cea5523

Browse files
committed
changes
1 parent e2aab90 commit cea5523

File tree

3 files changed

+48
-10
lines changed

3 files changed

+48
-10
lines changed

test/prototype/safetensors/test_safetensors_support.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,10 @@ def test_safetensors(self, config, act_pre_scale=False):
7474

7575
save_file(tensors_data_dict, f.name, metadata=metadata)
7676
tensors_data_dict, metadata = load_data(file_path=f.name, device="cuda")
77-
reconstructed_dict = unflatten_tensor_state_dict(
77+
leftover_tensor_data_dict, reconstructed_dict = unflatten_tensor_state_dict(
7878
tensors_data_dict, metadata
7979
)
80+
assert not leftover_tensor_data_dict
8081

8182
model = torch.nn.Sequential(
8283
torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")

torchao/prototype/safetensors/safetensors_support.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ def unflatten_tensor_state_dict(
3434
'_data': {
3535
'block_size': [1,32],
3636
...
37-
}
37+
},
38+
'_tensor_data_names': ['qdata', 'scale']
3839
}
3940
'0.bias': {
4041
'_type': 'torch.Tensor',
@@ -66,33 +67,52 @@ def unflatten_tensor_state_dict(
6667

6768
tensor_names = json.loads(metadata["tensor_names"])
6869
result = {}
69-
70+
leftover_state_dict = tensors_data_dict.copy()
7071
for tensor_name in tensor_names:
72+
processed_tensors = []
73+
7174
module_fqn, weight_name = tensor_name.rsplit(".", 1)
7275

7376
prefix = f"{module_fqn}._{weight_name}_"
7477
tensor_tensors = {}
78+
7579
for key, value in combined_data.items():
7680
if key.startswith(prefix):
7781
# Remove the prefix
7882
tensor_tensors[key[len(prefix) :]] = value
83+
full_tensor_name_in_state_dict = key
84+
processed_tensors.append(
85+
full_tensor_name_in_state_dict
86+
) # for tensor subclass
7987

8088
tensor_metadata = json.loads(metadata.get(tensor_name))
8189
tensor_type = tensor_metadata.get("_type")
90+
complete_tensor_data = tensor_metadata.get("_tensor_data_names")
8291

8392
if tensor_type in ALLOWED_TENSORS_SUBCLASSES:
84-
if not tensor_tensors:
85-
# we allow the option of loading in state_dict info for a single tensor
86-
# if tensor state dict info is not loaded in yet, we wait for it to be provided
87-
# in a future call
93+
# if not all tensor data is present (ie missing qdata) we wait for it
94+
# to be loaded in from a future call
95+
if not len(tensor_tensors) is len(complete_tensor_data):
8896
continue
8997
tensor_metadata["_data"].update(tensor_tensors)
9098
result[tensor_name] = object_from_dict(tensor_metadata)
9199
elif tensor_type == torch.Tensor.__name__:
100+
# we allow the option of loading in state_dict info for a single tensor
101+
# if tensor state dict info is not loaded in yet, we wait for it to be provided
102+
# in a future call
103+
if tensor_name not in tensors_data_dict.keys():
104+
continue
92105
result[tensor_name] = tensors_data_dict[tensor_name]
106+
processed_tensors.append(
107+
tensor_name
108+
) # add here because key for torch.Tensor has no prefix
93109
else:
94110
raise ValueError(f"Unsupported tensor type: {tensor_type}")
95-
return result
111+
112+
for tensor_name in processed_tensors:
113+
del leftover_state_dict[tensor_name]
114+
115+
return leftover_state_dict, result
96116

97117

98118
def flatten_tensor_state_dict(
@@ -125,7 +145,8 @@ def flatten_tensor_state_dict(
125145
'_data': {
126146
'block_size': [1,32],
127147
...
128-
}
148+
},
149+
'_tensor_data_names': ['qdata', 'scale']
129150
}
130151
'0.bias': {
131152
'_type': 'torch.Tensor',

torchao/prototype/safetensors/safetensors_utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,23 @@ def default(self, o):
6060
encoded_attribute = self.encode_value(attribute)
6161
tensor_attr_dict[tensor_attribute_name] = encoded_attribute
6262

63-
return {"_type": o.__class__.__name__, "_data": tensor_attr_dict}
63+
optional_tensor_data = (
64+
o.optional_tensor_data_names
65+
if hasattr(o, "optional_tensor_data_names")
66+
else []
67+
)
68+
all_tensor_data = optional_tensor_data + o.tensor_data_names
69+
70+
_tensor_data_names = []
71+
for tensor_data_name in all_tensor_data:
72+
if getattr(o, tensor_data_name) is not None:
73+
_tensor_data_names.append(tensor_data_name)
74+
75+
return {
76+
"_type": o.__class__.__name__,
77+
"_data": tensor_attr_dict,
78+
"_tensor_data_names": _tensor_data_names,
79+
}
6480

6581
if hasattr(o, "_fields") and hasattr(
6682
o, "_asdict"

0 commit comments

Comments
 (0)