Skip to content

Commit ea640a1

Browse files
authored
[https://nvbugs/5550283][fix] update test case to call post quantization explicitly due to code refactor (#8188)
Signed-off-by: xxi <[email protected]>
1 parent a9a0969 commit ea640a1

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

tests/unittest/_torch/modules/test_fused_moe.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,7 @@ def per_rank_test_fused_moe_alltoall_fp8_blockwise(job_id):
740740
)
741741
alltoall_model.to("cuda")
742742
alltoall_model.load_weights([weights])
743+
alltoall_model.post_load_weights()
743744

744745
# Use DeepGemmFusedMoE as reference
745746
ref_model = DeepGemmFusedMoE(
@@ -755,6 +756,7 @@ def per_rank_test_fused_moe_alltoall_fp8_blockwise(job_id):
755756
)
756757
ref_model.to("cuda")
757758
ref_model.load_weights([weights])
759+
ref_model.post_load_weights()
758760

759761
# Evaluate the outputs on variant sequence lengths
760762
m = MAX_NUM_TOKENS

0 commit comments

Comments
 (0)