Skip to content

Conversation

@Cjkkkk
Copy link
Contributor

@Cjkkkk Cjkkkk commented Dec 5, 2025

📝 Summary of Changes
Fix Cublaslt fp8 to dummy nullptr c pointer when beta = 0

🎯 Justification
Since cuda 13.0, cublasLT adds a check that enforce gemm with same C and D pointer to have same descs. XLA uses same C and D pointer for in-place gemms which is fine. But for gemms with beta = 0, XLA also uses same C and D pointer even though C pointer is not needed by cublas. The new check will fail now in this case for fp8 gemms since C and D usually have different descs. The correct way is to use dummy C nullptr for gemms with beta = 0.

🚀 Kind of Contribution
🐛 Bug Fix

📊 Benchmark (for Performance Improvements)
None

🧪 Unit Tests:
these tests will pass with cuda 13 now:

DotTests/ParametricDotTest.TestF8E4M3FN/270x270x520_MajorToMinorFT
DotTests/ParametricDotTest.TestF8E4M3FN/12x117x7_MajorToMinorFT
DotTests/ParametricDotTest.TestF8E4M3FN/260x3x520_MajorToMinorTT
DotTests/ParametricDotTest.TestF8E4M3FN/12x117x7_MajorToMinorTT
DotTests/ParametricDotTest.TestF8E4M3FN/260x3x520_MajorToMinorTF
DotTests/ParametricDotTest.TestF8E4M3FN/260x3x520_MajorToMinorFT
DotTests/ParametricDotTest.TestF8E4M3FN/12x117x7_MajorToMinorTF
DotTests/ParametricDotTest.TestF8E4M3FN/12x117x7_MajorToMinorFF
DotTests/ParametricDotTest.TestF8E4M3FN/270x270x520_MajorToMinorTF
DotTests/ParametricDotTest.TestF8E4M3FN/260x3x520_MajorToMinorFF
DotTests/ParametricDotTest.TestF8E4M3FN/270x270x520_MajorToMinorFF
DotTests/ParametricDotTest.TestF8E4M3FN/270x270x520_MajorToMinorTT
F8E5M2Tests/DotAlgorithmSupportTest.AlgorithmIsSupportedFromCudaCapability/dot_any_f8_any_f8_f32_with_lhs_f8e5m2_rhs_f8e4m3fn_output_f8e4m3fn_from_cc_8_9_rocm_63_no_restriction_c_16_nc_2
F8E4M3FNTests/DotAlgorithmSupportTest.AlgorithmIsSupportedFromCudaCapability/dot_any_f8_any_f8_f32_fast_accum_with_lhs_f8e4m3fn_rhs_f8e4m3fn_output_f8e4m3fn_from_cc_8_9_rocm_63_no_restriction_c_16_nc_2
F8E5M2Tests/DotAlgorithmSupportTest.AlgorithmIsSupportedFromCudaCapability/dot_any_f8_any_f8_f32_with_lhs_f8e5m2_rhs_f8e4m3fn_output_f8e4m3fn_from_cc_8_9_rocm_63_no_restriction_c_32_nc_32
F8E4M3FNTests/DotAlgorithmSupportTest.AlgorithmIsSupportedFromCudaCapability/dot_any_f8_any_f8_f32_with_lhs_f8e4m3fn_rhs_f8e4m3fn_output_f8e4m3fn_from_cc_8_9_rocm_63_no_restriction_c_32_nc_32
F8E5M2Tests/DotAlgorithmSupportTest.AlgorithmIsSupportedFromCudaCapability/dot_any_f8_any_f8_f32_fast_accum_with_lhs_f8e5m2_rhs_f8e4m3fn_output_f8e4m3fn_from_cc_8_9_rocm_63_no_restriction_c_16_nc_2
F8E5M2Tests/DotAlgorithmSupportTest.AlgorithmIsSupportedFromCudaCapability/dot_any_f8_any_f8_f32_with_lhs_f8e5m2_rhs_f8e4m3fn_output_f8e5m2_from_cc_8_9_rocm_63_no_restriction_c_16_nc_2
F8E4M3FNTests/DotAlgorithmSupportTest.AlgorithmIsSupportedFromCudaCapability/dot_any_f8_any_f8_f32_with_lhs_f8e4m3fn_rhs_f8e4m3fn_output_f8e4m3fn_from_cc_8_9_rocm_63_no_restriction_c_16_nc_2
F8E4M3FNTests/DotAlgorithmSupportTest.AlgorithmIsSupportedFromCudaCapability/dot_any_f8_any_f8_f32_fast_accum_with_lhs_f8e4m3fn_rhs_f8e4m3fn_output_f8e4m3fn_from_cc_8_9_rocm_63_no_restriction_c_32_nc_32
F8E5M2Tests/DotAlgorithmSupportTest.AlgorithmIsSupportedFromCudaCapability/dot_any_f8_any_f8_f32_with_lhs_f8e5m2_rhs_f8e4m3fn_output_f8e5m2_from_cc_8_9_rocm_63_no_restriction_c_32_nc_32
F8E5M2Tests/DotAlgorithmSupportTest.AlgorithmIsSupportedFromCudaCapability/dot_any_f8_any_f8_f32_fast_accum_with_lhs_f8e5m2_rhs_f8e4m3fn_output_f8e5m2_from_cc_8_9_rocm_63_no_restriction_c_16_nc_2
F8E5M2Tests/DotAlgorithmSupportTest.AlgorithmIsSupportedFromCudaCapability/dot_any_f8_any_f8_f32_fast_accum_with_lhs_f8e5m2_rhs_f8e4m3fn_output_f8e5m2_from_cc_8_9_rocm_63_no_restriction_c_32_nc_32
F8E5M2Tests/DotAlgorithmSupportTest.AlgorithmIsSupportedFromCudaCapability/dot_any_f8_any_f8_f32_fast_accum_with_lhs_f8e5m2_rhs_f8e4m3fn_output_f8e4m3fn_from_cc_8_9_rocm_63_no_restriction_c_32_nc_32

🧪 Execution Tests:
None

@Cjkkkk Cjkkkk requested a review from beckerhe December 5, 2025 20:54
@Cjkkkk
Copy link
Contributor Author

Cjkkkk commented Dec 5, 2025

Hi @beckerhe , could you help take a look at this? This impacts lots of fp8 tests on our CUDA 13.1 release.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant