@@ -444,22 +444,14 @@ def load_weights_fused_qkv_linear(self, module: Linear,
444444
445445 copy_weight (module .weight_scale , max (weight_scale ))
446446
447- q_weight = q_weight .to (module .dtype ) * weight_scale [0 ]
448- k_weight = k_weight .to (module .dtype ) * weight_scale [1 ]
449- v_weight = v_weight .to (module .dtype ) * weight_scale [2 ]
447+ # use in-place multiplication and division to avoid extra memory allocation
448+ q_weight = q_weight .to (module .dtype ).mul_ (weight_scale [0 ])
449+ k_weight = k_weight .to (module .dtype ).mul_ (weight_scale [1 ])
450+ v_weight = v_weight .to (module .dtype ).mul_ (weight_scale [2 ])
450451
451452 fused_weight = torch .cat ((q_weight , k_weight , v_weight ))
452- original_device = module .weight_scale .device
453- # module.weight_scale and fused_weight must be on the same device for division operation
454- # Reset the device of module.weight_scale to the original device if changed
455- if module .weight_scale .device != fused_weight .device :
456- module .weight_scale = Parameter (
457- module .weight_scale .data .to (fused_weight .device ))
458- fused_weight = (fused_weight / module .weight_scale ).to (
459- torch .float8_e4m3fn )
460- if original_device != module .weight_scale .device :
461- module .weight_scale = Parameter (
462- module .weight_scale .data .to (original_device ))
453+ fused_weight = fused_weight .div_ (
454+ module .weight_scale .to (fused_weight .device )).to (torch .float8_e4m3fn )
463455 copy_weight (module .weight , fused_weight )
464456
465457 # Load k and v scales, used for NVFP4 KV cache
@@ -492,20 +484,12 @@ def load_weights_fused_gate_up_linear(self, module: Linear,
492484 gate_weight , up_weight = load_weights_fused_gate_up_helper (
493485 module , weights )
494486
495- gate_weight = gate_weight .to (module .dtype ) * weight_scale [0 ]
496- up_weight = up_weight .to (module .dtype ) * weight_scale [1 ]
487+ # use in-place multiplication and division to avoid extra memory allocation
488+ gate_weight = gate_weight .to (module .dtype ).mul_ (weight_scale [0 ])
489+ up_weight = up_weight .to (module .dtype ).mul_ (weight_scale [1 ])
497490 fused_weight = torch .cat ((gate_weight , up_weight ))
498- original_device = module .weight_scale .device
499- # module.weight_scale and fused_weight must be on the same device for division operation
500- # Reset the device of module.weight_scale to the original device if changed
501- if module .weight_scale .device != fused_weight .device :
502- module .weight_scale = Parameter (
503- module .weight_scale .data .to (fused_weight .device ))
504- fused_weight = (fused_weight / module .weight_scale ).to (
505- torch .float8_e4m3fn )
506- if original_device != module .weight_scale .device :
507- module .weight_scale = Parameter (
508- module .weight_scale .data .to (original_device ))
491+ fused_weight = fused_weight .div_ (
492+ module .weight_scale .to (fused_weight .device )).to (torch .float8_e4m3fn )
509493 copy_weight (module .weight , fused_weight )
510494
511495
0 commit comments