diff --git a/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh b/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh index 3b7ed113e7..9222bf19d2 100644 --- a/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh +++ b/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh @@ -125,7 +125,7 @@ std::vector getJitIncludeDirs() { static std::vector includeDirs; if (includeDirs.empty()) { // Command to execute - char const* cmd = "pip show tensorrt_llm 2>/dev/null"; + char const* cmd = "pip show flashinfer-python 2>/dev/null"; // Buffer to store the output std::array buffer; @@ -174,15 +174,11 @@ std::vector getJitIncludeDirs() { location.erase(location.find_last_not_of(" \n\r\t") + 1); // Set the include directory based on the package location - includeDirs.push_back(std::filesystem::path(location) / "tensorrt_llm" / "include"); - - if (!kJitUseNvcc) { - includeDirs.push_back(std::filesystem::path(location) / "tensorrt_llm" / "include" / - "cuda" / "include"); - } + includeDirs.push_back(std::filesystem::path(location) / "flashinfer" / "data" / "csrc" / + "nv_internal" / "tensorrt_llm"); } } else { - TLLM_LOG_WARNING("Failed to find TensorRT LLM installation, DeepGEMM will be disabled."); + TLLM_LOG_WARNING("Failed to find FlashInfer installation, DeepGEMM will be disabled."); } } return includeDirs; diff --git a/csrc/nv_internal/tensorrt_llm/deep_gemm/runtime.cuh b/csrc/nv_internal/tensorrt_llm/deep_gemm/runtime.cuh index 35af1fcd23..f4e6ab124e 100644 --- a/csrc/nv_internal/tensorrt_llm/deep_gemm/runtime.cuh +++ b/csrc/nv_internal/tensorrt_llm/deep_gemm/runtime.cuh @@ -36,8 +36,13 @@ static bool kJitDebugging = []() { }(); static bool kJitUseNvcc = []() { - char const* env_var = getenv("TRTLLM_DG_JIT_USE_NVCC"); - return env_var && (std::string(env_var) == "1" || std::string(env_var) == "true"); + // char const* env_var = getenv("TRTLLM_DG_JIT_USE_NVCC"); + // return env_var && (std::string(env_var) == "1" || std::string(env_var) == "true"); + // always use nvcc + // TODO: Enable nvrtc -- need these headers: + // [TensorRT-LLM][INFO] Compilation log: + // kernel.cu(16): catastrophic error: cannot open source file "cuda_bf16.h" + return true; }(); static bool kJitDumpCubin = []() {