-
Couldn't load subscription status.
- Fork 543
feat: enable deepgemm jit for fp8 block-scale on SM90 #1969
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -125,7 +125,7 @@ std::vector<std::filesystem::path> getJitIncludeDirs() { | |
| static std::vector<std::filesystem::path> includeDirs; | ||
| if (includeDirs.empty()) { | ||
| // Command to execute | ||
| char const* cmd = "pip show tensorrt_llm 2>/dev/null"; | ||
| char const* cmd = "pip show flashinfer-python 2>/dev/null"; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the purpose of this command? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the DeepGEMM JIT, it needs the header files in deep_gemm/, this command finds the installation path which is then used further down to add the deep_gemm/ to the -I There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tend to move the logic to python, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or we can obtain the include path from python and pass the value to C++. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is where a refactor might be necessary, unfortunately these deep_gemm kernels aren't captured as part of AOT. |
||
|
|
||
| // Buffer to store the output | ||
| std::array<char, 128> buffer; | ||
|
|
@@ -174,15 +174,11 @@ std::vector<std::filesystem::path> getJitIncludeDirs() { | |
| location.erase(location.find_last_not_of(" \n\r\t") + 1); | ||
|
|
||
| // Set the include directory based on the package location | ||
| includeDirs.push_back(std::filesystem::path(location) / "tensorrt_llm" / "include"); | ||
|
|
||
| if (!kJitUseNvcc) { | ||
| includeDirs.push_back(std::filesystem::path(location) / "tensorrt_llm" / "include" / | ||
| "cuda" / "include"); | ||
| } | ||
| includeDirs.push_back(std::filesystem::path(location) / "flashinfer" / "data" / "csrc" / | ||
| "nv_internal" / "tensorrt_llm"); | ||
| } | ||
| } else { | ||
| TLLM_LOG_WARNING("Failed to find TensorRT LLM installation, DeepGEMM will be disabled."); | ||
| TLLM_LOG_WARNING("Failed to find FlashInfer installation, DeepGEMM will be disabled."); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess we can safely assume flashinfer is installed if this function is called? |
||
| } | ||
| } | ||
| return includeDirs; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,8 +36,13 @@ static bool kJitDebugging = []() { | |
| }(); | ||
|
|
||
| static bool kJitUseNvcc = []() { | ||
| char const* env_var = getenv("TRTLLM_DG_JIT_USE_NVCC"); | ||
| return env_var && (std::string(env_var) == "1" || std::string(env_var) == "true"); | ||
| // char const* env_var = getenv("TRTLLM_DG_JIT_USE_NVCC"); | ||
| // return env_var && (std::string(env_var) == "1" || std::string(env_var) == "true"); | ||
| // always use nvcc | ||
| // TODO: Enable nvrtc -- need these headers: | ||
| // [TensorRT-LLM][INFO] Compilation log: | ||
| // kernel.cu(16): catastrophic error: cannot open source file "cuda_bf16.h" | ||
| return true; | ||
| }(); | ||
|
Comment on lines
38
to
46
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
|
|
||
| static bool kJitDumpCubin = []() { | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.