@@ -449,11 +449,17 @@ def load_weights_fused_qkv_linear(self, module: Linear,
449449 v_weight = v_weight .to (module .dtype ) * weight_scale [2 ]
450450
451451 fused_weight = torch .cat ((q_weight , k_weight , v_weight ))
452+ original_device = module .weight_scale .device
453+ # module.weight_scale and fused_weight must be on the same device for division operation
454+ # Reset the device of module.weight_scale to the original device if changed
452455 if module .weight_scale .device != fused_weight .device :
453456 module .weight_scale = Parameter (
454457 module .weight_scale .data .to (fused_weight .device ))
455458 fused_weight = (fused_weight / module .weight_scale ).to (
456459 torch .float8_e4m3fn )
460+ if original_device != module .weight_scale .device :
461+ module .weight_scale = Parameter (
462+ module .weight_scale .data .to (original_device ))
457463 copy_weight (module .weight , fused_weight )
458464
459465 # Load k and v scales, used for NVFP4 KV cache
@@ -489,11 +495,17 @@ def load_weights_fused_gate_up_linear(self, module: Linear,
489495 gate_weight = gate_weight .to (module .dtype ) * weight_scale [0 ]
490496 up_weight = up_weight .to (module .dtype ) * weight_scale [1 ]
491497 fused_weight = torch .cat ((gate_weight , up_weight ))
498+ original_device = module .weight_scale .device
499+ # module.weight_scale and fused_weight must be on the same device for division operation
500+ # Reset the device of module.weight_scale to the original device if changed
492501 if module .weight_scale .device != fused_weight .device :
493502 module .weight_scale = Parameter (
494503 module .weight_scale .data .to (fused_weight .device ))
495504 fused_weight = (fused_weight / module .weight_scale ).to (
496505 torch .float8_e4m3fn )
506+ if original_device != module .weight_scale .device :
507+ module .weight_scale = Parameter (
508+ module .weight_scale .data .to (original_device ))
497509 copy_weight (module .weight , fused_weight )
498510
499511
0 commit comments