@@ -194,13 +194,12 @@ def _export_tensor_to_file(self, expert_maps, expert_map_record_path: str):
194194 json .dump (record , f , indent = 4 )
195195
196196 def do_update_expert_map (self , layer_id , updated_expert_map ):
197- pad_len = self .expert_map_per_layer [layer_id ].shape [0 ] - updated_expert_map .shape [0 ]
198- updated_expert_map_padded = torch .nn .functional .pad (
199- updated_expert_map ,
200- pad = (0 ,pad_len ),
201- mode = 'constant' ,
202- value = - 1
203- )
197+ pad_len = self .expert_map_per_layer [layer_id ].shape [
198+ 0 ] - updated_expert_map .shape [0 ]
199+ updated_expert_map_padded = torch .nn .functional .pad (updated_expert_map ,
200+ pad = (0 , pad_len ),
201+ mode = 'constant' ,
202+ value = - 1 )
204203 self .expert_map_per_layer [layer_id ].copy_ (updated_expert_map_padded )
205204 self .expert_map_per_layer_cpu [layer_id ].copy_ (updated_expert_map )
206205
@@ -214,14 +213,15 @@ def do_update_expert_weight(self, layer_id, local_expert_to_replace,
214213
215214 def do_update_log2phy_map (self , layer_id , updated_log2phy_map ):
216215 if self .log2phy_map_per_layer [layer_id ] is not None :
217- pad_len = self .log2phy_map_per_layer [layer_id ].shape [0 ] - updated_log2phy_map .shape [0 ]
216+ pad_len = self .log2phy_map_per_layer [layer_id ].shape [
217+ 0 ] - updated_log2phy_map .shape [0 ]
218218 updated_log2phy_map_padded = torch .nn .functional .pad (
219- updated_log2phy_map ,
220- pad = (0 ,pad_len ),
221- mode = 'constant' ,
222- value = - 1
223- )
224- self . log2phy_map_per_layer [ layer_id ]. copy_ ( updated_log2phy_map_padded )
219+ updated_log2phy_map ,
220+ pad = (0 , pad_len ),
221+ mode = 'constant' ,
222+ value = - 1 )
223+ self . log2phy_map_per_layer [ layer_id ]. copy_ (
224+ updated_log2phy_map_padded )
225225
226226 def global2local (self , placement : torch .Tensor ,
227227 E_local : int ) -> torch .Tensor :
0 commit comments