@@ -444,16 +444,14 @@ def load_weights_fused_qkv_linear(self, module: Linear,
444444
445445 copy_weight (module .weight_scale , max (weight_scale ))
446446
447- q_weight = q_weight .to (module .dtype ) * weight_scale [0 ]
448- k_weight = k_weight .to (module .dtype ) * weight_scale [1 ]
449- v_weight = v_weight .to (module .dtype ) * weight_scale [2 ]
447+ # use in-place multiplication and division to avoid extra memory allocation
448+ q_weight = q_weight .to (module .dtype ).mul_ (weight_scale [0 ])
449+ k_weight = k_weight .to (module .dtype ).mul_ (weight_scale [1 ])
450+ v_weight = v_weight .to (module .dtype ).mul_ (weight_scale [2 ])
450451
451452 fused_weight = torch .cat ((q_weight , k_weight , v_weight ))
452- if module .weight_scale .device != fused_weight .device :
453- module .weight_scale = Parameter (
454- module .weight_scale .data .to (fused_weight .device ))
455- fused_weight = (fused_weight / module .weight_scale ).to (
456- torch .float8_e4m3fn )
453+ fused_weight = fused_weight .div_ (
454+ module .weight_scale .to (fused_weight .device )).to (torch .float8_e4m3fn )
457455 copy_weight (module .weight , fused_weight )
458456
459457 # Load k and v scales, used for NVFP4 KV cache
@@ -486,14 +484,12 @@ def load_weights_fused_gate_up_linear(self, module: Linear,
486484 gate_weight , up_weight = load_weights_fused_gate_up_helper (
487485 module , weights )
488486
489- gate_weight = gate_weight .to (module .dtype ) * weight_scale [0 ]
490- up_weight = up_weight .to (module .dtype ) * weight_scale [1 ]
487+ # use in-place multiplication and division to avoid extra memory allocation
488+ gate_weight = gate_weight .to (module .dtype ).mul_ (weight_scale [0 ])
489+ up_weight = up_weight .to (module .dtype ).mul_ (weight_scale [1 ])
491490 fused_weight = torch .cat ((gate_weight , up_weight ))
492- if module .weight_scale .device != fused_weight .device :
493- module .weight_scale = Parameter (
494- module .weight_scale .data .to (fused_weight .device ))
495- fused_weight = (fused_weight / module .weight_scale ).to (
496- torch .float8_e4m3fn )
491+ fused_weight = fused_weight .div_ (
492+ module .weight_scale .to (fused_weight .device )).to (torch .float8_e4m3fn )
497493 copy_weight (module .weight , fused_weight )
498494
499495
0 commit comments