Skip to content

Commit 422ae7c

Browse files
committed
Adapt comments and optimize memory footprint.
Signed-off-by: Simeng Liu <[email protected]>
1 parent 2311624 commit 422ae7c

File tree

1 file changed

+11
-27
lines changed

1 file changed

+11
-27
lines changed

tensorrt_llm/_torch/modules/linear.py

Lines changed: 11 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -444,22 +444,14 @@ def load_weights_fused_qkv_linear(self, module: Linear,
444444

445445
copy_weight(module.weight_scale, max(weight_scale))
446446

447-
q_weight = q_weight.to(module.dtype) * weight_scale[0]
448-
k_weight = k_weight.to(module.dtype) * weight_scale[1]
449-
v_weight = v_weight.to(module.dtype) * weight_scale[2]
447+
# use in-place multiplication and division to avoid extra memory allocation
448+
q_weight = q_weight.to(module.dtype).mul_(weight_scale[0])
449+
k_weight = k_weight.to(module.dtype).mul_(weight_scale[1])
450+
v_weight = v_weight.to(module.dtype).mul_(weight_scale[2])
450451

451452
fused_weight = torch.cat((q_weight, k_weight, v_weight))
452-
original_device = module.weight_scale.device
453-
# module.weight_scale and fused_weight must be on the same device for division operation
454-
# Reset the device of module.weight_scale to the original device if changed
455-
if module.weight_scale.device != fused_weight.device:
456-
module.weight_scale = Parameter(
457-
module.weight_scale.data.to(fused_weight.device))
458-
fused_weight = (fused_weight / module.weight_scale).to(
459-
torch.float8_e4m3fn)
460-
if original_device != module.weight_scale.device:
461-
module.weight_scale = Parameter(
462-
module.weight_scale.data.to(original_device))
453+
fused_weight = fused_weight.div_(
454+
module.weight_scale.to(fused_weight.device)).to(torch.float8_e4m3fn)
463455
copy_weight(module.weight, fused_weight)
464456

465457
# Load k and v scales, used for NVFP4 KV cache
@@ -492,20 +484,12 @@ def load_weights_fused_gate_up_linear(self, module: Linear,
492484
gate_weight, up_weight = load_weights_fused_gate_up_helper(
493485
module, weights)
494486

495-
gate_weight = gate_weight.to(module.dtype) * weight_scale[0]
496-
up_weight = up_weight.to(module.dtype) * weight_scale[1]
487+
# use in-place multiplication and division to avoid extra memory allocation
488+
gate_weight = gate_weight.to(module.dtype).mul_(weight_scale[0])
489+
up_weight = up_weight.to(module.dtype).mul_(weight_scale[1])
497490
fused_weight = torch.cat((gate_weight, up_weight))
498-
original_device = module.weight_scale.device
499-
# module.weight_scale and fused_weight must be on the same device for division operation
500-
# Reset the device of module.weight_scale to the original device if changed
501-
if module.weight_scale.device != fused_weight.device:
502-
module.weight_scale = Parameter(
503-
module.weight_scale.data.to(fused_weight.device))
504-
fused_weight = (fused_weight / module.weight_scale).to(
505-
torch.float8_e4m3fn)
506-
if original_device != module.weight_scale.device:
507-
module.weight_scale = Parameter(
508-
module.weight_scale.data.to(original_device))
491+
fused_weight = fused_weight.div_(
492+
module.weight_scale.to(fused_weight.device)).to(torch.float8_e4m3fn)
509493
copy_weight(module.weight, fused_weight)
510494

511495

0 commit comments

Comments
 (0)