-
Couldn't load subscription status.
- Fork 542
feat: enable deepgemm jit for fp8 block-scale on SM90 #1969
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Duncan Moss <[email protected]>
Signed-off-by: Duncan Moss <[email protected]>
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughThe pull request modifies the JIT compilation configuration for the TensorRT LLM deep GEMM module. The include directory discovery mechanism now uses the flashinfer-python package instead of tensorrt_llm, with updated path resolution. Additionally, the NVCC usage logic is simplified to always use NVCC without reading an environment variable. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~8 minutes Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Comment |
Summary of ChangesHello @djmmoss, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates JIT compilation capabilities for FP8 DeepGEMM kernels into the FlashInfer project. The primary goal is to enhance performance by allowing dynamic compilation of these kernels. This is achieved by ensuring that NVCC is always used for JIT compilation and by correctly configuring the include paths to resolve dependencies within the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request enables JIT compilation for FP8 DeepGEMM kernels, defaulting to NVCC. The changes correctly update package paths from tensorrt_llm to flashinfer-python to locate necessary headers. My review focuses on improving code clarity and maintainability. I've suggested updating a log message to reflect the package change and cleaning up commented-out code while preserving important context in a TODO comment.
| static bool kJitUseNvcc = []() { | ||
| char const* env_var = getenv("TRTLLM_DG_JIT_USE_NVCC"); | ||
| return env_var && (std::string(env_var) == "1" || std::string(env_var) == "true"); | ||
| // char const* env_var = getenv("TRTLLM_DG_JIT_USE_NVCC"); | ||
| // return env_var && (std::string(env_var) == "1" || std::string(env_var) == "true"); | ||
| // always use nvcc | ||
| // TODO: Enable nvrtc -- need these headers: | ||
| // [TensorRT-LLM][INFO] Compilation log: | ||
| // kernel.cu(16): catastrophic error: cannot open source file "cuda_bf16.h" | ||
| return true; | ||
| }(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The kJitUseNvcc variable is now hardcoded to true, and the previous implementation that used an environment variable is commented out. While the TODO comment explains why NVRTC is currently disabled, leaving commented-out code can reduce readability and maintainability. It would be cleaner to remove the commented-out lines and refine the comments to keep only the essential context.
static bool kJitUseNvcc = []() {
// NVRTC is currently disabled due to include issues. Using NVCC by default.
// TODO: Enable NVRTC. It fails with errors like:
// [TensorRT-LLM][INFO] Compilation log:
// kernel.cu(16): catastrophic error: cannot open source file "cuda_bf16.h"
return true;
}();
Signed-off-by: Duncan Moss <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh (2)
154-179: Update the outdated comment to reflect the new package name.The comment at line 155 still references
tensorrt_llm, but the code now usesflashinfer-python. Update the comment for consistency.Apply this diff to update the comment:
// Parse the location using regex - // `pip show tensorrt_llm` will output something like: + // `pip show flashinfer-python` will output something like: // Location: /usr/local/lib/python3.12/dist-packages // Editable project location: /code
180-182: Update the error message to reflect the new package dependency.The error message references "TensorRT LLM installation" but the code now searches for
flashinfer-python. Update the message for consistency.Apply this diff to update the error message:
} else { - TLLM_LOG_WARNING("Failed to find TensorRT LLM installation, DeepGEMM will be disabled."); + TLLM_LOG_WARNING("Failed to find flashinfer-python installation, DeepGEMM will be disabled."); }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh(2 hunks)csrc/nv_internal/tensorrt_llm/deep_gemm/runtime.cuh(1 hunks)
🔇 Additional comments (1)
csrc/nv_internal/tensorrt_llm/deep_gemm/runtime.cuh (1)
38-46: LGTM! Appropriate simplification to force NVCC usage.The hardcoded
return truewith commented-out environment variable logic is appropriate given the TODO note about missing headers for NVRTC. This ensures consistent behavior until NVRTC support is fully enabled.
| if (includeDirs.empty()) { | ||
| // Command to execute | ||
| char const* cmd = "pip show tensorrt_llm 2>/dev/null"; | ||
| char const* cmd = "pip show flashinfer-python 2>/dev/null"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the purpose of this command?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the DeepGEMM JIT, it needs the header files in deep_gemm/, this command finds the installation path which is then used further down to add the deep_gemm/ to the -I
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tend to move the logic to python, pip show flashinfer-python doesn't necessarily show the correct package information (e.g. at AOT time when the package is not installed yet).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or we can obtain the include path from python and pass the value to C++.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is where a refactor might be necessary, unfortunately these deep_gemm kernels aren't captured as part of AOT.
| } | ||
| } else { | ||
| TLLM_LOG_WARNING("Failed to find TensorRT LLM installation, DeepGEMM will be disabled."); | ||
| TLLM_LOG_WARNING("Failed to find FlashInfer installation, DeepGEMM will be disabled."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we can safely assume flashinfer is installed if this function is called?
📌 Description
Enable JIT compile for the FP8 DeepGEMM kernels, NVRTC is currently disabled it uses NVCC by default.
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Summary by CodeRabbit