Skip to content

Commit 834a780

Browse files
authored
[https://nvbugs/5599086][fix] Fix FP8 Linear module for spark (#8707)
Signed-off-by: Simeng Liu <[email protected]>
1 parent 45b36cc commit 834a780

File tree

1 file changed

+11
-15
lines changed

1 file changed

+11
-15
lines changed

tensorrt_llm/_torch/modules/linear.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -444,16 +444,14 @@ def load_weights_fused_qkv_linear(self, module: Linear,
444444

445445
copy_weight(module.weight_scale, max(weight_scale))
446446

447-
q_weight = q_weight.to(module.dtype) * weight_scale[0]
448-
k_weight = k_weight.to(module.dtype) * weight_scale[1]
449-
v_weight = v_weight.to(module.dtype) * weight_scale[2]
447+
# use in-place multiplication and division to avoid extra memory allocation
448+
q_weight = q_weight.to(module.dtype).mul_(weight_scale[0])
449+
k_weight = k_weight.to(module.dtype).mul_(weight_scale[1])
450+
v_weight = v_weight.to(module.dtype).mul_(weight_scale[2])
450451

451452
fused_weight = torch.cat((q_weight, k_weight, v_weight))
452-
if module.weight_scale.device != fused_weight.device:
453-
module.weight_scale = Parameter(
454-
module.weight_scale.data.to(fused_weight.device))
455-
fused_weight = (fused_weight / module.weight_scale).to(
456-
torch.float8_e4m3fn)
453+
fused_weight = fused_weight.div_(
454+
module.weight_scale.to(fused_weight.device)).to(torch.float8_e4m3fn)
457455
copy_weight(module.weight, fused_weight)
458456

459457
# Load k and v scales, used for NVFP4 KV cache
@@ -486,14 +484,12 @@ def load_weights_fused_gate_up_linear(self, module: Linear,
486484
gate_weight, up_weight = load_weights_fused_gate_up_helper(
487485
module, weights)
488486

489-
gate_weight = gate_weight.to(module.dtype) * weight_scale[0]
490-
up_weight = up_weight.to(module.dtype) * weight_scale[1]
487+
# use in-place multiplication and division to avoid extra memory allocation
488+
gate_weight = gate_weight.to(module.dtype).mul_(weight_scale[0])
489+
up_weight = up_weight.to(module.dtype).mul_(weight_scale[1])
491490
fused_weight = torch.cat((gate_weight, up_weight))
492-
if module.weight_scale.device != fused_weight.device:
493-
module.weight_scale = Parameter(
494-
module.weight_scale.data.to(fused_weight.device))
495-
fused_weight = (fused_weight / module.weight_scale).to(
496-
torch.float8_e4m3fn)
491+
fused_weight = fused_weight.div_(
492+
module.weight_scale.to(fused_weight.device)).to(torch.float8_e4m3fn)
497493
copy_weight(module.weight, fused_weight)
498494

499495

0 commit comments

Comments
 (0)