@@ -34,7 +34,8 @@ def unflatten_tensor_state_dict(
3434 '_data': {
3535 'block_size': [1,32],
3636 ...
37- }
37+ },
38+ '_tensor_data_names': ['qdata', 'scale']
3839 }
3940 '0.bias': {
4041 '_type': 'torch.Tensor',
@@ -66,33 +67,52 @@ def unflatten_tensor_state_dict(
6667
6768 tensor_names = json .loads (metadata ["tensor_names" ])
6869 result = {}
69-
70+ leftover_state_dict = tensors_data_dict . copy ()
7071 for tensor_name in tensor_names :
72+ processed_tensors = []
73+
7174 module_fqn , weight_name = tensor_name .rsplit ("." , 1 )
7275
7376 prefix = f"{ module_fqn } ._{ weight_name } _"
7477 tensor_tensors = {}
78+
7579 for key , value in combined_data .items ():
7680 if key .startswith (prefix ):
7781 # Remove the prefix
7882 tensor_tensors [key [len (prefix ) :]] = value
83+ full_tensor_name_in_state_dict = key
84+ processed_tensors .append (
85+ full_tensor_name_in_state_dict
86+ ) # for tensor subclass
7987
8088 tensor_metadata = json .loads (metadata .get (tensor_name ))
8189 tensor_type = tensor_metadata .get ("_type" )
90+ complete_tensor_data = tensor_metadata .get ("_tensor_data_names" )
8291
8392 if tensor_type in ALLOWED_TENSORS_SUBCLASSES :
84- if not tensor_tensors :
85- # we allow the option of loading in state_dict info for a single tensor
86- # if tensor state dict info is not loaded in yet, we wait for it to be provided
87- # in a future call
93+ # if not all tensor data is present (ie missing qdata) we wait for it
94+ # to be loaded in from a future call
95+ if not len (tensor_tensors ) is len (complete_tensor_data ):
8896 continue
8997 tensor_metadata ["_data" ].update (tensor_tensors )
9098 result [tensor_name ] = object_from_dict (tensor_metadata )
9199 elif tensor_type == torch .Tensor .__name__ :
100+ # we allow the option of loading in state_dict info for a single tensor
101+ # if tensor state dict info is not loaded in yet, we wait for it to be provided
102+ # in a future call
103+ if tensor_name not in tensors_data_dict .keys ():
104+ continue
92105 result [tensor_name ] = tensors_data_dict [tensor_name ]
106+ processed_tensors .append (
107+ tensor_name
108+ ) # add here because key for torch.Tensor has no prefix
93109 else :
94110 raise ValueError (f"Unsupported tensor type: { tensor_type } " )
95- return result
111+
112+ for tensor_name in processed_tensors :
113+ del leftover_state_dict [tensor_name ]
114+
115+ return leftover_state_dict , result
96116
97117
98118def flatten_tensor_state_dict (
@@ -125,7 +145,8 @@ def flatten_tensor_state_dict(
125145 '_data': {
126146 'block_size': [1,32],
127147 ...
128- }
148+ },
149+ '_tensor_data_names': ['qdata', 'scale']
129150 }
130151 '0.bias': {
131152 '_type': 'torch.Tensor',
0 commit comments