@@ -170,23 +170,23 @@ def _convert_state_dict(m, state_dict_pt, prefix=""):
170170 state_dict_ms = {}
171171 while state_dict_pt :
172172 name_pt , data_pt = state_dict_pt .popitem ()
173- # TODO For models contains a lot of paramters, going through state_dict and model at the same time
174- # would cause performance decrease significantly. This part for aligning prefix would need to be optimized.
175- # for name, param in m.parameters_and_names():
176- # name_ms = param.name
177- # length = len(prefix) + 1
178- # if name_pt.startswith(prefix):
179- # # When name_ms and name_pt match and name_pt has prefix, name_pt would be sliced
180- # if name_ms.rsplit(".", 1)[0] == name_pt.rsplit(".", 1)[0][length:] or name_ms == name_pt[length:]:
181- # name_pt = name_pt[length:]
182- # elif not name_pt.startswith(prefix):
183- # # When name_ms and name_pt match and name_ms has prefix, prefix would be added to name_pt
184- # if name_pt.rsplit(".", 1)[0] == name_ms.rsplit(".", 1)[0][length:] or name_pt == name_ms[length:]:
185- # name_pt = ".".join([prefix, name_pt])
186173 name_ms , data_mapping = pt2ms_mappings .get (name_pt , (name_pt , lambda x : x ))
187174 data_ms = data_mapping (data_pt )
188175 if name_ms is not None :
189176 state_dict_ms [name_ms ] = data_ms
177+
178+ length = len (prefix ) + 1
179+ model_ckpt_key = m .state_dict ().keys ()
180+ for key in state_dict_ms .keys ():
181+ # When model name and state dict name match and state dict name has prefix, state dict name would be sliced
182+ if key [length :] in model_ckpt_key :
183+ data_ms = state_dict_ms .pop (key )
184+ state_dict_ms [key [length :]] = data_ms
185+ # When model name and state dict name match and model name has prefix, prefix would be added to state dict name
186+ elif "." .join ([prefix , key ]) in model_ckpt_key :
187+ data_ms = state_dict_ms .pop (key )
188+ state_dict_ms ["." .join ([prefix , key ])] = data_ms
189+
190190 return state_dict_ms
191191
192192
0 commit comments