Skip to content

Commit 2311624

Browse files
committed
[https://nvbugs/5599086][fix] Fix FP8 Linear module for spark
Signed-off-by: Simeng Liu <[email protected]>
1 parent 9c4432f commit 2311624

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

tensorrt_llm/_torch/modules/linear.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,11 +449,17 @@ def load_weights_fused_qkv_linear(self, module: Linear,
449449
v_weight = v_weight.to(module.dtype) * weight_scale[2]
450450

451451
fused_weight = torch.cat((q_weight, k_weight, v_weight))
452+
original_device = module.weight_scale.device
453+
# module.weight_scale and fused_weight must be on the same device for division operation
454+
# Reset the device of module.weight_scale to the original device if changed
452455
if module.weight_scale.device != fused_weight.device:
453456
module.weight_scale = Parameter(
454457
module.weight_scale.data.to(fused_weight.device))
455458
fused_weight = (fused_weight / module.weight_scale).to(
456459
torch.float8_e4m3fn)
460+
if original_device != module.weight_scale.device:
461+
module.weight_scale = Parameter(
462+
module.weight_scale.data.to(original_device))
457463
copy_weight(module.weight, fused_weight)
458464

459465
# Load k and v scales, used for NVFP4 KV cache
@@ -489,11 +495,17 @@ def load_weights_fused_gate_up_linear(self, module: Linear,
489495
gate_weight = gate_weight.to(module.dtype) * weight_scale[0]
490496
up_weight = up_weight.to(module.dtype) * weight_scale[1]
491497
fused_weight = torch.cat((gate_weight, up_weight))
498+
original_device = module.weight_scale.device
499+
# module.weight_scale and fused_weight must be on the same device for division operation
500+
# Reset the device of module.weight_scale to the original device if changed
492501
if module.weight_scale.device != fused_weight.device:
493502
module.weight_scale = Parameter(
494503
module.weight_scale.data.to(fused_weight.device))
495504
fused_weight = (fused_weight / module.weight_scale).to(
496505
torch.float8_e4m3fn)
506+
if original_device != module.weight_scale.device:
507+
module.weight_scale = Parameter(
508+
module.weight_scale.data.to(original_device))
497509
copy_weight(module.weight, fused_weight)
498510

499511

0 commit comments

Comments
 (0)